In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import statistics
from copy import deepcopy
from functools import partial

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from ay2.torch.deepfake_detection import DeepfakeAudioClassification
from ay2.torch.losses import (
    BinaryTokenContrastLoss,
    Focal_loss,
    LabelSmoothingBCE,
    MultiClass_ContrastLoss,
)
from ay2.torch.optim import Adam_GC
from ay2.torch.optim.selective_weight_decay import (
    Optimizers_with_selective_weight_decay,
    Optimizers_with_selective_weight_decay_for_modulelist,
)
from ay2.torchaudio.transforms import AddGaussianSNR
from ay2.torchaudio.transforms.self_operation import (
    AudioToTensor,
    CentralAudioClip,
    RandomAudioClip,
    RandomPitchShift,
    RandomSpeed,
)
from tqdm.auto import tqdm

In [None]:
from ay2.tools import (
    find_unsame_name_for_file,
    freeze_modules,
    rich_bar,
    unfreeze_modules,
)

In [None]:
from ay2.torchaudio.transforms import SpecAugmentBatchTransform
from ay2.torchaudio.transforms.self_operation import RandomSpeed

random_speed = RandomSpeed(min_speed=0.5, max_speed=2.0, p=0.5)

In [None]:
try:
    from .multiView_model import MultiViewModel
except ImportError:
    from multiView_model import MultiViewModel

In [None]:
class MultiViewModel_lit(DeepfakeAudioClassification):
    def __init__(self, cfg=None, args=None, **kwargs):
        super().__init__()
        self.model = MultiViewModel()
        self.cfg = cfg

        self.audio_transform = SpecAugmentBatchTransform.from_policy(cfg.aug_policy)

        self.configure_loss_fn()
        self.save_hyperparameters()

    def configure_loss_fn(
        self,
    ):
        self.bce_loss = LabelSmoothingBCE(label_smoothing=0.1)
        self.contrast_loss2 = BinaryTokenContrastLoss(alpha=0.1)

    def calcuate_loss(self, batch_res, batch, stage="train"):
        B = batch_res["logit"].shape[0]
        label = batch["label"]
        losses = {}

        losses["cls_loss1D"] = self.bce_loss(batch_res["logit1D"], label)
        losses["cls_loss2D"] = self.bce_loss(batch_res["logit2D"], label)
        losses["cls_loss"] = self.bce_loss(batch_res["logit"], label)
        losses["contrast_loss"] = self.contrast_loss2(batch_res["feature"], label)
        losses["loss"] = (
            losses["cls_loss"]
            + 0.5 * losses["contrast_loss"]
            + 0.1 * (losses["cls_loss1D"] + losses["cls_loss2D"])
        )

        return losses

    def configure_optimizers(self):
        optimizer = Optimizers_with_selective_weight_decay_for_modulelist(
            [self.model],
            optimizer="Adam",
            lr=0.0001,
            weight_decay=0.01,
        )
        return [optimizer]

    def _shared_pred(self, batch, batch_idx, stage="train"):
        """common predict step for train/val/test

        Note that the data augmenation is done in the self.model.feature_extractor.

        """
        audio, sample_rate = batch["audio"], batch["sample_rate"]
        batch_res = self.model(
            audio, stage=stage, batch=batch if stage == "train" else None
        )
        batch_res["pred"] = (torch.sigmoid(batch_res["logit"]) + 0.5).int()

        return batch_res