In [1]:
!nvidia-smi

Fri Mar 24 16:42:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.86.01    Driver Version: 515.86.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    Off  | 00000000:01:00.0 Off |                  Off |
| 30%   41C    P2    81W / 300W |  11880MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    Off  | 00000000:21:00.0 Off |                  Off |
| 30%   35C    P8    27W / 300W |      8MiB / 49140MiB |      0%      Default |
|       

In [2]:
import argparse
import pandas as pd

import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import (
    Wav2Vec2Config, 
    BertConfig,
    AutoTokenizer
)

from torcheval.metrics.functional import multiclass_f1_score, multiclass_accuracy
from tqdm import tqdm

from dataset import ETRIDataset
from models import CASEmodel, RoCASEmodel, CompressedCCEModel, ConcatModel, MultiModalMixer
from utils import audio_embedding, seed

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
parser = argparse.ArgumentParser()

# -- Choose Pretrained Model
parser.add_argument("--lm_path", type=str, default="klue/bert-base", help="You can choose models among (klue-bert series and klue-roberta series) (default: klue/bert-base")
parser.add_argument("--am_path", type=str, default="kresnik/wav2vec2-large-xlsr-korean")

# -- Training Argument
parser.add_argument("--test_bsz", type=int, default=16)
parser.add_argument("--context_max_len", type=int, default=128)
parser.add_argument("--audio_max_len", type=int, default=1024)
parser.add_argument("--num_labels", type=int, default=7)
parser.add_argument("--audio_emb_type", type=str, default="last_hidden_state", help="Can chosse audio embedding type between 'last_hidden_state' and 'extract_features' (default: last_hidden_state)")
parser.add_argument("--model", type=str, default="CASE")

## -- directory
parser.add_argument("--data_path", type=str, default="data/test.csv")
parser.add_argument("--model_path", type=str, default="save/epoch:1_CASEmodel.pt")
###### emb_train에 대한 설명 부과하기
parser.add_argument("--embedding_path", type=str, default="data/emb_test.pt")

# -- utils
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--seed", type=int, default=0)

args = parser.parse_args([])

In [4]:
seed.seed_setting(args.seed)

wav_config = Wav2Vec2Config.from_pretrained(args.am_path)
bert_config = BertConfig.from_pretrained(args.lm_path)
tokenizer = AutoTokenizer.from_pretrained(args.lm_path)

def text_audio_collator(batch):
    
    return {'audio_emb' : pad_sequence([item['audio_emb'] for item in batch], batch_first=True),
            'label' : torch.stack([item['label'] for item in batch]).squeeze(),
            'input_ids' :  torch.stack([item['input_ids'] for item in batch]).squeeze(),
            'attention_mask' :  torch.stack([item['attention_mask'] for item in batch]).squeeze(),
            'token_type_ids' :  torch.stack([item['token_type_ids'] for item in batch]).squeeze()}

# args.data_path
# args.embedding_path
test_data = pd.read_csv(args.data_path)
test_data.reset_index(inplace=True)

audio_emb = audio_embedding.save_and_load(args.am_path, test_data['audio'].tolist(), args.device, args.embedding_path)

label_dict = {'angry':0, 'neutral':1, 'sad':2, 'happy':3, 'disqust':4, 'surprise':5, 'fear':6}
test_dataset = ETRIDataset(
    audio_embedding = audio_emb, 
    dataset=test_data, 
    label_dict = label_dict,
    tokenizer = tokenizer,
    audio_emb_type = args.audio_emb_type,
    max_len = args.context_max_len, 
    )

# Create a DataLoader that batches audio sequences and pads them to a fixed length
test_dataloader = DataLoader(
    test_dataset, 
    batch_size=args.test_bsz,
    shuffle=False, 
    collate_fn=text_audio_collator, 
    num_workers=args.num_workers,
    )

In [5]:
check_dict = {'Model': [],
              'Test_ACC': [],
              'Test_M_F1': [],
              'Test_W_F1': []}

In [None]:
case_path = [os.path.join("save", path) for path in os.listdir('save') if "CASE" in path]

