In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math
import random
from copy import deepcopy
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from ay2.torch.nn import LambdaFunctionModule
from einops import rearrange

In [3]:
from torchvision.transforms import v2

In [144]:
try:
    from .rawnet.rawnet2 import RawNet2
    from .resnet import ResNet
    from .resnet1d import ResNet1D
except ImportError:
    from rawnet.rawnet2 import RawNet2
    from resnet import ResNet
    from resnet1d import ResNet1D

In [145]:
class MultiViewModel(nn.Module):
    def __init__(self, verbose=0, cfg=None, args=None, **kwargs):
        super().__init__()

        self.cfg = cfg

        self.feature_model2D = ResNet()
        self.feature_model1D = ResNet1D(
            in_channels=1,
            base_filters=64,
            kernel_size=3,
            downsample_stride=2,
            groups=1,
            n_block=8,
            n_classes=0,
            downsample_gap=2,
            increasefilter_gap=2,
            verbose=1,
        )
        final_dim = 512

        self.proj1D = nn.Linear(1024, 512)

        self.dropout = nn.Dropout(0.1)
        self.cls_final = nn.Sequential(
            nn.utils.weight_norm(nn.Linear(final_dim * 2, 1, bias=False)),
        )
        self.cls1D, self.cls2D = [
            nn.Sequential(
                nn.utils.weight_norm(nn.Linear(final_dim, 1, bias=False)),
            )
            for _ in range(2)
        ]

        self.avgpool = nn.AdaptiveAvgPool2d(1)

        self.spectrogram_transforms = nn.ModuleList(
            [
                torchaudio.transforms.Spectrogram(
                    n_fft=512 // (2 ** (i + 2)), hop_length=187
                )
                for i in range(4)
            ]
        )

        self.alphas = nn.ParameterList(
            [nn.Parameter(torch.ones(1, _c, 1) * 0.5) for _c in [64, 128, 256, 512]]
        )
        self.betas = nn.ParameterList(
            [nn.Parameter(torch.ones(1, _c, 1, 1) * 0.5) for _c in [64, 128, 256, 512]]
        )
        self.set_verbose(verbose)

    def feature_norm(self, code):
        code_norm = code.norm(p=2, dim=1, keepdim=True) / 10.0
        code = torch.div(code, code_norm)
        return code

    def print_shape(self, *args):
        for x in args:
            print(x.shape)

    def set_verbose(self, verbose):
        self.feature_model1D.verbose = verbose
        self.feature_model2D.verbose = verbose

    def transform_audio_into_spectorgram(self, x, transform):
        x = transform(x)
        x = torch.log(x + 1e-7)
        x = (x - torch.mean(x, dim=(1, 2, 3), keepdim=True)) / (
            torch.std(x, dim=(1, 2, 3), keepdim=True) + 1e-9
        )
        return x

    def fuse_2Dfeat_for_1D(self, feat1D, feat2D, idx):
        h, w = feat2D.shape[-2:]
        L = feat1D.shape[-1]
        scale_factor = L / (h * w)

        feat2D = rearrange(feat2D, "b c h w -> b c (w h)")
        feat2D = F.upsample_nearest(feat2D, scale_factor=scale_factor + 0.0001)

        feat1D = self.alphas[idx] * feat1D + (1 - self.alphas[idx]) * feat2D
        return feat1D

    def fuse_1Dfeat_for_2D(self, feat1D, feat2D, idx):
        feat1D = self.transform_audio_into_spectorgram(
            feat1D, self.spectrogram_transforms[idx]
        )
        feat2D = self.betas[idx] * feat2D + (1 - self.betas[idx]) * feat1D
        return feat2D

    def forward(self, x, stage="test", batch=None):
        batch_size = x.shape[0]
        res = {}
        # _input = x.clone()

        feat1_1 = self.feature_model1D.compute_stage1(x)
        feat2_1 = self.feature_model2D.compute_stage1(x)

        feat1_2 = self.feature_model1D.compute_stage2(
            self.fuse_2Dfeat_for_1D(feat1_1, feat2_1, 0)
        )
        feat2_2 = self.feature_model2D.compute_stage2(
            self.fuse_1Dfeat_for_2D(feat1_1, feat2_1, 0)
        )

        feat1_3 = self.feature_model1D.compute_stage3(
            self.fuse_2Dfeat_for_1D(feat1_2, feat2_2, 1)
        )
        feat2_3 = self.feature_model2D.compute_stage3(
            self.fuse_1Dfeat_for_2D(feat1_2, feat2_2, 1)
        )

        feat1_4 = self.feature_model1D.compute_stage4(
            self.fuse_2Dfeat_for_1D(feat1_3, feat2_3, 2)
        )
        feat2_4 = self.feature_model2D.compute_stage4(
            self.fuse_1Dfeat_for_2D(feat1_3, feat2_3, 2)
        )

        feat1 = self.feature_model1D.compute_latent_feature(feat1_4)
        feat2 = self.feature_model2D.compute_latent_feature(feat2_4)

        res["feature"] = torch.concat([feat1, feat2], dim=-1)
        res["logit1D"] = self.cls1D(self.dropout(feat1)).squeeze(-1)
        res["logit2D"] = self.cls2D(self.dropout(feat2)).squeeze(-1)
        res["logit"] = self.cls_final(self.dropout(res["feature"])).squeeze(-1)
        return res

In [146]:
model = MultiViewModel(verbose=1)
x = torch.randn(3, 1, 48000)
_ = model(x)

ResNet1D input shape torch.Size([3, 1, 48000])
ResNet1D Stage 1: output shape torch.Size([3, 64, 12000])
ResNet Stage 1: output shape torch.Size([3, 64, 65, 65])
ResNet1D Stage 2: output shape torch.Size([3, 128, 6000])
ResNet Stage 2: output shape torch.Size([3, 128, 33, 33])
ResNet1D Stage 3: output shape torch.Size([3, 256, 3000])
ResNet Stage 3: output shape torch.Size([3, 256, 17, 17])
ResNet1D Stage 4: output shape torch.Size([3, 512, 1500])
ResNet Stage 4: output shape torch.Size([3, 512, 9, 9])
ResNet1D Latent Feature: output shape torch.Size([3, 512])
ResNet Latent Feature: output shape torch.Size([3, 512])
