In [1]:
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
import os
import librosa
import torch.nn as nn
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import math
import torch.optim as optim
import pandas as pd
torch.hub.list('zhanghang1989/ResNeSt', force_reload=True)

df = pd.read_csv('./asset/birdclef-2021/train_metadata.csv')

folder_path = './asset/birdclef-2021/train_short_audio/'

all_list = list()

all_sec = list()
for pri,sec,file in zip(df['primary_label'],df['secondary_labels'],df['filename']):
    sec_2 = list(sec.replace("'",'').replace('[','').replace(']','').replace(' ','').split(','))  
    sec_2.append(pri)
    if sec_2[0]=='':
        sec_2=sec_2[1:]
    for bird in sec_2:
        all_sec.append(bird)
    filename = os.path.join(folder_path,pri,file)
    all_list.append({'path':filename,'bird':sec_2})
    
blist = list(set(all_sec))
classes_num= len(blist)

Downloading: "https://github.com/zhanghang1989/ResNeSt/archive/master.zip" to /root/.cache/torch/hub/master.zip


In [2]:

def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            
    
def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


def init_gru(rnn):
    """Initialize a GRU layer. """
    
    def _concat_init(tensor, init_funcs):
        (length, fan_out) = tensor.shape
        fan_in = length // len(init_funcs)
    
        for (i, init_func) in enumerate(init_funcs):
            init_func(tensor[i * fan_in : (i + 1) * fan_in, :])
        
    def _inner_uniform(tensor):
        fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
        nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
    
    for i in range(rnn.num_layers):
        _concat_init(
            getattr(rnn, 'weight_ih_l{}'.format(i)),
            [_inner_uniform, _inner_uniform, _inner_uniform]
        )
        torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0)

        _concat_init(
            getattr(rnn, 'weight_hh_l{}'.format(i)),
            [_inner_uniform, _inner_uniform, nn.init.orthogonal_]
        )
        torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0)
        
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super(ConvBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

        
    def forward(self, input, pool_size=(2, 2), pool_type='avg'):
        
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')
        
        return x
    

class AttBlockV2(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear"):
        super().__init__()

        self.activation = activation
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)

    

class Tmodel(nn.Module):
    def __init__(self,train=True):
        super(Tmodel,self).__init__()
        
        SPEC_HEIGHT = 128
        SPEC_WIDTH = 256
        NUM_MELS = SPEC_HEIGHT
        HOP_LENGTH = int(32000 * 5 / (SPEC_WIDTH - 1)) # sample rate * duration / spec width - 1 == 627
        FMIN = 500
        FMAX = 12500
        classes_num = 398
        self.interpolate_ratio = 8
        
        self.spectrogram_extractor  = Spectrogram(
                    n_fft=2048,
                    hop_length=HOP_LENGTH,
                    freeze_parameters=True)
        
        self.logmel_extractor = LogmelFilterBank(sr=32000,
            n_mels=NUM_MELS, fmin=FMIN, fmax=FMAX, freeze_parameters=True)

        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 
            freq_drop_width=8, freq_stripes_num=2)
        
        self.bn0 = nn.BatchNorm2d(128)

        # load pretrained models, using ResNeSt-50 as an example
        if train:
            base_model = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True)
        else:
            base_model = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=False)
            
        layers = list(base_model.children())[:-2]
        self.encoder = nn.Sequential(*layers)
        
        self.gru = nn.GRU(input_size=2048, hidden_size=1024, num_layers=1, 
            bias=True, batch_first=True, bidirectional=True)

        self.att_block = AttBlockV2(2048, classes_num, activation='sigmoid')
        self.init_weights()
        
    def init_weights(self):
        init_bn(self.bn0)
        init_gru(self.gru)
        
    def forward(self,input,mixup_lambda=None):
        
        x = self.spectrogram_extractor(input)   # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)    # (batch_size, 1, time_steps, mel_bins)
        
        frames_num = x.shape[2]
        
        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        
        if self.training:
            x = self.spec_augmenter(x)
        
        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)
        
        x = torch.tile(x,(1,3,1,1))
        x = self.encoder(x)

        x = torch.mean(x, dim=3)
        x = x.transpose(1, 2)   # (batch_size, time_steps, channels)
        (x, _) = self.gru(x)
        x = x.transpose(1, 2)

        (clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
        """cla: (batch_size, classes_num, time_stpes)"""
        
        logit = torch.sum(norm_att * self.att_block.cla(x), dim=2)
        segmentwise_logit = self.att_block.cla(x).transpose(1, 2)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        interpolate_ratio = frames_num // segmentwise_output.size(1)
        
        # Framewise output
        framewise_output = interpolate(segmentwise_output,
                                       interpolate_ratio)
        framewise_output = pad_framewise_output(framewise_output, frames_num)

        framewise_logit = interpolate(segmentwise_logit, interpolate_ratio)
        framewise_logit = pad_framewise_output(framewise_logit, frames_num)
        
        output_dict = {
            "framewise_output": framewise_output,
            "segmentwise_output": segmentwise_output,
            "logit": logit,
            "framewise_logit": framewise_logit,
            "clipwise_output": clipwise_output
        }

            
        return output_dict
    
def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.
    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled


def pad_framewise_output(framewise_output: torch.Tensor, frames_num: int):
    """Pad framewise_output to the same length as input frames. The pad value
    is the same as the value of the last frame.
    Args:
      framewise_output: (batch_size, frames_num, classes_num)
      frames_num: int, number of frames to pad
    Outputs:
      output: (batch_size, frames_num, classes_num)
    """
    output = F.interpolate(
        framewise_output.unsqueeze(1),
        size=(frames_num, framewise_output.size(2)),
        align_corners=True,
        mode="bilinear").squeeze(1)

    return output

In [3]:
#prepareing
import random
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import random
import soundfile as sf

df = pd.read_csv('./asset/birdclef-2021/train_metadata.csv')

folder_path = './asset/birdclef-2021/train_short_audio/'

all_list = list()

all_sec = list()
for pri,sec,file in zip(df['primary_label'],df['secondary_labels'],df['filename']):
    sec_2 = list(sec.replace("'",'').replace('[','').replace(']','').replace(' ','').split(','))  
    sec_2.append(pri)
    if sec_2[0]=='':
        sec_2=sec_2[1:]
    for bird in sec_2:
        all_sec.append(bird)
    filename = os.path.join(folder_path,pri,file)
    all_list.append({'path':filename,'bird':sec_2})

class testData(Dataset):
    def __init__(self,all_list):
        self.all_list = all_list
        
    def __len__(self):
        return len(self.all_list)

    def __getitem__(self,idx):
        return self.all_list[idx]
    
blist = list(set(all_sec))
classes_num= len(blist)

# collate_fn by class type, handling with parameters/frame
'''
class Collator(object):
    def __init__(self,blist,frame_sec=7,sr=32000,number_of_frame=20,classes_num=397):
        self.frame_sec = frame_sec
        self.sr = sr
        self.number_of_frame = number_of_frame
        self.classes_num = classes_num
        self.blist = blist
        
    def __call__(self,batch):
        birds = list()
        frames = list()
        duration = self.frame_sec * self.sr
        
        batch_ind = 0
        wav, _ = librosa.load(batch[batch_ind]['path'],sr=32000)
        
        wav_ind = 0
        while len(frames) < self.number_of_frame:
            if wav_ind+duration>len(wav) and batch_ind<len(batch)-1:
                batch_ind += 1
                wav, _ = librosa.load(batch[batch_ind]['path'],sr=32000)
                wav_ind = 0
            else:
                frame = wav[wav_ind:wav_ind+duration]
                frames.append(frame[np.newaxis])
                
                bird_arr = np.zeros((1,self.classes_num))
                for bird in batch[batch_ind]['bird']:
                    bird_arr[0][self.blist.index(bird)]=1
                    birds.append(bird_arr)
                wav_ind += duration
                
        print(frames)
        frames = np.concatenate(frames)
        birds = np.concatenate(birds)
        
        return frames,birds
'''
class BCEFocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, preds, targets):
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')(preds, targets)
        probas = torch.sigmoid(preds)
        loss = targets * self.alpha * \
            (1. - probas)**self.gamma * bce_loss + \
            (1. - targets) * probas**self.gamma * bce_loss
        loss = loss.mean()
        return loss


