In [None]:
import logging
import pandas as pd
from tqdm import tqdm 
import torch
import torch.nn as  nn 
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from datasets import load_from_disk
from omegaconf import OmegaConf
import hydra
from hydra import compose, initialize
import os
import random
from pathlib import Path
import logging
import time

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import erc
import os
import pickle

# device = torch.device('cuda:1')

def get_label(batch: dict, task: erc.constants.Task = None):
    device = torch.device('cuda:1')
    # labels = batch["emotion"].long()
    # labels = torch.stack([batch["valence"], batch["arousal"]], dim=1).float()
    labels = {
    # "emotion": batch["emotion"].to(device),
    "emotion": batch["emotion"].to(device),
    "regress": torch.stack([batch["valence"], batch["arousal"]], dim=1),
    "vote_emotion": batch.get("vote_emotion", None)
    }
    # TODO: Add Multilabel Fetch
    return labels



In [None]:
def _seed_everything(seed):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    # torch.use_deterministic_algorithms(True)
    # If the above line is uncommented, we get the following RuntimeError:
    #  max_pool3d_with_indices_backward_cuda does not have a deterministic implementation
    torch.backends.cudnn.benchmark = False
_seed_everything(42)

In [None]:
BATCH_SIZE = 6
valid_dataset = load_from_disk("/home/hoesungryu/etri-erc/kemdy19-kemdy20_valid4_multilabelFalse_rdeuceTrue")
valid_dataloadaer = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

train_dataset = load_from_disk("/home/hoesungryu/etri-erc/kemdy19-kemdy20_train4_multilabelFalse_rdeuceTrue")
train_dataloadaer = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
# valid_dataset = load_from_disk("/home/hoesungryu/etri-erc/kemdy19-kemdy20_valid4_multilabelTrue_rdeuceFalse")


################
# model load 
################
with initialize(version_base=None, config_path="./config/model"):
    cfg = compose(config_name="mlp_mixer_roberta")
cfg.config['txt'] = "klue/roberta-large"
# cfg['_target_'] = "erc.model.inference_mlp_mixer.MLP_Mixer_Roberta"
cfg['_target_'] = "erc.model.mlp_mixer.MLP_Mixer_Roberta"


# CKPT = '/home/hoesungryu/etri-erc/weights_AI_HUB/RobertaL_valid4_onehot_epoch25.ckpt'
CKPT = '/home/hoesungryu/etri-erc/outputs/2023-04-08/13-23-20/61299-valid_acc0.949.ckpt'
# CKPT = '/home/hoesungryu/etri-erc/outputs/2023-04-08/13-23-20/last.ckpt'
SAVE_PATH = "./RobertaL_valid_results"
os.makedirs(SAVE_PATH, exist_ok=True)
# ckpt = torch.load(CKPT, map_location="cpu")
ckpt = torch.load(CKPT, map_location = torch.device('cuda:1'))
model_ckpt = ckpt.pop("state_dict")


In [None]:
from functools import partial

import torch
from transformers import Wav2Vec2ForSequenceClassification, BertForSequenceClassification, RobertaForSequenceClassification
from torch import nn
from einops.layers.torch import Rearrange, Reduce
from peft import get_peft_model, LoraConfig, TaskType

from erc.constants import Task
import erc


logger = erc.utils.get_logger(__name__)


pair = lambda x: x if isinstance(x, tuple) else (x, x)

class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x)) + x

def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):
    inner_dim = int(dim * expansion_factor)
    return nn.Sequential(
        dense(dim, inner_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        dense(inner_dim, dim),
        nn.Dropout(dropout)
    )

def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.):
    image_h, image_w = pair(image_size)
    assert (image_h % patch_size) == 0 and (image_w % patch_size) == 0, 'image must be divisible by patch size'
    num_patches = (image_h // patch_size) * (image_w // patch_size)
    chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear

    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
        nn.Linear((patch_size ** 2) * channels, dim),
        *[nn.Sequential(
            PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
            PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
        ) for _ in range(depth)],
        nn.LayerNorm(dim),
        Reduce('b n c -> b c', 'mean'),
        nn.Linear(dim, num_classes)
    )
