In [None]:
import logging
import pandas as pd
from tqdm import tqdm 
import torch
import torch.nn as  nn 
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from datasets import load_from_disk
from omegaconf import OmegaConf
import hydra
from hydra import compose, initialize

import erc


def get_label(batch: dict, task: erc.constants.Task = None):
    # labels = batch["emotion"].long()
    # labels = torch.stack([batch["valence"], batch["arousal"]], dim=1).float()
    labels = {
    "emotion": batch["emotion"],
    "regress": torch.stack([batch["valence"], batch["arousal"]], dim=1),
    "vote_emotion": batch.get("vote_emotion", None)
    }
    # TODO: Add Multilabel Fetch
    return labels


##################
# valid-dataloader
##################
BATCH_SIZE = 8
valid_dataset = load_from_disk("/home/hoesungryu/etri-erc/kemdy19-kemdy20_valid3_multilabelTrue_rdeuceFalse")
valid_dataloadaer = DataLoader(valid_dataset, batch_size=BATCH_SIZE)

################
# model load 
################
with initialize(version_base=None, config_path="./config/model"):
    cfg = compose(config_name="mlp_mixer_roberta")
cfg.config['txt'] = "klue/roberta-large"
cfg['_target_'] = "erc.model.inference_mlp_mixer.MLP_Mixer_Roberta"


CKPT = '/home/hoesungryu/etri-erc/weights_AI_HUB/robertaL_fold=3_epoch=29.ckpt'
ckpt = torch.load(CKPT,map_location="cpu")
model_ckpt = ckpt.pop("state_dict")

model = hydra.utils.instantiate(cfg, checkpoint=model_ckpt).eval()

In [None]:
pbar = tqdm(
    total=int(len(valid_dataset)/BATCH_SIZE), 
    iterable =valid_dataloadaer)

total_wav = torch.zeros([1,256])
total_txt = torch.zeros([1,256])
total_vote_emotion = torch.zeros([BATCH_SIZE])

for batch in pbar:
    labels = get_label(batch) # concat 
    
    wav_pooled, txt_pooled = model(wav=batch["wav"],
            wav_mask=batch["wav_mask"],
            txt=batch["txt"],
            txt_mask=batch["txt_mask"],
            labels=labels)
    total_wav = torch.concat([total_wav,wav_pooled], dim=0)
    total_txt = torch.concat([total_txt,txt_pooled], dim=0)

    total_vote_emotion = torch.concat([total_vote_emotion,labels['vote_emotion']])
    
total_wav = total_wav[1:]
total_txt = total_txt[1:]
total_vote_emotion = total_vote_emotion[BATCH_SIZE:]