In [33]:
import wandb
import pandas as pd
from torch import nn
from dataset import ETRIDataset
from trainer import ModelTrainer
from torch.utils.data import DataLoader
from utils import audio_embedding, seed
from models import CASEmodel, RoCASEmodel
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from transformers import AdamW, Wav2Vec2Config, RobertaConfig, BertConfig, AutoTokenizer

# Define a config dictionary object
config = {
  "lr": 1e-5,
  "lm_path":'klue/bert-base',
  "am_path": 'kresnik/wav2vec2-large-xlsr-korean',
  "train_bsz": 64,
  "val_bsz": 64,
  "val_ratio":0.1,
  "max_len": 128,
  "epochs" :20,
  "device":'cuda:2',
  "num_labels":7,
  'data_path' : 'data/train.csv',
  "label_dict": {'angry':0, 'neutral':1, 'sad':2, 'happy':3, 'disqust':4, 'surprise':5, 'fear':6},
  "sav_dir":'save',
  "base_score":0.45, # Save the model according to the base validation score.
  'embedding_path':'data/emb_train.pt', # If an embedding file named "data/emb_train.pt" does not exist, generate one
  "audio_emb_type": 'last_hidden_state', # audio embedding type: 'last_hidden_state' or 'extract_features'
  "max_len" : 128,
  "seed":42
}


seed.seed_setting(config['seed'])

# # Pass the config dictionary when you initialize W&B
# wandb.init(project='comp',
#         group='bert_cls',
#         name='case_audio_base',
#         config=config
# )

wav_config = Wav2Vec2Config.from_pretrained(config['am_path'])
bert_config = BertConfig.from_pretrained(config['lm_path'])
tokenizer = AutoTokenizer.from_pretrained(config['lm_path'])

def text_audio_collator(batch):
    audio_emb = pad_sequence([item['audio_emb'] for item in batch], batch_first=True)
    batch['audio_emb'] = audio_emb
    return batch

dataset = pd.read_csv(config['data_path'])
dataset.reset_index(inplace=True)
train_df, val_df = train_test_split(dataset, test_size = config['val_ratio'], random_state=config['seed'])

audio_emb = audio_embedding.save_and_load(config['am_path'], dataset['audio'].to_list(),
                                                'cuda:3',  # cuda is required to run the audio embedding generation model.
                                                config['embedding_path']) 

train_dataset = ETRIDataset(audio_embedding = audio_emb, 
                                dataset=train_df, 
                                label_dict = config['label_dict'],
                                tokenizer = tokenizer,
                                audio_emb_type = config['audio_emb_type'],
                                max_len = wandb.config['max_len'], 
                                )

for i in train_dataset:
    i

KeyError: 0

wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)


In [19]:
train_df.index[213]
train_df['text'].iloc[230]

'그래서 이제 그래가지고 그런 것들이 몇 개 있었는데 그중에 하나가 감자랑 뭐 케찹, 특히 케찹.'

In [28]:
idx=0
text = train_df['text'].iloc[idx]
# label = torch.tensor([label_dict[dataset['labels'][idx]]])
label = train_df['labels'].iloc[idx]
emb_key = str(train_df.index[idx])
wav_emb = audio_emb[emb_key]['last_hidden_state']

In [32]:
audio_emb['5087']

{'last_hidden_state': tensor([[-0.2264,  0.4745,  0.3608,  ..., -0.6388,  0.1840, -1.1324],
         [-0.3262,  0.3578,  0.2932,  ..., -0.2424,  0.4000, -1.1593],
         [-0.2379, -0.0612,  0.1637,  ..., -0.6568,  1.0239, -1.2081],
         ...,
         [-0.2023,  0.5404,  1.0275,  ..., -1.2963,  0.4411, -1.1646],
         [-0.0138,  0.8220,  1.2323,  ..., -1.3627, -0.1051, -0.9415],
         [ 0.4475,  1.1410, -0.1176,  ..., -1.5894,  0.6625, -0.7491]]),
 'extract_features': tensor([[ 0.0915, -0.9580, -0.8608,  ...,  0.1364, -0.1297, -0.5361],
         [ 0.1730, -0.5584,  0.6368,  ..., -0.3559, -0.5224,  0.1257],
         [ 0.4155, -0.4394,  1.1773,  ...,  0.9019,  0.1899,  0.1428],
         ...,
         [ 0.3458, -0.8104, -1.1789,  ...,  0.1491, -0.1316, -0.9596],
         [ 0.3236, -0.9368, -1.0070,  ...,  0.2956,  0.0048, -0.6200],
         [ 1.0934, -0.9369, -1.3964,  ...,  0.5943,  0.2064, -1.2340]])}

wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
