In [None]:
import importlib
from typing import List, Dict

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from tqdm.auto import tqdm
from torch import nn
from torchmetrics import Accuracy, AUROC, ConcordanceCorrCoef, F1Score
import wandb

import erc
import os
import random
import numpy as np 

logger = erc.utils.get_logger(name=__name__)


class ERCModule(pl.LightningModule):
    def __init__(self,
                 model: nn.Module,
                 train_loader: torch.utils.data.DataLoader,
                 valid_loader: torch.utils.data.DataLoader,
                 optimizer: omegaconf.DictConfig,
                 scheduler: omegaconf.DictConfig = None,
                 load_from_checkpoint: str = None,
                 separate_lr: dict = None):
        super().__init__()
        self.model = model

        # Dataloaders
        self.train_loader = train_loader
        self.valid_loader = valid_loader

        # Optimizations
        if separate_lr is not None:
            _opt_groups = []
            for _submodel, _lr in separate_lr.items():
                submodel = getattr(self.model, _submodel, None)
                if submodel is None:
                    logger.warn("separate_lr was given but submodel was not found: %s", _submodel)
                    self.opt_config = self._configure_optimizer(optimizer=optimizer,
                                                                scheduler=scheduler)
                    break
                _opt_groups.append(
                    {"params": submodel.parameters(), "lr": _lr}
                )
            opt = dict(optimizer)
            _o = opt.pop("_target_").split(".")
            _oc = importlib.import_module(".".join(_o[:-1]))
            _oc = getattr(_oc, _o[-1])
            _opt = _oc(params=_opt_groups, **opt)
            _sch = hydra.utils.instantiate(scheduler, scheduler={"optimizer": _opt})
            self.opt_config = {"optimizer": _opt, "lr_scheduler": dict(**_sch)}
        else:
            self.opt_config = self._configure_optimizer(optimizer=optimizer, scheduler=scheduler)

        # Metrics Configuration
        self.acc = Accuracy(task="multiclass", num_classes=7)
        self.auroc = AUROC(task="multiclass", num_classes=7)
        self.f1 = F1Score(task="multiclass", num_classes=7, average="macro")
        self.ccc_val = ConcordanceCorrCoef(num_outputs=1)
        self.ccc_aro = ConcordanceCorrCoef(num_outputs=1)

        self.label_keys = list(erc.constants.emotion2idx.keys())[:-1]

    def train_dataloader(self):
        return self.train_loader

    def valid_dataloader(self):
        return self.valid_loader

    def _configure_optimizer(self, optimizer: omegaconf.DictConfig, scheduler: omegaconf.DictConfig):
        opt = hydra.utils.instantiate(optimizer, params=self.model.parameters())
        sch: dict = hydra.utils.instantiate(scheduler, scheduler={"optimizer": opt})\
                            if scheduler is not None else None
        opt_config = {
            "optimizer": opt, "lr_scheduler": dict(**sch)
        } if sch is not None else opt
        return opt_config
    
    def configure_optimizers(self) -> torch.optim.Optimizer | dict:
        return self.opt_config

    def get_label(self, batch: dict, task: erc.constants.Task = None):
        task = task or self.model.TASK
        if task == erc.constants.Task.CLS:
            # (batch_size,) | Long
            labels = batch["emotion"].long()
        elif task == erc.constants.Task.REG:
            # (batch_size, 2) | Float
            labels = torch.stack([batch["valence"], batch["arousal"]], dim=1).float()
        elif task == erc.constants.Task.ALL:
            labels = {
                "emotion": batch["emotion"],
                "regress": torch.stack([batch["valence"], batch["arousal"]], dim=1),
                "vote_emotion": batch.get("vote_emotion", None)
            }
        # TODO: Add Multilabel Fetch
        return labels

    def forward(self, batch):
        try:
            labels = self.get_label(batch)
            result: dict = self.model(wav=batch["wav"],
                                      wav_mask=batch["wav_mask"],
                                      txt=batch["txt"],
                                      txt_mask=batch["txt_mask"],
                                      labels=labels,
                                      gender=batch.get("gender", None))
            return result
        except RuntimeError:
            # For CUDA Device-side asserted error
            print(f"Label given {labels}")
            logger.warn("Label given %s", labels)
            raise RuntimeError

    def _sort_outputs(self, outputs: List[Dict]):
        try:
            result = dict()
            keys: list = outputs[0].keys()
            for key in keys:
                data = outputs[0][key]
                if data.ndim == 0:
                    # Scalar value result
                    result[key] = torch.stack([o[key] for o in outputs if key in o])
                elif data.ndim in [1, 2]:
                    # Batched 
                    result[key] = torch.concat([o[key] for o in outputs if key in o])
        except AttributeError:
            logger.warn("Error provoking data %s", outputs)
            breakpoint()
        return result

    def remove_deuce(self, outputs: dict) -> dict:
        """ Find deuced emotions and remove from batch """
        result = outputs
        emotion = outputs["emotion"]
        if emotion.ndim == 2:
            # For multi-dimensional emotion cases
            _, num_class = emotion.shape
            v, _ = emotion.max(dim=1)
            v = v.unsqueeze(dim=1).repeat(1, num_class)
            um = (emotion == v).sum(dim=1) == 1 # (bsz, ), unique mask
            if (um.sum() == 0).item():
                # If every batches had deuce data
                # Return scalar metrics only (removing cls/reg pred and logits)
                result = {k: _v for k, _v in outputs.items() if _v.ndim == 0}
            else:
                result.update(
                    {k: _v[um] for k, _v in outputs.items() if _v.ndim > 0}
                )
                result["emotion"] = result["emotion"].argmax(dim=1)
            return result
        else:
            return result

    def log_result(
        self, 
        outputs: List[Dict] | dict, 
        mode: erc.constants.RunMode | str = "train",
        unit: str = "epoch"
    ):
        result: dict = self._sort_outputs(outputs=outputs) if isinstance(outputs, list) else outputs
        if unit == "step":
            # No need to on epochs
            result = self.remove_deuce(outputs=result)

        # Log Losses
        for loss_key in ["loss", "cls_loss", "reg_loss"]:
            if loss_key in result:
                self.log(f"{unit}/{mode}_{loss_key}", torch.mean(result.get(loss_key, 0)), prog_bar=True)

        # Log Classification Metrics: Accuracy & AUROC
        if "cls_pred" in result and "emotion" in result:
            self.acc(preds=result["cls_pred"], target=result["emotion"])
            self.auroc(preds=result["cls_pred"], target=result["emotion"])
            self.f1(preds=result["cls_pred"], target=result["emotion"])
            self.log(f'{unit}/{mode}_acc', self.acc)
            self.log(f'{unit}/{mode}_auroc', self.auroc)
            self.log(f'{unit}/{mode}_f1', self.f1)

        # Log Regression Metrics: CCC
        if "reg_pred" in result and "regress" in result:
            self.ccc_val(result["reg_pred"][:, 0], result["regress"][:, 0])
            self.ccc_aro(result["reg_pred"][:, 1], result["regress"][:, 1])
            self.log(f"{unit}/{mode}_ccc(val)", self.ccc_val)
            self.log(f"{unit}/{mode}_ccc(aro)", self.ccc_aro)
        return result
        
    def log_confusion_matrix(self, result: dict):
        preds = result["cls_pred"].cpu().detach() if "cls_pred" in result else None
        labels = result["emotion"].cpu().numpy() if "emotion" in result else None
        if preds is not None and labels is not None:
            preds = preds.argmax(dim=1).numpy()
            cf = wandb.plot.confusion_matrix(y_true=labels,
                                            preds=preds,
                                            class_names=self.label_keys)
            self.logger.experiment.log({"confusion_matrix": cf})

    def training_step(self, batch, batch_idx, optimizer_idx=None):
        result = self.forward(batch)
        result = self.log_result(outputs=result, mode="train", unit="step")
        return result

    def training_epoch_end(self, outputs: List[Dict]):
        result = self.log_result(outputs=outputs, mode="train", unit="epoch")

    def validation_step(self, batch, batch_idx):
        result = self.forward(batch)
        result = self.log_result(outputs=result, mode="valid", unit="step")
        return result
    
    def validation_epoch_end(self, outputs: List[Dict]):
        result = self.log_result(outputs=outputs, mode="valid", unit="epoch")
        self.log_confusion_matrix(result)