class BCEFocal2WayLoss(nn.Module):
    def __init__(self, weights=[1, 1], class_weights=None):
        super().__init__()

        self.focal = BCEFocalLoss()

        self.weights = weights

    def forward(self, input, target):
        input_ = input["logit"]
        target = target.float()

        framewise_output = input["framewise_logit"]
        clipwise_output_with_max, _ = framewise_output.max(dim=1)

        loss = self.focal(input_, target)
        aux_loss = self.focal(clipwise_output_with_max, target)

        return self.weights[0] * loss + self.weights[1] * aux_loss
    
class BatchCollator(object):
    def __init__(self,blist,frame_sec=7,sr=32000,classes_num=398):
        self.frame_sec = frame_sec
        self.sr = sr
        self.classes_num = classes_num
        self.blist = blist
        self.duration = frame_sec * sr

    def __call__(self,batch):
        waves = list()
        birds = np.zeros((len(batch),self.classes_num))
        for i,meta in enumerate(batch):
            wav, _ = sf.read(meta['path'])
            #wav, _ = librosa.load(meta['path'],sr=self.sr)
            if len(wav) < self.duration:
                wav = np.concatenate([wav,np.zeros((self.duration-len(wav)))])[np.newaxis,:]
            else:
                ind = random.randint(0,len(wav)-self.duration)
                wav = wav[ind:ind+self.duration][np.newaxis,:]
            waves.append(wav)
            for bird in meta['bird']:
                birds[i][blist.index(bird)] = 1

        waves = np.concatenate(waves)

        return waves,birds

