In [1]:
%load_ext autoreload
%autoreload 2

## Import

In [2]:
import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from ay2.torch.nn import LambdaFunctionModule

In [2]:
try:
    from .conv_attention import MLP, Attention
    from .utils import AdaptiveConv1d, DepthwiseSeparableConv1d, Multi_Head_Attention
except ImportError:
    from conv_attention import MLP, Attention
    from utils import AdaptiveConv1d, DepthwiseSeparableConv1d, Multi_Head_Attention

ImportError: attempted relative import with no known parent package


## Multi-Scale Fusion Module

In [10]:
class MultiScaleFusion(nn.Module):
    def __init__(self, n_dim, n_head=1, scales=[1, 5, 10], samples_per_frame=400):
        super().__init__()

        self.n_dim = n_dim
        self.samples_per_frame = samples_per_frame
        self.norm = nn.BatchNorm1d(n_dim)

        scales = [1, 5, 10]
        assert samples_per_frame % scales[-1] == 0, samples_per_frame

        self.down_samples = nn.ModuleList(
            [
                nn.Sequential(
                    nn.AvgPool1d(scales[i] * 3, stride=scales[i], padding=scales[i])
                    if i > 0
                    else nn.Identity(),
                    nn.Conv1d(n_dim, n_dim, 3, stride=1, padding=1),
                    # nn.LeakyReLU(negative_slope=0.3),
                    # nn.Conv1d(n_dim, n_dim, 3, stride=1, padding=1),
                )
                for i in range(3)
            ]
        )

        self.up_samples = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Upsample(scale_factor=scales[i]) if i > 0 else nn.Identity(),
                    nn.Conv1d(n_dim, n_dim, 3, stride=1, padding=1),
                    # nn.LeakyReLU(negative_slope=0.3),
                    # nn.Conv1d(n_dim, n_dim, 3, stride=1, padding=1),
                )
                for i in range(3)
            ]
        )

        self.conv_fusion = nn.Sequential(
            nn.Conv1d(n_dim * 3, n_dim, 3, stride=1, padding=1),
            nn.LeakyReLU(negative_slope=0.3),
            nn.Conv1d(n_dim, n_dim * 3, 3, stride=1, padding=1),
        )
        self.mha = Multi_Head_Attention(
            max_k=80, embed_dim=n_dim, num_heads=n_head, dropout=0.1
        )
        self.attn_upsamples = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Upsample(scale_factor=samples_per_frame // scales[i]),
                    nn.Conv1d(n_dim, n_dim, 3, stride=1, padding=1),
                    # nn.LeakyReLU(negative_slope=0.3),
                    # nn.Conv1d(n_dim, n_dim, 3, stride=1, padding=1),
                )
                for i in range(3)
            ]
        )

        self.register_parameter("alpha", nn.Parameter(torch.ones(1, n_dim, 1)))
        self.register_parameter("beta", nn.Parameter(torch.ones(1, n_dim * 3, 1)))

    def forward(self, x):
        short_cut = x
        x = self.norm(x)
        n_frames = x.shape[-1] // self.samples_per_frame
        avg_pool = partial(F.adaptive_avg_pool1d, output_size=n_frames)
        max_pool = partial(F.adaptive_max_pool1d, output_size=n_frames)

        frame_feat = []
        ms_feat = []
        for i in range(3):
            y = self.down_samples[i](x)
            print("scale %d : " % i, y.shape)
            ms_feat.append(y)
            attn = avg_pool(y) + max_pool(y)  # (B, n_dim, n_frames)
            frame_feat.append(attn)
            # frame_feat.append(attn.transpose(1, 2))  # (B, n_frames, n_dim)

        frame_feat = torch.concat(frame_feat, dim=1)  # (B, 3*n_dim, n_frames)
        frame_feat = self.conv_fusion(frame_feat)
        frame_feat = torch.split(frame_feat, self.n_dim, dim=1)
        frame_feat = [x.transpose(1, 2) for x in frame_feat]

        v, k, q = frame_feat
        attn = self.mha(q, k, v)
        attn = attn.transpose(1, 2)  # (B, n_dim, n_frames)
        # print("attn shape: ", attn.shape)

        for i in range(3):
            _attn = self.attn_upsamples[i](attn)
            ms_feat[i] = ms_feat[i] * _attn
            # ms_feat[i] = (
            #     ms_feat[i]
            #     + self.beta[:, i * self.n_dim : (i + 1) * self.n_dim, :] * _attn
            # )

        rec_feat = []
        for i in range(3):
            y = self.up_samples[i](ms_feat[i])
            rec_feat.append(y)

        rec_feat = rec_feat[0] + rec_feat[1] + rec_feat[2]
        x = x + self.alpha * rec_feat
        return x

In [11]:
model = MultiScaleFusion(n_dim=32)
x = torch.randn(2, 32, 4000)
model(x)