for model_path in tqdm(case_path):
    model = CASEmodel(args.lm_path, wav_config, bert_config, args.num_labels)
    
    model.load_state_dict(torch.load(model_path))
    model.to(args.device)
    model.eval()

    test_output = []
    test_label = []
    with torch.no_grad():

        pbar = tqdm(test_dataloader)
        for _, batch in enumerate(pbar):
            label = batch['label'].to(args.device)
            audio_tensor = batch['audio_emb'].to(args.device)

            input_ids = batch["input_ids"].to(args.device)
            attention_mask = batch["attention_mask"].to(args.device)
            token_type_ids = batch["token_type_ids"].to(args.device)

            output = model(
                input_ids, 
                attention_mask,
                token_type_ids,
                audio_tensor 
                )['class_logit']
            
            test_output.append(output.detach().cpu())
            test_label.append(label.detach().cpu())

        logits = torch.cat(test_output)
        labels = torch.cat(test_label)

        test_m_f1 = multiclass_f1_score(logits, labels, 
                                            num_classes=args.num_labels, 
                                            average="macro").detach().cpu().item()
        test_w_f1 = multiclass_f1_score(logits, labels, 
                                            num_classes=args.num_labels, 
                                            average="weighted").detach().cpu().item()
        test_acc = multiclass_accuracy(logits, labels, 
                                            num_classes=args.num_labels).detach().cpu().item()

        check_dict['Model'].append(model_path.split("/")[1])
        check_dict['Test_ACC'].append(test_acc) 
        check_dict['Test_M_F1'].append(test_m_f1) 
        check_dict['Test_W_F1'].append(test_acc)    
        
        model.to("cpu")
        del model
        torch.cuda.empty_cache()


In [None]:
cce_path = [os.path.join("save", path) for path in os.listdir('save') if "CCE" in path]

for model_path in tqdm(cce_path):
    model = CompressedCCEModel(args, wav_config, bert_config)
    
    model.load_state_dict(torch.load(model_path))
    model.to(args.device)
    model.eval()

    test_output = []
    test_label = []
    with torch.no_grad():

        pbar = tqdm(test_dataloader)
        for _, batch in enumerate(pbar):
            label = batch['label'].to(args.device)
            audio_tensor = batch['audio_emb'].to(args.device)

            input_ids = batch["input_ids"].to(args.device)
            attention_mask = batch["attention_mask"].to(args.device)
            token_type_ids = batch["token_type_ids"].to(args.device)

            output = model(
                input_ids, 
                attention_mask,
                token_type_ids,
                audio_tensor 
                )['class_logit']
            
            test_output.append(output.detach().cpu())
            test_label.append(label.detach().cpu())

        logits = torch.cat(test_output)
        labels = torch.cat(test_label)

        test_m_f1 = multiclass_f1_score(logits, labels, 
                                            num_classes=args.num_labels, 
                                            average="macro").detach().cpu().item()
        test_w_f1 = multiclass_f1_score(logits, labels, 
                                            num_classes=args.num_labels, 
                                            average="weighted").detach().cpu().item()
        test_acc = multiclass_accuracy(logits, labels, 
                                            num_classes=args.num_labels).detach().cpu().item()

        check_dict['Model'].append(model_path.split("/")[1])
        check_dict['Test_ACC'].append(test_acc) 
        check_dict['Test_M_F1'].append(test_m_f1) 
        check_dict['Test_W_F1'].append(test_acc)    
        
        model.to("cpu")
        del model
        torch.cuda.empty_cache()


In [31]:
# test_check = pd.DataFrame(check_dict)
# test_check.to_csv("data/score.csv", index=False)

In [7]:
check_dict = {'Model': [],
              'Test_ACC': [],
              'Test_M_F1': [],
              'Test_W_F1': []}

In [6]:
concat_path = [os.path.join("save", path) for path in os.listdir('save') if "Concat" in path]

