In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from core.models import codi
from core.models.common.get_model import get_model
import torch
from core.models.ema import LitEma
from core.models.common.get_optimizer import get_optimizer
from musiccaps import MusicCapsDataset

def load_yaml_config(filepath):
    with open(filepath, 'r') as file:
        return yaml.safe_load(file)

class ConfigObject(object):
    def __init__(self, dictionary):
        for key in dictionary:
            setattr(self, key, dictionary[key])


### モデルの定義=============================================

# AudioLDM
audioldm_cfg = load_yaml_config('configs/model/audioldm.yaml')
audioldm = ConfigObject(audioldm_cfg["audioldm_autoencoder"])

# Optimus
optimus_cfg = load_yaml_config('configs/model/optimus.yaml')

# optimus_vaeのconfigの辞書を、オブジェクトに置き換え
optimus_cfg['optimus_vae']['args']['encoder'] = ConfigObject(optimus_cfg['optimus_bert_encoder'])
optimus_cfg['optimus_vae']['args']['encoder'].args['config'] = ConfigObject(optimus_cfg['optimus_bert_encoder']['args']['config'])
optimus_cfg['optimus_vae']['args']['decoder'] = ConfigObject(optimus_cfg['optimus_gpt2_decoder'])
optimus_cfg['optimus_vae']['args']['decoder'].args['config'] = ConfigObject(optimus_cfg['optimus_gpt2_decoder']['args']['config'])
optimus_cfg['optimus_vae']['args']['tokenizer_encoder'] = ConfigObject(optimus_cfg['optimus_bert_tokenizer'])
optimus_cfg['optimus_vae']['args']['tokenizer_decoder'] = ConfigObject(optimus_cfg['optimus_gpt2_tokenizer'])
optimus_cfg['optimus_vae']['args']['args'] = ConfigObject(optimus_cfg['optimus_vae']['args']['args'])
optimus = ConfigObject(optimus_cfg["optimus_vae"])

# CLAP
clap_cfg = load_yaml_config('configs/model/clap.yaml')
clap = ConfigObject(clap_cfg["clap_audio"])

# CoDi
unet_cfg = load_yaml_config('configs/model/openai_unet.yaml')
unet_cfg["openai_unet_codi"]["args"]["unet_image_cfg"] = ConfigObject(unet_cfg["openai_unet_2d"])
unet_cfg["openai_unet_codi"]["args"]["unet_text_cfg"] = ConfigObject(unet_cfg["openai_unet_0dmd"])
unet_cfg["openai_unet_codi"]["args"]["unet_audio_cfg"] = ConfigObject(unet_cfg["openai_unet_2d_audio"])
unet = ConfigObject(unet_cfg["openai_unet_codi"])

# CoDiモデルのインスタンスを作成
model = codi.CoDi(audioldm_cfg=audioldm, optimus_cfg=optimus, clap_cfg=clap, unet_config=unet)


#######################
# Running in eps mode #
#######################

Keeping EMAs of 3368.


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [46]:
# データセット
dataset = MusicCapsDataset(csv_file='/raid/m236866/md-mt/datasets/musiccaps/musiccaps-public.csv',
                           audio_dir='/raid/m236866/md-mt/datasets/musiccaps/musiccaps_30')
                           
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)


ema = LitEma(model)
optimizer_config = {
            'type': 'adam',
            'args': {
                 'weight_decay': 1e-4  # Weight decay
            }
        }
optimizer_config = ConfigObject(optimizer_config)
optimizer = get_optimizer()(model, optimizer_config)

In [47]:
dataset[0]

TypeError: join() argument must be str, bytes, or os.PathLike object, not 'int64'

In [11]:
dataset.captions_frame

Unnamed: 0,ytid,start_s,end_s,audioset_positive_labels,aspect_list,caption,author_id,is_balanced_subset,is_audioset_eval
0,-0Gj8-vB1q4,30,40,"/m/0140xf,/m/02cjck,/m/04rlf","['low quality', 'sustained strings melody', 's...",The low quality recording features a ballad so...,4,False,True
1,-0SdAVK79lg,30,40,"/m/0155w,/m/01lyv,/m/0342h,/m/042v_gx,/m/04rlf...","['guitar song', 'piano backing', 'simple percu...",This song features an electric guitar as the m...,0,False,False
2,-0vPFx-wRRI,30,40,"/m/025_jnm,/m/04rlf","['amateur recording', 'finger snipping', 'male...",a male voice is singing a melody with changing...,6,False,True
3,-0xzrMun0Rs,30,40,"/m/01g90h,/m/04rlf","['backing track', 'jazzy', 'digital drums', 'p...",This song contains digital drums playing a sim...,6,False,True
4,-1LrH01Ei1w,30,40,"/m/02p0sh1,/m/04rlf","['rubab instrument', 'repetitive melody on dif...",This song features a rubber instrument being p...,0,False,False
...,...,...,...,...,...,...,...,...,...
5516,zw5dkiklbhE,15,25,"/m/01sm1g,/m/0l14md","['amateur recording', 'percussion', 'wooden bo...",This audio contains someone playing a wooden b...,6,False,False
5517,zwfo7wnXdjs,30,40,"/m/02p0sh1,/m/04rlf,/m/06j64v","['instrumental music', 'arabic music', 'genera...",The song is an instrumental. The song is mediu...,1,True,True
5518,zx_vcwOsDO4,50,60,"/m/01glhc,/m/02sgy,/m/0342h,/m/03lty,/m/04rlf,...","['instrumental', 'no voice', 'electric guitar'...",The rock music is purely instrumental and feat...,2,True,True
5519,zyXa2tdBTGc,30,40,"/m/04rlf,/t/dd00034","['instrumental music', 'gospel music', 'strong...",The song is an instrumental. The song is slow ...,1,False,False


In [27]:
### 学習ループ=============================================
num_epochs=2
for epoch in range(num_epochs):
    for batch_idx, (texts, audios) in enumerate(dataloader):
        # ここでモデルに入力を与え、損失を計算し、オプティマイザーを使用してモデルの重みを更新
        optimizer.zero_grad()
        loss = model.forward(x=audios, c=texts) #損失計算
        loss.backward()
        optimizer.step()

        # EMAの更新
        ema.update(model.parameters())

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')

TypeError: join() argument must be str, bytes, or os.PathLike object, not 'int64'