tensor([[[ 0.7555, -0.6272,  0.9750,  ..., -0.2995,  0.8353,  0.4291],
         [-0.1354, -0.7936,  0.2231,  ..., -0.7562, -0.5525,  1.4898],
         [ 0.1626,  0.1518,  1.1208,  ...,  0.2509, -0.5045,  0.3458],
         ...,
         [-2.5315, -0.2392, -0.4742,  ...,  1.4658, -0.5670,  0.5073],
         [-1.9215,  0.7726,  0.0569,  ...,  0.5763,  0.0100,  2.2050],
         [-0.5215,  0.5268, -2.8960,  ...,  0.3329,  2.2136, -1.2991]],

        [[ 1.0391, -1.4285,  0.1913,  ..., -0.1417, -0.7822,  1.2220],
         [-1.4263, -0.1288,  0.5291,  ..., -0.2794, -2.0096, -2.5121],
         [-1.0814, -1.7781, -2.0418,  ...,  0.3363, -1.7723, -0.2150],
         ...,
         [ 0.5747, -0.1943,  0.6696,  ..., -1.1426,  0.3285, -0.1630],
         [ 0.7465, -0.6417,  1.3040,  ...,  0.2831,  1.9883,  1.6042],
         [-0.8657, -0.9114,  1.8017,  ..., -0.7329, -0.5312, -1.6505]]],
       grad_fn=<AddBackward0>)

## 2D

In [33]:
class MultiScaleFusion2D(nn.Module):
    def __init__(self, n_dim, n_head=1, scales=[1, 5, 10], samples_per_frame=400):
        super().__init__()

        self.n_dim = n_dim
        self.norm = nn.BatchNorm2d(n_dim)

        scales = [1, 2, 3]

        self.down_samples = nn.ModuleList(
            [
                nn.Sequential(
                    nn.AvgPool2d(scales[i] * 3, stride=scales[i], padding=scales[i])
                    if i > 0
                    else nn.Identity(),
                    nn.Conv2d(
                        n_dim, n_dim, 3, stride=1, padding=1, groups=1, bias=False
                    ),
                )
                for i in range(3)
            ]
        )

        # self.conv_attention = Attention(dim=n_dim)
        self.conv_attention = nn.Sequential(
            Attention(dim=n_dim), MLP(dim=n_dim, mlp_ratio=2.0)
        )

        self.up_samples = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Upsample(
                        scale_factor=scales[i], mode="bilinear", align_corners=True
                    )
                    if i > 0
                    else nn.Identity(),
                    nn.Conv2d(
                        n_dim, n_dim, 3, stride=1, padding=1, groups=1, bias=False
                    ),
                )
                for i in range(3)
            ]
        )

        # self.final_proj = nn.Sequential(
        #     nn.Conv2d(n_dim*3, n_dim, 1, bias=False),
        #     nn.BatchNorm2d(n_dim),
        #     nn.Dropout(0.1),
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(n_dim, n_dim, 1, bias=False)
        # )
        # self.final_proj = nn.Conv2d(n_dim*3, n_dim, 1, bias=False)

        self.register_parameter("alpha1", nn.Parameter(torch.ones(1, n_dim, 1, 1)))
        self.register_parameter("alpha2", nn.Parameter(torch.ones(1, n_dim, 1, 1)))
        self.register_parameter("alpha3", nn.Parameter(torch.ones(1, n_dim, 1, 1)))
        self.register_parameter("alpha", nn.Parameter(torch.ones(1, n_dim, 1, 1)))

    def forward(self, x):
        B, C, H, W = x.shape
        short_cut = x
        x = self.norm(x)

        frame_feat = []
        ms_feat = []
        for i in range(3):
            y = self.down_samples[i](x)
            y = self.conv_attention(y)
            # print("scale %d : " % i, y.shape)
            ms_feat.append(y)

        rec_feat = []
        for i in range(3):
            y = self.up_samples[i](ms_feat[i])
            _H, _W = y.shape[-2], y.shape[-1]
            y = F.pad(y, (0, W - _W, 0, H - _H))
            # print(y.shape)
            rec_feat.append(y)

        # rec_feat = (rec_feat[0] + rec_feat[1] + rec_feat[2]) / 3
        rec_feat = (
            self.alpha1 * rec_feat[0]
            + self.alpha2 * rec_feat[1]
            + self.alpha3 * rec_feat[2]
        ) / 3
        # rec_feat = self.final_proj(torch.concat(rec_feat, dim=1))
        x = x + self.alpha * rec_feat
        # x = x + rec_feat
        return x

In [None]:
module = MultiScaleFusion2D(n_dim=64)
x = torch.randn(2, 64, 224, 252)
module(x).shape

In [51]:
# spectrogram = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=187)
# x = torch.randn(2, 1, 48000)
# spectrogram(x).shape

torch.Size([2, 1, 257, 257])