In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from .utils import AdaptiveConv1d, DepthwiseSeparableConv1d, Multi_Head_Attention

In [3]:
from utils import AdaptiveConv1d, DepthwiseSeparableConv1d, Multi_Head_Attention

In [4]:
class MultiScaleFusion(nn.Module):
    def __init__(self, n_dim, kernel_size, samples_per_frame=400):
        super().__init__()

        self.samples_per_frame = samples_per_frame

        self.norm = nn.BatchNorm1d(n_dim)
        
        # strides = [1, kernel_size, kernel_size * 2]
        kernel_sizes = [5, 5, 25]
        strides = [1, 5, 25]
        assert samples_per_frame % strides[-1] == 0, samples_per_frame
        self.adap_conv_blocks = nn.ModuleList(
            [
                AdaptiveConv1d(
                    n_dim=n_dim,
                    kernel_size=kernel_sizes[i],
                    stride=strides[i],
                    reduction=strides[i],
                    groups=n_dim,
                    conv_transpose="upsample",
                )
                for i in range(3)
            ]
        )
        self.conv_blocks = nn.ModuleList(
            [
                DepthwiseSeparableConv1d(
                    n_dim, n_dim, kernel_size=3, stride=1, padding="same"
                )
                for i in range(6)
            ]
        )

        self.mha = Multi_Head_Attention(max_k=80, embed_dim=n_dim, num_heads=1)
        self.attn_upsamples = nn.ModuleList(
            [
                nn.Upsample(scale_factor=samples_per_frame // strides[i])
                for i in range(3)
            ]
        )

        self.register_parameter('alpha', nn.Parameter(torch.ones(1, n_dim, 1)) )

    def forward(self, x):
        short_cut = x
        x = self.norm(x)
        n_frames = x.shape[-1] // self.samples_per_frame
        avg_pool = partial(F.adaptive_avg_pool1d, output_size=n_frames)
        max_pool = partial(F.adaptive_max_pool1d, output_size=n_frames)

        frame_feat = []
        ms_feat = []
        for i in range(3):
            y = self.adap_conv_blocks[i](x)
            # y = self.conv_blocks[i](y)
            # print("scale", i, y.shape)
            ms_feat.append(y)
            attn = avg_pool(y) + max_pool(y)
            frame_feat.append(attn.transpose(1, 2))  # (B, n_frames, n_dim)

        v, k, q = frame_feat
        attn = self.mha(q, k, v)
        attn = attn.transpose(1, 2)  # (B, n_dim, n_frames)
        # print("attn shape: ", attn.shape)

        rec_feat = []
        for i in range(3):
            _attn = self.attn_upsamples[i](attn)
            y = ms_feat[i] * _attn
            y = self.adap_conv_blocks[i].reverse(y)
            # y = self.conv_blocks[i + 3](y)
            rec_feat.append(y)

        rec_feat = rec_feat[0] + rec_feat[1] + rec_feat[2]
        x = x + self.alpha * rec_feat
        return x

In [5]:
def build_stage(
    n_dim_in, n_dim_out, n_blocks, kernel_size, samples_per_frame, downsample_factor=1
):
    # print(n_dim_in, n_dim_out)
    conv1 = nn.Conv1d(n_dim_in, n_dim_out, 3, stride=1, padding=1)
    conv_blocks = [
        MultiScaleFusion(
            n_dim=n_dim_out,
            kernel_size=kernel_size,
            samples_per_frame=samples_per_frame,
        )
        for i in range(n_blocks)
    ]
    module = nn.Sequential(conv1, *conv_blocks)
    if downsample_factor > 1:
        module.add_module(
            "down-sample", nn.Conv1d(n_dim_out, n_dim_out, 5, stride=2, padding=2)
        )
    return module

In [6]:
modle = build_stage(
    n_dim_in=32,
    n_dim_out=128,
    n_blocks=3,
    kernel_size=25,
    samples_per_frame=400,
    downsample_factor=2,
)
with torch.autograd.profiler.profile(enabled=True) as prof:
    x = torch.randn(16, 32, 24000)
    _ = modle(x).shape
print(prof.key_averages().table(sort_by="self_cpu_time_total"))

INFO:2023-07-20 09:37:30 2875196:2875196 init.cpp:149] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti
STAGE:2023-07-20 09:37:30 2875196:2875196 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                            Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
--------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
        aten::mkldnn_convolution        77.04%        2.267s        77.07%        2.268s     113.407ms            20  
        aten::upsample_nearest1d         6.23%     183.205ms         6.24%     183.570ms      10.198ms            18  
                       aten::add         4.35%     128.147ms         4.35%     128.147ms       6.102ms            21  
         aten::native_batch_norm         3.38%      99.615ms         3.39%      99.684ms      33.228ms             3  
                       aten::mul         3.23%      95.018ms         3.23%      95.086ms       5.283ms            18  
                   aten::normal_         2.38%  