dataset = testData(all_list)
collator = BatchCollator(blist)
#dataloader = DataLoader(dataset,batch_size=1,shuffle=True,collate_fn=make_batch)
dataloader = DataLoader(dataset,batch_size=20,shuffle=True,collate_fn=collator)

In [4]:
'''
import torch
from IPython.display import display
from utils.logging import Averager
from torch.optim import lr_scheduler
from autoth.core import ScoreCalculatorExample, HyperParamsOptimizer

#learning_rate = 0.1 #for onecycle
learning_rate = 0.001  #for cosine

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epochs=15
model = Tmodel().to(torch.float32)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, 
        betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)

model.to(device)
model.train()

avg = Averager()
#scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=len(dataloader), epochs=epochs,
#                                        pct_start=0.2)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
dh = display('',display_id=True)
batch_max = len(dataloader)
loss_func = BCEFocal2WayLoss()

############################
##### auto threshold ######
#############################
score_calculator = ScoreCalculatorExample(dataloader.batch_size,classes_num)
init_params = torch.Tensor([0.3]*classes_num).to(device)
hyper_params_opt = HyperParamsOptimizer(score_calculator, 
    learning_rate=1e-2, epochs=10, step=0.01)

for epoch in range(epochs):
    #train
    for batch, data in enumerate(dataloader):
        wav, bird = data
        
        wav = torch.from_numpy(wav).to(torch.float32)
        wav = wav.to(device)
        
        bird_smooth = np.where(bird==1,0.995,0.0025)
        bird_smooth = torch.from_numpy(bird_smooth).to(torch.float32).to(device)
        output_dict = model(wav)
        
        loss = loss_func(output_dict, bird_smooth)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        scheduler.step()
        avg.add(loss)
        if batch%20==0:
            opt_score, init_params = hyper_params_opt.do_optimize(init_params, output_dict['clipwise_output'],\
                                                                  torch.from_numpy(bird).to(device))
        dh.update('Epoch : {} {}/{} loss : {:4f} / lr : {:4f} / auto_f1_score : {:4f} / cur_f1 : {:4f}'.format(\
                                epoch+1,batch+1,batch_max,avg.val(),\
                                optimizer.param_groups[0]['lr'], opt_score,\
                                score_calculator(init_params,output_dict['clipwise_output'], torch.from_numpy(bird).to(device))))
        
        
        del wav, bird, loss, output_dict, data
        torch.save(model.state_dict(),os.path.join('./result/sed_auto_th.pth'))
    #eval
    '''

