In [1]:
import os
import wandb
import argparse
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import AdamW, Wav2Vec2Config, RobertaConfig, BertConfig, AutoTokenizer

from sklearn.model_selection import train_test_split

from dataset import ETRIDataset
from trainer import ModelTrainer
from models import CASEmodel, RoCASEmodel
from utils import audio_embedding, seed

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import random
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd
from typing import Callable, Tuple, Dict
from torch.utils.data import Dataset
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

class ETRIDataset(Dataset):
    """
    This is a class that returns audio embeddings and text tokenization results.
    """
    def __init__(self, 
                 audio_embedding, # audio embedding load from '~.pt' file
                 dataset:pd.DataFrame, 
                 label_dict:Dict, # label to idx dictionary
                 tokenizer:Callable,
                 audio_emb_type:str='last_hidden_state', # audio embedding type: 'last_hidden_state' or 'extract_features'
                 max_len:int=128):
        super().__init__()
        
        self.audio_emb = audio_embedding
        self.dataset = dataset
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.label_dict = label_dict
        self.emb_type = audio_emb_type
        
    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset['text'].iloc[idx]
        
        label = torch.tensor(self.label_dict[self.dataset['labels'].iloc[idx]])
        
        emb_key = str(self.dataset.index[idx])
        wav_emb = self.audio_emb[emb_key][self.emb_type]

        encoded_dict = self.tokenizer(text, 
                                      return_tensors='pt',
                                      add_special_tokens=True,
                                      max_length=self.max_len,
                                      padding='max_length',
                                      truncation=True,
                                      return_attention_mask=True,
                                      return_token_type_ids=True
                                      )
        
        encoded_dict['audio_emb'] = wav_emb
        encoded_dict['label'] = label
        
        return encoded_dict

In [9]:
parser = argparse.ArgumentParser()

# -- Choose Pretrained Model
parser.add_argument("--lm_path", type=str, default="klue/bert-base", help="You can choose models among (klue-bert series and klue-roberta series) (default: klue/bert-base")
parser.add_argument("--am_path", type=str, default="kresnik/wav2vec2-large-xlsr-korean")

# -- Training Argument
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--train_bsz", type=int, default=64)
parser.add_argument("--valid_bsz", type=int, default=256)
parser.add_argument("--val_ratio", type=float, default=0.2)
parser.add_argument("--context_max_len", type=int, default=128)
parser.add_argument("--audio_max_len", type=int, default=1024)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--num_labels", type=int, default=7)
parser.add_argument("--audio_emb_type", type=str, default="last_hidden_state", help="Can chosse audio embedding type between 'last_hidden_state' and 'extract_features' (default: last_hidden_state)")
parser.add_argument("--model", type=str, default="CASE")
## -- directory
parser.add_argument("--data_path", type=str, default="data/train.csv")
parser.add_argument("--save_path", type=str, default="save")
###### emb_train에 대한 설명 부과하기
parser.add_argument("--embedding_path", type=str, default="data/emb_train.pt")

# -- utils
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--seed", type=int, default=0)

# -- wandb
parser.add_argument("--wandb_project", type=str, default="comp")
parser.add_argument("--wandb_entity", type=str, default=None)
parser.add_argument("--wandb_group", type=str, default=None)
parser.add_argument("--wandb_name", type=str, default="case_audio_base")

args = parser.parse_args([])

In [10]:
wav_config = Wav2Vec2Config.from_pretrained(args.am_path)
bert_config = BertConfig.from_pretrained(args.lm_path)
tokenizer = AutoTokenizer.from_pretrained(args.lm_path)

In [11]:
def text_audio_collator(batch):
    
    return {'audio_emb' : pad_sequence([item['audio_emb'] for item in batch], batch_first=True),
            'label' : torch.stack([item['label'] for item in batch]).squeeze(),
            'input_ids' :  torch.stack([item['input_ids'] for item in batch]).squeeze(),
            'attention_mask' :  torch.stack([item['attention_mask'] for item in batch]).squeeze(),
            'token_type_ids' :  torch.stack([item['token_type_ids'] for item in batch]).squeeze()}

dataset = pd.read_csv(args.data_path)
dataset.reset_index(inplace=True)

In [12]:
train_df, val_df = train_test_split(dataset, test_size = args.val_ratio, random_state=args.seed)

In [7]:
audio_emb = audio_embedding.save_and_load(args.am_path, dataset['audio'].tolist(), args.device, args.embedding_path)

Some weights of the model checkpoint at kresnik/wav2vec2-large-xlsr-korean were not used when initializing Wav2Vec2Model: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 10262/10262 [04:29<00:00, 38.13it/s]


In [13]:
label_dict = {'angry':0, 'neutral':1, 'sad':2, 'happy':3, 'disqust':4, 'surprise':5, 'fear':6}
train_dataset = ETRIDataset(
    audio_embedding = audio_emb, 
    dataset=train_df, 
    label_dict = label_dict,
    tokenizer = tokenizer,
    audio_emb_type = args.audio_emb_type,
    max_len = args.context_max_len, 
    )

In [14]:
train_dataloader = DataLoader(
        train_dataset, 
        batch_size=args.train_bsz,
        shuffle=True, 
        collate_fn=text_audio_collator, 
        num_workers=args.num_workers,
        )

In [15]:
from tqdm import tqdm
for i in tqdm(range(len(train_dataset))):
    train_dataset[i]

100%|██████████| 8209/8209 [00:01<00:00, 7167.01it/s]


In [14]:
samp = next(iter(train_dataset))

In [16]:
samp.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'audio_emb', 'label'])

In [13]:
sample = next(iter(train_dataloader))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/jjonhwa/.local/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/jjonhwa/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/jjonhwa/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_761833/2472082722.py", line 41, in __getitem__
    wav_emb = self.audio_emb[emb_key][self.emb_type]
KeyError: '1392'