STAGE:2023-07-20 09:37:34 2875196:2875196 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-07-20 09:37:34 2875196:2875196 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [7]:
class AudioModel(nn.Module):
    def __init__(
        self, dims=[32, 64, 128, 256], n_blocks=[2, 2, 6, 2], samples_per_frame=400
    ):
        super().__init__()

        self.samples_per_frame = samples_per_frame
        self.conv_head = nn.Sequential(
            nn.Conv1d(1, dims[0], 2, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv1d(dims[0], dims[0], 3, stride=1, padding=1)
        )

        self.stages = nn.ModuleList(
            [
                build_stage(
                    n_dim_in=dims[max(i - 1, 0)],
                    n_dim_out=dims[i],
                    n_blocks=n_blocks[i],
                    kernel_size=25,
                    samples_per_frame=400 // (2 * (2**i)),
                    downsample_factor=1 if i == 0 else 2,
                )
                for i in range(4)
            ]
        )

        self.cls_head = nn.Linear(dims[-1], 1, bias=False)

    def forward(self, x):
        audio_length = x.shape[-1]
        audio_frames = audio_length // self.samples_per_frame

        x = self.conv_head(x)
        # print(x.shape)
        for i, stage in enumerate(self.stages):
            x = stage(x)
            # print("Output of the %d-th stage"%(i+1), x.shape)

        x = F.adaptive_avg_pool1d(x, 1)
        # x = F.adaptive_avg_pool1d(x, audio_frames)
        x = self.cls_head(x.transpose(1, 2))
        x = torch.mean(x, dim=1)
        return x

In [15]:
model = AudioModel()
x = torch.randn(32, 1, 48000)
model(x)
with torch.autograd.profiler.profile(enabled=True) as prof:
    x = torch.randn(16, 1, 48000)
    _ = model(x).shape
print(prof.key_averages().table(sort_by="self_cpu_time_total"))

torch.Size([32, 32, 24000])
torch.Size([32, 1])


STAGE:2023-07-19 23:17:46 2780097:2780097 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


torch.Size([16, 32, 24000])
torch.Size([16, 1])


STAGE:2023-07-19 23:17:51 2780097:2780097 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-07-19 23:17:51 2780097:2780097 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                          Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::mkldnn_convolution        78.91%        2.940s        78.98%        2.943s      36.783ms            80  
      aten::upsample_nearest1d         6.59%     245.545ms         6.63%     246.891ms       3.429ms            72  
                     aten::add         3.79%     141.279ms         3.79%     141.279ms       2.355ms            60  
     aten::adaptive_max_pool2d         3.34%     124.550ms         3.34%     124.550ms       3.460ms            36  
                     aten::mul         2.23%      83.057ms         2.24%      83.401ms       1.390ms            60  
    aten::_adaptive_avg_pool2d         2.02%      75.330ms      

In [17]:
model.to("cuda:1")

import torch
from torch.autograd import Variable

x = torch.randn(16, 1, 48000)
y = Variable(x, requires_grad=True).to("cuda:1")

In [26]:
with torch.autograd.profiler.profile(use_cuda=True) as prof:
    z = model(y)
    print(y.shape)
    z = torch.sum(z)
    z.backward()
# NOTE: some columns were removed for brevityM
print(prof.key_averages().table(sort_by="self_cuda_time_total"))

STAGE:2023-07-19 23:23:53 2780097:2780097 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


torch.Size([16, 32, 24000])
torch.Size([16, 1])
torch.Size([16, 1, 48000])


STAGE:2023-07-19 23:23:54 2780097:2780097 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-07-19 23:23:54 2780097:2780097 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             aten::convolution_backward         2.13%       8.113ms         2.92%      11.117ms     138.963us     237.628ms        61.27%     245.728ms       3.072ms            80  
                                            aten::copy_        14.06%      53.490ms        14.06%      53.490ms     316.509us      47.666ms        12.29%      47.666ms     282.047us           169  
         