for model_path in tqdm(concat_path):
    model = ConcatModel(args, wav_config, bert_config)
    
    model.load_state_dict(torch.load(model_path))
    model.to(args.device)
    model.eval()

    test_output = []
    test_label = []
    with torch.no_grad():

        pbar = tqdm(test_dataloader)
        for _, batch in enumerate(pbar):
            label = batch['label'].to(args.device)
            audio_tensor = batch['audio_emb'].to(args.device)

            input_ids = batch["input_ids"].to(args.device)
            attention_mask = batch["attention_mask"].to(args.device)
            token_type_ids = batch["token_type_ids"].to(args.device)

            output = model(
                input_ids, 
                attention_mask,
                token_type_ids,
                audio_tensor 
                )['class_logit']
            
            test_output.append(output.detach().cpu())
            test_label.append(label.detach().cpu())

        logits = torch.cat(test_output)
        labels = torch.cat(test_label)

        test_m_f1 = multiclass_f1_score(logits, labels, 
                                            num_classes=args.num_labels, 
                                            average="macro").detach().cpu().item()
        test_w_f1 = multiclass_f1_score(logits, labels, 
                                            num_classes=args.num_labels, 
                                            average="weighted").detach().cpu().item()
        test_acc = multiclass_accuracy(logits, labels, 
                                            num_classes=args.num_labels).detach().cpu().item()

        check_dict['Model'].append(model_path.split("/")[1])
        check_dict['Test_ACC'].append(test_acc) 
        check_dict['Test_M_F1'].append(test_m_f1) 
        check_dict['Test_W_F1'].append(test_acc)    
        
        model.to("cpu")
        del model
        torch.cuda.empty_cache()


  0%|          | 0/10 [00:00<?, ?it/s]Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 161/161 [00:33<00:00,  4.81it/s]
 10%|█         | 1/10 [00:37<05:37, 37

In [7]:
mmm_path = [os.path.join("save", path) for path in os.listdir('save') if "MMM" in path]

for model_path in tqdm(mmm_path):
    model = MultiModalMixer(args, wav_config, bert_config)
    
    model.load_state_dict(torch.load(model_path))
    model.to(args.device)
    model.eval()

    test_output = []
    test_label = []
    with torch.no_grad():

        pbar = tqdm(test_dataloader)
        for _, batch in enumerate(pbar):
            label = batch['label'].to(args.device)
            audio_tensor = batch['audio_emb'].to(args.device)

            input_ids = batch["input_ids"].to(args.device)
            attention_mask = batch["attention_mask"].to(args.device)
            token_type_ids = batch["token_type_ids"].to(args.device)

            output = model(
                input_ids, 
                attention_mask,
                token_type_ids,
                audio_tensor 
                )['class_logit']
            
            test_output.append(output.detach().cpu())
            test_label.append(label.detach().cpu())

        logits = torch.cat(test_output)
        labels = torch.cat(test_label)

        test_m_f1 = multiclass_f1_score(logits, labels, 
                                            num_classes=args.num_labels, 
                                            average="macro").detach().cpu().item()
        test_w_f1 = multiclass_f1_score(logits, labels, 
                                            num_classes=args.num_labels, 
                                            average="weighted").detach().cpu().item()
        test_acc = multiclass_accuracy(logits, labels, 
                                            num_classes=args.num_labels).detach().cpu().item()

        check_dict['Model'].append(model_path.split("/")[1])
        check_dict['Test_ACC'].append(test_acc) 
        check_dict['Test_M_F1'].append(test_m_f1) 
        check_dict['Test_W_F1'].append(test_acc)    
        
        model.to("cpu")
        del model
        torch.cuda.empty_cache()

  0%|          | 0/3 [00:00<?, ?it/s]Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 161/161 [00:25<00:00,  6.21it/s]
 33%|███▎      | 1/3 [00:28<00:56, 28.0

In [8]:
test_check = pd.read_csv("data/score.csv")

more = pd.DataFrame(check_dict)

In [10]:
test_check = pd.concat([test_check, more], axis=0)

In [12]:
test_check[test_check['Test_M_F1'] > 0.24]

Unnamed: 0,Model,Test_ACC,Test_M_F1,Test_W_F1
2,epoch:8_CASEmodel_shceduler-linear_False.pt,0.880359,0.2519,0.880359
7,epoch:5_CASEmodel.pt,0.879969,0.256384,0.879969
8,epoch:5_CASEmodel_True.pt,0.879969,0.256384,0.879969
14,epoch:3_CCEmodel.pt,0.874903,0.246014,0.874903
18,epoch:3_CCEmodel_True.pt,0.874903,0.246014,0.874903


: 