class MLP_Mixer_Roberta(nn.Module):
    TASK = Task.ALL
    def __init__(
        self,
        config: str,
        # criterions: torch.nn.Module,
        cls_coef: float = 0.5,
        **config_kwargs
    ):
        super().__init__()
        self.wav_model = Wav2Vec2ForSequenceClassification.from_pretrained(config['wav'])
        self.txt_model = RobertaForSequenceClassification.from_pretrained(config['txt'])

        proj_size = self.wav_model.config.classifier_proj_size
        self.mlp_mixer = MLPMixer(image_size=proj_size,
                                  **config['mlp_mixer'])
        self.wav_projector = nn.Linear(self.wav_model.config.hidden_size, proj_size)
        last_hdn_size = {
            "klue/roberta-base": 768, "klue/roberta-large": 1024
        }[config["txt"]]
        self.txt_projector = nn.Linear(last_hdn_size, proj_size)
        self.gender_embed = nn.Embedding(num_embeddings=2, embedding_dim=proj_size)
        self.use_gender = config_kwargs.get("use_gender", False)
        self.wav_gender = config_kwargs.get("wav_gender", False)
        self.txt_gender = config_kwargs.get("txt_gender", False)
        self.use_peakl = config_kwargs.get("use_peakl", False)
        self.wav_model = self.wav_model.wav2vec2
        self.txt_model = self.txt_model.roberta

        # self.criterions = criterions
        if not (0 < cls_coef < 1):
            cls_coef = 0.7
        self.cls_coef = cls_coef
        self.reg_coef = 1 - cls_coef

    def forward(
        self,
        wav: torch.Tensor,
        wav_mask: torch.Tensor,
        txt: torch.Tensor,
        txt_mask: torch.Tensor,
        labels: torch.Tensor = None,
        **kwargs
    ) -> dict:
        """ Size
         WAV_hidden_dim: 1024
         WAV_proj_size: 256
         BERT_hidden_dim: 768
         BERT_proj_size: 256
        """
        # WAV 
        wav_outputs = self.wav_model(input_values=wav, attention_mask=wav_mask) # (B, S, WAV_hidden_dim)
        hidden_states = self.wav_projector(wav_outputs[0]) # (B, S, WAV_proj_size) 
        # Pool hidden states. (B, proj_size)
        if wav_mask is None:
            pooled_wav_output = hidden_states.mean(dim=1)
        else:
            padding_mask = self.wav_model._get_feature_vector_attention_mask(hidden_states.shape[1], wav_mask)
            hidden_states[~padding_mask] = 0.0
            pooled_wav_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)

        # TXT 
        txt_last_hidden_state = self.txt_model(input_ids=txt, attention_mask=txt_mask)[0] # (B, RoBERTa_hidden_dim)
        txt_outputs = txt_last_hidden_state[:, 0, :]
        pooled_txt_output = self.txt_projector(txt_outputs) # (B, RoBERTa_proj_size)

        # (B, 1 , WAV_proj_size, BERT_proj_size)
        matmul_output = torch.bmm(pooled_wav_output.unsqueeze(2), pooled_txt_output.unsqueeze(1)).unsqueeze(1)
        logits = self.mlp_mixer(matmul_output) # (B, num_labels)

        # calcuate the loss fct
        cls_logits = logits[:, :-2]
        cls_labels = labels["emotion"]


        return {
            "emotion": cls_labels.detach(),
            "cls_pred": cls_logits.detach(),
        }


In [None]:
model = MLP_Mixer_Roberta(cfg.config)

In [None]:
def parser(k):
    _k = k.split(".")[1:]
    return ".".join(_k)
new_ckpt = {parser(k): v for k, v in model_ckpt.items()}