import torch
from IPython.display import display
from utils.logging import Averager
from torch.optim import lr_scheduler
from autoth.core import ScoreCalculatorExample, HyperParamsOptimizer
import pickle


#learning_rate = 0.1 #for onecycle
learning_rate = 0.001  #for cosine

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epochs=15
model = Tmodel().to(torch.float32)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, 
        betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)

saved_model = './result/sed_v3.pth'
use_saved= True
if os.path.exists(saved_model) and use_saved:
    model.load_state_dict(torch.load(saved_model,map_location=device))
model.to(device)    
model.train()

avg = Averager()
f1_avg= Averager()
#scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=len(dataloader), epochs=epochs,
#                                        pct_start=0.2)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
dh = display('',display_id=True)
batch_max = len(dataloader)
loss_func = BCEFocal2WayLoss()

############################
##### auto threshold ######
#############################
score_calculator = ScoreCalculatorExample(dataloader.batch_size,classes_num)
init_params = torch.Tensor([0.3]*classes_num).to(device)
hyper_params_opt = HyperParamsOptimizer(score_calculator, 
    learning_rate=1e-2, epochs=10, step=0.01)

with open('./result/sed_v3.pkl','rb') as f:
    init_params = pickle.load(f)
init_params = init_params.to(torch.float).to(device)

opt_score = 0
for epoch in range(epochs):
    #train
    for batch, data in enumerate(dataloader):
        wav, bird = data
        
        wav = torch.from_numpy(wav).to(torch.float32)
        wav = wav.to(device)
        
        bird_smooth = np.where(bird==1,0.995,0.0025)
        bird_smooth = torch.from_numpy(bird_smooth).to(torch.float32).to(device)
        output_dict = model(wav)
        
        loss = loss_func(output_dict, bird_smooth)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        scheduler.step()
        avg.add(loss)
        with torch.no_grad():
            if batch%20==0:
                bef_score = opt_score
                opt_score, init_params = hyper_params_opt.do_optimize(init_params, output_dict['clipwise_output'],\
                                                                      torch.from_numpy(bird).to(device))
                if bef_score < opt_score:
                    with open('./result/sed_v3.pkl','wb') as f:
                        pickle.dump(init_params,f)
                
            f1 = score_calculator(init_params,output_dict['clipwise_output'], torch.from_numpy(bird).to(device))
            f1_avg.add(f1)
            dh.update('Epoch : {} {}/{} loss : {:4f} / lr : {:4f} / auto_f1_score : {:4f} / cur_f1 : {:4f}'.format(\
                                    epoch+1,batch+1,batch_max,avg.val(),\
                                    optimizer.param_groups[0]['lr'], opt_score,\
                                    f1_avg.val()))
            
        del wav, bird, loss, output_dict, data
        torch.save(model.state_dict(),os.path.join('./result/sed_v3.pth'))
    #eval

Using cache found in /root/.cache/torch/hub/zhanghang1989_ResNeSt_master


'Epoch : 4 309/3144 loss : 0.005671 / lr : 0.000976 / auto_f1_score : 0.007692 / cur_f1 : 0.007085'

KeyboardInterrupt: 

In [None]:

def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            
    
def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