def _seed_everything(seed):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    # torch.use_deterministic_algorithms(True)
    # If the above line is uncommented, we get the following RuntimeError:
    #  max_pool3d_with_indices_backward_cuda does not have a deterministic implementation
    torch.backends.cudnn.benchmark = False


In [None]:
from datasets import load_from_disk
from torch.utils.data import DataLoader
import numpy as np 

_seed_everything(42)
BATCH_SIZE = 6
valid_dataset = load_from_disk("/home/hoesungryu/etri-erc/kemdy19-kemdy20_valid4_multilabelFalse_rdeuceTrue")
train_dataset = load_from_disk("/home/hoesungryu/etri-erc/kemdy19-kemdy20_train4_multilabelFalse_rdeuceTrue")

valid_dataloadaer = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
train_dataloadaer = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
from hydra import compose, initialize

with initialize(version_base=None, config_path="./config/model"):
    cfg = compose(config_name="mlp_mixer_roberta")

cfg.config['txt'] = "klue/roberta-large"
cfg['_target_'] = "erc.model.mlp_mixer.MLP_Mixer_Roberta"


In [None]:

with initialize(version_base=None, config_path="./config"):
    config_ = compose(config_name="train.yaml")



In [None]:
from erc.model.mlp_mixer import MLP_Mixer_Roberta

In [None]:
criterions_dict = {
    'cls': 'erc.optims.FocalLoss',
    'reg': 'torch.nn.MSELoss'
}

In [None]:
from torch.optim import AdamW