In [None]:
def remove_deuce(outputs: dict) -> dict:
        """ Find deuced emotions and remove from batch """
        result = outputs
        emotion = outputs["emotion"]
        if emotion.ndim == 2:
            # For multi-dimensional emotion cases
            _, num_class = emotion.shape
            v, _ = emotion.max(dim=1)
            v = v.unsqueeze(dim=1).repeat(1, num_class)
            um = (emotion == v).sum(dim=1) == 1 # (bsz, ), unique mask
            if (um.sum() == 0).item():
                # If every batches had deuce data
                # Return scalar metrics only (removing cls/reg pred and logits)
                result = {k: _v for k, _v in outputs.items() if _v.ndim == 0}
            else:
                result.update(
                    {k: _v[um] for k, _v in outputs.items() if _v.ndim > 0}
                )
                result["emotion"] = result["emotion"].argmax(dim=1)
            return result
        else:
            return result

In [None]:
def _sort_outputs(outputs):
        try:
            result = dict()
            keys: list = outputs[0].keys()
            for key in keys:
                data = outputs[0][key]
                if data.ndim == 0:
                    # Scalar value result
                    result[key] = torch.stack([o[key] for o in outputs if key in o])
                elif data.ndim in [1, 2]:
                    # Batched 
                    result[key] = torch.concat([o[key] for o in outputs if key in o])
        except AttributeError:
            logger.warn("Error provoking data %s", outputs)
            breakpoint()
        return result

In [None]:
from torchmetrics import Accuracy, AUROC, ConcordanceCorrCoef, F1Score


device = torch.device('cuda:1')
model.to(device).load_state_dict(new_ckpt)


acc = Accuracy(task="multiclass", num_classes=7).to(device)

pbar = tqdm(
total=int(len(train_dataset)/BATCH_SIZE), 
iterable = enumerate(train_dataloadaer))

for batch_idx, batch in pbar:
    labels = get_label(batch) # concat 
    
    result = model(wav=batch["wav"].to(device),
            wav_mask=batch["wav_mask"].to(device),
            txt=batch["txt"].to(device),
            txt_mask=batch["txt_mask"].to(device),
            labels=labels)
    acc_ = acc(preds=result["cls_pred"], target=result["emotion"])

model.eval()
pbar_2 = tqdm(
total=int(len(valid_dataset)/BATCH_SIZE), 
iterable = enumerate(valid_dataloadaer))
for batch_idx, batch in pbar_2:
    labels = get_label(batch) # concat 
    
    result = model(wav=batch["wav"].to(device),
            wav_mask=batch["wav_mask"].to(device),
            txt=batch["txt"].to(device),
            txt_mask=batch["txt_mask"].to(device),
            labels=labels)
    acc_ = acc(preds=result["cls_pred"], target=result["emotion"])
    


In [None]:
acc_ = acc.compute()
print(f"Accuracy on all data: {acc_}")
    # result = remove_deuce(outputs=result)
    # result["emotion"] = result["emotion"].argmax(dim=1)

    # pred.append(list(result["cls_pred"].detach().cpu().numpy()))
    # target.append(list(result["emotion"].detach().cpu().numpy()))
    # ACC_score = acc(preds=result["cls_pred"], target=result["emotion"]).item()
#     print(f"{ACC_score:.4f}")
#     print(torch.sum(result["cls_pred"] == result["emotion"]))
    # break
#     break
#     total_score.append(ACC_score)
    # save_name = os.path.join(SAVE_PATH, f'wav_txt_{batch_idx:03d}.pickle')
    # with open(save_name, 'wb') as f:
    #     pickle.dump(save_dict, f, pickle.HIGHEST_PROTOCOL)
    # total += result["emotion"].size(0)
    # correct += torch.sum(result["cls_pred"].argmax(dim=1) == result["emotion"]).item()

    # accumulate += (correct / total)
    # accumulate /= 2 
# print('Accuracy of the network on the 10000 test images: %d %%' % (
#     100 * correct / total))