def init_gru(rnn):
    """Initialize a GRU layer. """
    
    def _concat_init(tensor, init_funcs):
        (length, fan_out) = tensor.shape
        fan_in = length // len(init_funcs)
    
        for (i, init_func) in enumerate(init_funcs):
            init_func(tensor[i * fan_in : (i + 1) * fan_in, :])
        
    def _inner_uniform(tensor):
        fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
        nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
    
    for i in range(rnn.num_layers):
        _concat_init(
            getattr(rnn, 'weight_ih_l{}'.format(i)),
            [_inner_uniform, _inner_uniform, _inner_uniform]
        )
        torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0)

        _concat_init(
            getattr(rnn, 'weight_hh_l{}'.format(i)),
            [_inner_uniform, _inner_uniform, nn.init.orthogonal_]
        )
        torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0)
        
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super(ConvBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

        
    def forward(self, input, pool_size=(2, 2), pool_type='avg'):
        
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')
        
        return x
    

class AttBlockV2(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear"):
        super().__init__()

        self.activation = activation
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)

    

class Tmodel(nn.Module):
    def __init__(self,train=True):
        super(Tmodel,self).__init__()
        
        SPEC_HEIGHT = 128
        SPEC_WIDTH = 256
        NUM_MELS = SPEC_HEIGHT
        HOP_LENGTH = int(32000 * 5 / (SPEC_WIDTH - 1)) # sample rate * duration / spec width - 1 == 627
        FMIN = 500
        FMAX = 12500
        classes_num = 398
        self.interpolate_ratio = 8
        
        self.spectrogram_extractor  = Spectrogram(
                    n_fft=2048,
                    hop_length=HOP_LENGTH,
                    freeze_parameters=True)
        
        self.logmel_extractor = LogmelFilterBank(sr=32000,
            n_mels=NUM_MELS, fmin=FMIN, fmax=FMAX, freeze_parameters=True)

        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 
            freq_drop_width=8, freq_stripes_num=2)
        
        self.bn0 = nn.BatchNorm2d(128)

        # load pretrained models, using ResNeSt-50 as an example
        if train:
            base_model = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True)
        else:
            base_model = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=False)
            
        layers = list(base_model.children())[:-2]
        self.encoder = nn.Sequential(*layers)

        if hasattr(base_model, "fc"):
            in_features = base_model.fc.in_features
        else:
            in_features = base_model.classifier.in_features
        self.fc1 = nn.Linear(in_features, in_features, bias=True)
        
        self.att_block = AttBlockV2(in_features, classes_num, activation='sigmoid')
        self.init_weights()
        
    def init_weights(self):
        init_layer(self.fc1)
        init_bn(self.bn0)
        
    def forward(self,input,mixup_lambda=None):
        
        x = self.spectrogram_extractor(input)   # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)    # (batch_size, 1, time_steps, mel_bins)
        
        frames_num = x.shape[2]
        
        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        
        if self.training:
            x = self.spec_augmenter(x)
        
        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)
        
        x = torch.tile(x,(1,3,1,1))
        x = x.transpose(2, 3)
        # (batch_size, channels, freq, frames)
        x = self.encoder(x)

        # (batch_size, channels, frames)
        x = torch.mean(x, dim=2)

        # channel smoothing
        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2

        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc1(x))
        x = x.transpose(1, 2)
        x = F.dropout(x, p=0.5, training=self.training)
        (clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
        logit = torch.sum(norm_att * self.att_block.cla(x), dim=2)
        segmentwise_logit = self.att_block.cla(x).transpose(1, 2)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        interpolate_ratio = frames_num // segmentwise_output.size(1)

        # Get framewise output
        framewise_output = interpolate(segmentwise_output,
                                       interpolate_ratio)
        framewise_output = pad_framewise_output(framewise_output, frames_num)

        framewise_logit = interpolate(segmentwise_logit, interpolate_ratio)
        framewise_logit = pad_framewise_output(framewise_logit, frames_num)

        output_dict = {
            "framewise_output": framewise_output,
            "segmentwise_output": segmentwise_output,
            "logit": logit,
            "framewise_logit": framewise_logit,
            "clipwise_output": clipwise_output
        }

        return output_dict
    
def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.
    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled


def pad_framewise_output(framewise_output: torch.Tensor, frames_num: int):
    """Pad framewise_output to the same length as input frames. The pad value
    is the same as the value of the last frame.
    Args:
      framewise_output: (batch_size, frames_num, classes_num)
      frames_num: int, number of frames to pad
    Outputs:
      output: (batch_size, frames_num, classes_num)
    """
    output = F.interpolate(
        framewise_output.unsqueeze(1),
        size=(frames_num, framewise_output.size(2)),
        align_corners=True,
        mode="bilinear").squeeze(1)

    return output