model = ERCModule(
    model = MLP_Mixer_Roberta(cfg.config, criterions = criterions_dict),
    optimizer=config_.optim,
    scheduler=config_.scheduler,
    train_loader=train_dataloadaer,
    valid_loader=valid_dataloadaer
)


In [None]:
CKPT = '/home/hoesungryu/etri-erc/outputs/2023-04-08/13-23-20/61299-valid_acc0.949.ckpt'
ckpt = torch.load(CKPT, map_location = torch.device('cuda:1'))
model_ckpt = ckpt.pop("state_dict")
device = torch.device('cuda:1')
model.to(device).load_state_dict(model_ckpt)

In [None]:

def setup_trainer(config: omegaconf.DictConfig) -> pl.LightningModule:
    logger.info("Start Setting up")
    erc.utils._seed_everything(config.misc.seed)

    ckpt = config.module.load_from_checkpoint
    if ckpt:
        ckpt = torch.load(ckpt)
        model_ckpt = ckpt.pop("state_dict")
    else:
        model_ckpt = None

    logger.info("Start intantiating Models & Optimizers")
    model = hydra.utils.instantiate(config.model, checkpoint=model_ckpt)

    logger.info("Start instantiating dataloaders")
    dataloaders = erc.datasets.get_dataloaders(ds_cfg=config.dataset,
                             dl_cfg=config.dataloader,
                             modes=config.misc.modes)
    
    logger.info("Start instantiating Pytorch-Lightning Trainer")


    module = hydra.utils.instantiate(config.module,
                                    model=model,
                                    optimizer=config.optim,
                                    scheduler=config.scheduler,
                                    train_loader=dataloaders["train"],
                                    valid_loader=dataloaders["valid"])

    return module, dataloaders



In [None]:
def get_label(batch: dict, task: erc.constants.Task = None):
    device = torch.device('cuda:1')
    # labels = batch["emotion"].long()
    # labels = torch.stack([batch["valence"], batch["arousal"]], dim=1).float()
    labels = {
    # "emotion": batch["emotion"].to(device),
    "emotion": batch["emotion"].to(device),
    "regress": torch.stack([batch["valence"], batch["arousal"]], dim=1),
    "vote_emotion": batch.get("vote_emotion", None)
    }
    # TODO: Add Multilabel Fetch
    return labels


In [None]:
model

In [None]:
from torchmetrics import Accuracy, AUROC, ConcordanceCorrCoef, F1Score


# device = torch.device('cuda:1')
# model.to(device).load_state_dict(new_ckpt)
# model.eval()
model = model.model

acc = Accuracy(task="multiclass", num_classes=7).to(device)

total_score = []
pred = [] 
target = []

pbar = tqdm(
total=int(len(valid_dataset)/BATCH_SIZE), 
iterable = enumerate(valid_dataloadaer))

total = 0 
correct = 0 
accumulate = 0 
for batch_idx, batch in pbar:
    labels = get_label(batch) # concat 
    
    result = model(wav=batch["wav"].to(device),
            wav_mask=batch["wav_mask"].to(device),
            txt=batch["txt"].to(device),
            txt_mask=batch["txt_mask"].to(device),
            labels=labels)

    # result = remove_deuce(outputs=result)
    # result["emotion"] = result["emotion"].argmax(dim=1)
    # pred.append(list(result["cls_pred"].detach().cpu().numpy()))
    # target.append(list(result["emotion"].detach().cpu().numpy()))
#     ACC_score = acc(preds=result["cls_pred"], target=result["emotion"]).item()
#     print(f"{ACC_score:.4f}")
#     print(torch.sum(result["cls_pred"] == result["emotion"]))
    # break
#     break
#     total_score.append(ACC_score)
    # save_name = os.path.join(SAVE_PATH, f'wav_txt_{batch_idx:03d}.pickle')
    # with open(save_name, 'wb') as f:
    #     pickle.dump(save_dict, f, pickle.HIGHEST_PROTOCOL)
    total += result["emotion"].size(0)
    correct += torch.sum(result["cls_pred"].argmax(dim=1) == result["emotion"]).item()

    accumulate += (correct / total)
    accumulate /= 2 
# print('Accuracy of the network on the 10000 test images: %d %%' % (
#     100 * correct / total))

In [None]:

def train(config: omegaconf.DictConfig) -> None:
    module, dataloaders = setup_trainer(config)
    
    # Logger Setup
    logger = hydra.utils.instantiate(config.logger)
    logger.watch(module)
    # Hard-code config uploading
    wandb.config.update(
        omegaconf.OmegaConf.to_container(config, resolve=True, throw_on_missing=True)
    )

    # Callbacks
    callbacks: dict = hydra.utils.instantiate(config.callbacks)
    trainer: pl.Trainer = hydra.utils.instantiate(config.trainer,
                                                  logger=logger,
                                                  callbacks=list(callbacks.values()))
    trainer.fit(model=module,
                train_dataloaders=dataloaders["train"],
                val_dataloaders=dataloaders["valid"])
