In [1]:
import os
import wandb
import argparse
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import (
    AdamW, 
    get_linear_schedule_with_warmup,
    Wav2Vec2Config, 
    RobertaConfig, 
    BertConfig,
    AutoTokenizer
)

from sklearn.model_selection import train_test_split

from dataset import ETRIDataset
from trainer import ModelTrainer
from models import (
    CompressedCASEModel,
    CASEmodel, 
    CompressedCSEModel, 
    ConcatModel, 
    MultiModalMixer,
    TextOnlyModel,
    SpeechOnlyModel,
)
from utils import audio_embedding, seed, loss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define a config dictionary object
parser = argparse.ArgumentParser()

# -- Choose Pretrained Model
parser.add_argument("--lm_path", type=str, default="klue/bert-base", help="You can choose models among (klue-bert series and klue-roberta series) (default: klue/bert-base")
parser.add_argument("--am_path", type=str, default="kresnik/wav2vec2-large-xlsr-korean")

# -- Training Argument
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--train_bsz", type=int, default=64)
parser.add_argument("--valid_bsz", type=int, default=64)
parser.add_argument("--val_ratio", type=float, default=0.2)
parser.add_argument("--context_max_len", type=int, default=128)
parser.add_argument("--audio_max_len", type=int, default=512)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--scheduler", type=str, default=None)
parser.add_argument("--num_labels", type=int, default=7)
parser.add_argument("--audio_emb_type", type=str, default="last_hidden_state", help="Can chosse audio embedding type between 'last_hidden_state' and 'extract_features' (default: last_hidden_state)")
parser.add_argument("--model", type=str, default="CASE")
parser.add_argument("--contrastive", type=bool, default=False)
parser.add_argument("--loss", type=str, default="crossentropy")
parser.add_argument("--gamma", type=float, default=1.0, help="focalloss's gamma argument")
parser.add_argument("--size", type=str, default="base", help="model size parameter. Choose between 'base' and 'small'")

## -- directory
parser.add_argument("--data_path", type=str, default="data/train.csv")
parser.add_argument("--save_path", type=str, default="save")
###### emb_train에 대한 설명 부과하기
parser.add_argument("--embedding_path", type=str, default="data/emb_train.pt")

# -- utils
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--seed", type=int, default=0)

# -- wandb
parser.add_argument("--wandb_project", type=str, default="comp")
parser.add_argument("--wandb_entity", type=str, default=None)
parser.add_argument("--wandb_group", type=str, default=None)
parser.add_argument("--wandb_name", type=str, default="case_audio_base")

# -- train mode(PET or ORG)
parser.add_argument("--pet", type=bool, default=False)

args = parser.parse_args([])

wav_config = Wav2Vec2Config.from_pretrained(args.am_path)
bert_config = BertConfig.from_pretrained(args.lm_path)

if args.size == "small":
    args.audio_max_len = 512
    args.hidden_size = 256
else:
    args.hidden_size = bert_config.hidden_size

In [3]:
mmm = MultiModalMixer(args, wav_config, bert_config)
mmm.freeze()

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Base 모델 vs MMM

In [4]:
case = CASEmodel(args, wav_config, bert_config)
cse = CompressedCSEModel(args, wav_config, bert_config)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.decoder.we

In [5]:
print("MMM Paramter 수:", sum(p.numel() for p in mmm.parameters() if p.requires_grad))
print("C-CSE Paramter 수:", sum(p.numel() for p in cse.parameters() if p.requires_grad))
print("CASE Paramter 수:", sum(p.numel() for p in case.parameters() if p.requires_grad))

MMM Paramter 수: 1547527
C-CSE Paramter 수: 112658311
CASE Paramter 수: 113183239


### Small 모델 vs MMM

In [6]:
args.size = "small"

if args.size == "small":
    args.audio_max_len = 512
    args.hidden_size = 256
else:
    args.hidden_size = bert_config.hidden_size

In [7]:
case = CASEmodel(args, wav_config, bert_config)
cse = CompressedCSEModel(args, wav_config, bert_config)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.decoder.we

In [8]:
print("MMM Paramter 수:", sum(p.numel() for p in mmm.parameters() if p.requires_grad))
print("C-CSE Paramter 수:", sum(p.numel() for p in cse.parameters() if p.requires_grad))
print("CASE Paramter 수:", sum(p.numel() for p in case.parameters() if p.requires_grad))

MMM Paramter 수: 1547527
C-CSE Paramter 수: 593031
CASE Paramter 수: 1646343
