In [1]:
!pip install "../input/timmpytorch/timm-0.2.1-py3-none-any.whl"

Processing /kaggle/input/timmpytorch/timm-0.2.1-py3-none-any.whl
Installing collected packages: timm
Successfully installed timm-0.2.1


In [2]:
import os, random, time, warnings, logging
import librosa

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from contextlib import contextmanager
from IPython.display import Audio
from pathlib import Path
from typing import Optional, List
from pathlib import Path
from fastprogress import progress_bar


import timm
from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
    tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns, tf_efficientnet_b0_ns
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

cuda


In [4]:
ROOT = Path.cwd().parent
INPUT_ROOT = ROOT / "input"
RAW_DATA = INPUT_ROOT / "birdsong-recognition"
TRAIN_AUDIO_DIR = RAW_DATA / "train_audio"
TRAIN_RESAMPLED_AUDIO_DIRS = [
  INPUT_ROOT / "birdsong-resampled-train-audio-{:0>2}".format(i)  for i in range(5)
]
TEST_AUDIO_DIR = RAW_DATA / "test_audio"

In [5]:
if not TEST_AUDIO_DIR.exists():
    TEST_AUDIO_DIR = INPUT_ROOT / "birdcall-check" / "test_audio"
    test = pd.read_csv(INPUT_ROOT / "birdcall-check" / "test.csv")
else:
    test = pd.read_csv(RAW_DATA / "test.csv")

In [6]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore
    
    
def get_logger(out_file=None):
    logger = logging.getLogger()
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    logger.handlers = []
    logger.setLevel(logging.INFO)

    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)

    if out_file is not None:
        fh = logging.FileHandler(out_file)
        fh.setFormatter(formatter)
        fh.setLevel(logging.INFO)
        logger.addHandler(fh)
    logger.info("logger set up")
    return logger
    
    
@contextmanager
def timer(name: str, logger: Optional[logging.Logger] = None):
    t0 = time.time()
    msg = f"[{name}] start"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)
    yield

    msg = f"[{name}] done in {time.time() - t0:.2f} s"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)
    
    
set_seed(1213)

In [7]:
class DFTBase(nn.Module):
    def __init__(self):
        """Base class for DFT and IDFT matrix"""
        super(DFTBase, self).__init__()

    def dft_matrix(self, n):
        (x, y) = np.meshgrid(np.arange(n), np.arange(n))
        omega = np.exp(-2 * np.pi * 1j / n)
        W = np.power(omega, x * y)
        return W

    def idft_matrix(self, n):
        (x, y) = np.meshgrid(np.arange(n), np.arange(n))
        omega = np.exp(2 * np.pi * 1j / n)
        W = np.power(omega, x * y)
        return W
    
    
class STFT(DFTBase):
    def __init__(self, n_fft=2048, hop_length=None, win_length=None, 
        window='hann', center=True, pad_mode='reflect', freeze_parameters=True):
        """Implementation of STFT with Conv1d. The function has the same output 
        of librosa.core.stft
        """
        super(STFT, self).__init__()

        assert pad_mode in ['constant', 'reflect']

        self.n_fft = n_fft
        self.center = center
        self.pad_mode = pad_mode

        # By default, use the entire frame
        if win_length is None:
            win_length = n_fft

        # Set the default hop, if it's not already specified
        if hop_length is None:
            hop_length = int(win_length // 4)

        fft_window = librosa.filters.get_window(window, win_length, fftbins=True)

        # Pad the window out to n_fft size
        fft_window = librosa.util.pad_center(fft_window, n_fft)

        # DFT & IDFT matrix
        self.W = self.dft_matrix(n_fft)

        out_channels = n_fft // 2 + 1

        self.conv_real = nn.Conv1d(in_channels=1, out_channels=out_channels, 
            kernel_size=n_fft, stride=hop_length, padding=0, dilation=1, 
            groups=1, bias=False)

        self.conv_imag = nn.Conv1d(in_channels=1, out_channels=out_channels, 
            kernel_size=n_fft, stride=hop_length, padding=0, dilation=1, 
            groups=1, bias=False)

        self.conv_real.weight.data = torch.Tensor(
            np.real(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :]
        # (n_fft // 2 + 1, 1, n_fft)

        self.conv_imag.weight.data = torch.Tensor(
            np.imag(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :]
        # (n_fft // 2 + 1, 1, n_fft)

        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, input):
        """input: (batch_size, data_length)
        Returns:
          real: (batch_size, n_fft // 2 + 1, time_steps)
          imag: (batch_size, n_fft // 2 + 1, time_steps)
        """

        x = input[:, None, :]   # (batch_size, channels_num, data_length)

        if self.center:
            x = F.pad(x, pad=(self.n_fft // 2, self.n_fft // 2), mode=self.pad_mode)

        real = self.conv_real(x)
        imag = self.conv_imag(x)
        # (batch_size, n_fft // 2 + 1, time_steps)

        real = real[:, None, :, :].transpose(2, 3)
        imag = imag[:, None, :, :].transpose(2, 3)
        # (batch_size, 1, time_steps, n_fft // 2 + 1)

        return real, imag
    
    
class Spectrogram(nn.Module):
    def __init__(self, n_fft=2048, hop_length=None, win_length=None, 
        window='hann', center=True, pad_mode='reflect', power=2.0, 
        freeze_parameters=True):
        """Calculate spectrogram using pytorch. The STFT is implemented with 
        Conv1d. The function has the same output of librosa.core.stft
        """
        super(Spectrogram, self).__init__()

        self.power = power

        self.stft = STFT(n_fft=n_fft, hop_length=hop_length, 
            win_length=win_length, window=window, center=center, 
            pad_mode=pad_mode, freeze_parameters=True)

    def forward(self, input):
        """input: (batch_size, 1, time_steps, n_fft // 2 + 1)
        Returns:
          spectrogram: (batch_size, 1, time_steps, n_fft // 2 + 1)
        """

        (real, imag) = self.stft.forward(input)
        # (batch_size, n_fft // 2 + 1, time_steps)

        spectrogram = real ** 2 + imag ** 2

        if self.power == 2.0:
            pass
        else:
            spectrogram = spectrogram ** (power / 2.0)

        return spectrogram

    
class LogmelFilterBank(nn.Module):
    def __init__(self, sr=32000, n_fft=2048, n_mels=64, fmin=50, fmax=14000, is_log=True, 
        ref=1.0, amin=1e-10, top_db=80.0, freeze_parameters=True):
        """Calculate logmel spectrogram using pytorch. The mel filter bank is 
        the pytorch implementation of as librosa.filters.mel 
        """
        super(LogmelFilterBank, self).__init__()

        self.is_log = is_log
        self.ref = ref
        self.amin = amin
        self.top_db = top_db

        self.melW = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels,
            fmin=fmin, fmax=fmax).T
        # (n_fft // 2 + 1, mel_bins)

        self.melW = nn.Parameter(torch.Tensor(self.melW))

        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, input):
        """input: (batch_size, channels, time_steps)
        
        Output: (batch_size, time_steps, mel_bins)
        """

        # Mel spectrogram
        mel_spectrogram = torch.matmul(input, self.melW)

        # Logmel spectrogram
        if self.is_log:
            output = self.power_to_db(mel_spectrogram)
        else:
            output = mel_spectrogram

        return output


    def power_to_db(self, input):
        """Power to db, this function is the pytorch implementation of 
        librosa.core.power_to_lb
        """
        ref_value = self.ref
        log_spec = 10.0 * torch.log10(torch.clamp(input, min=self.amin, max=np.inf))
        log_spec -= 10.0 * np.log10(np.maximum(self.amin, ref_value))

        if self.top_db is not None:
            if self.top_db < 0:
                raise ParameterError('top_db must be non-negative')
            log_spec = torch.clamp(log_spec, min=log_spec.max().item() - self.top_db, max=np.inf)

        return log_spec
    
class DropStripes(nn.Module):
    def __init__(self, dim, drop_width, stripes_num):
        """Drop stripes. 
        Args:
          dim: int, dimension along which to drop
          drop_width: int, maximum width of stripes to drop
          stripes_num: int, how many stripes to drop
        """
        super(DropStripes, self).__init__()

        assert dim in [2, 3]    # dim 2: time; dim 3: frequency

        self.dim = dim
        self.drop_width = drop_width
        self.stripes_num = stripes_num

    def forward(self, input):
        """input: (batch_size, channels, time_steps, freq_bins)"""

        assert input.ndimension() == 4

        if self.training is False:
            return input

        else:
            batch_size = input.shape[0]
            total_width = input.shape[self.dim]

            for n in range(batch_size):
                self.transform_slice(input[n], total_width)

            return input


    def transform_slice(self, e, total_width):
        """e: (channels, time_steps, freq_bins)"""

        for _ in range(self.stripes_num):
            distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
            bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]

            if self.dim == 2:
                e[:, bgn : bgn + distance, :] = 0
            elif self.dim == 3:
                e[:, :, bgn : bgn + distance] = 0


class SpecAugmentation(nn.Module):
    def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, 
        freq_stripes_num):
        """Spec augmetation. 
        [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. 
        and Le, Q.V., 2019. Specaugment: A simple data augmentation method 
        for automatic speech recognition. arXiv preprint arXiv:1904.08779.
        Args:
          time_drop_width: int
          time_stripes_num: int
          freq_drop_width: int
          freq_stripes_num: int
        """

        super(SpecAugmentation, self).__init__()

        self.time_dropper = DropStripes(dim=2, drop_width=time_drop_width, 
            stripes_num=time_stripes_num)

        self.freq_dropper = DropStripes(dim=3, drop_width=freq_drop_width, 
            stripes_num=freq_stripes_num)

    def forward(self, input):
        x = self.time_dropper(input)
        x = self.freq_dropper(x)
        return x

In [8]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [9]:
def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
 
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            
    
def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super(ConvBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.conv2 = nn.Conv2d(in_channels=out_channels, 
                              out_channels=out_channels,
                              kernel_size=(3, 3), stride=(1, 1),
                              padding=(1, 1), bias=False)
                              
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

        
    def forward(self, input, pool_size=(2, 2), pool_type='avg'):
        
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')
        
        return x


class ConvBlock5x5(nn.Module):
    def __init__(self, in_channels, out_channels):
        
        super(ConvBlock5x5, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(5, 5), stride=(1, 1),
                              padding=(2, 2), bias=False)
                              
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_bn(self.bn1)

        
    def forward(self, input, pool_size=(2, 2), pool_type='avg'):
        
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')
        
        return x


class AttBlock(nn.Module):
    def __init__(self, n_in, n_out, activation='linear', temperature=1.):
        super(AttBlock, self).__init__()
        
        self.activation = activation
        self.temperature = temperature
        self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
        self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
        
        self.bn_att = nn.BatchNorm1d(n_out)
        self.init_weights()
        
    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)
        init_bn(self.bn_att)
         
    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)


class Cnn14_16k(nn.Module):
    def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num):
        
        super(Cnn14_16k, self).__init__() 

        assert sample_rate == 16000
        assert window_size == 512
        assert hop_size == 160
        assert mel_bins == 64
        assert fmin == 50
        assert fmax == 8000

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        top_db = None

        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 
            win_length=window_size, window=window, center=center, pad_mode=pad_mode, 
            freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 
            n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 
            freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 
            freq_drop_width=8, freq_stripes_num=2)

        self.bn0 = nn.BatchNorm2d(64)

        self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
        self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
        self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
        self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
        self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
        self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)

        self.fc1 = nn.Linear(2048, 2048, bias=True)
        self.fc_audioset1 = nn.Linear(2048, classes_num, bias=True)
        
        self.init_weight()

    def init_weight(self):
        init_bn(self.bn0)
        init_layer(self.fc1)
        init_layer(self.fc_audioset1)
 
    def forward(self, input, mixup_lambda=None):
        """
        Input: (batch_size, data_length)"""
        
        #x = input

        x = self.spectrogram_extractor(input)   # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)    # (batch_size, 1, time_steps, mel_bins)

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        
        #if self.training:
        #    x = self.spec_augmenter(x)

        # Mixup on spectrogram
        #if self.training and mixup_lambda is not None
        #    x = do_mixup(x, mixup_lambda)
        
        x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.mean(x, dim=3)
        
        (x1, _) = torch.max(x, dim=2)
        x2 = torch.mean(x, dim=2)
        x = x1 + x2
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu_(self.fc1(x))
        #embedding = F.dropout(x, p=0.5, training=self.training)
        clipwise_output = torch.sigmoid(self.fc_audioset1(x))
        
        #output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding}

        return clipwise_output

In [10]:
encoder_params = {
    "resnest50d" : {
        "features" : 2048,
        "init_op"  : partial(timm.models.resnest50d, pretrained=False, in_chans=1)
    },
    "densenet201" : {
        "features": 1920,
        "init_op": partial(timm.models.densenet201, pretrained=False)
    },
    "dpn92" : {
        "features": 2688,
        "init_op": partial(timm.models.dpn92, pretrained=False)
    },
    "dpn131": {
        "features": 2688,
        "init_op": partial(timm.models.dpn131, pretrained=False)
    },
    "tf_efficientnet_b0_ns": {
        "features": 1280,
        "init_op": partial(tf_efficientnet_b0_ns, pretrained=False, drop_path_rate=0.2)
    },
    "tf_efficientnet_b3_ns": {
        "features": 1536,
        "init_op": partial(tf_efficientnet_b3_ns, pretrained=False, drop_path_rate=0.2, in_chans=1)
    },
    "tf_efficientnet_b2_ns": {
        "features": 1408,
        "init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2, in_chans=1)
    },
    "tf_efficientnet_b4_ns": {
        "features": 1792,
        "init_op": partial(tf_efficientnet_b4_ns, pretrained=False, drop_path_rate=0.2, in_chans=1)
    },
    "tf_efficientnet_b5_ns": {
        "features": 2048,
        "init_op": partial(tf_efficientnet_b5_ns, pretrained=False, drop_path_rate=0.2, in_chans=1)
    }
}


class BirdClassifier(nn.Module):
    def __init__(self, encoder, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num):
        super().__init__()

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        top_db = None
        
        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 
            win_length=window_size, window=window, center=center, pad_mode=pad_mode, 
            freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 
            n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 
            freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 
            freq_drop_width=8, freq_stripes_num=2)

        self.encoder = encoder_params[encoder]["init_op"]()
        self.avg_pool = AdaptiveAvgPool2d((1, 1))
        self.dropout = Dropout(0.2)
        self.fc = Linear(encoder_params[encoder]['features'], classes_num)

    def forward(self, input, spec_aug=False, mixup_lambda=None):

        x = self.spectrogram_extractor(input)   # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)    # (batch_size, 1, time_steps, mel_bins)

        #x = input
        
        if spec_aug:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)

        x = self.encoder.forward_features(x)
        x = self.avg_pool(x).flatten(1)
        x = self.dropout(x)
        x = self.fc(x)
        return torch.sigmoid(x)

In [11]:
%%time

cnn14_config = {
    "sample_rate": 16000,
    "window_size": 512,
    "hop_size": 160,
    "mel_bins": 64,
    "fmin": 50,
    "fmax": 8000,
    "classes_num": 264
}

cnn14_stage3 = Cnn14_16k(**cnn14_config)
cnn14_stage3.load_state_dict(torch.load("../input/stage3-final-models/cnn14-fold-0.bin", map_location=device))
cnn14_stage3 = cnn14_stage3.to(device)
cnn14_stage3.eval()

print("done")

done
CPU times: user 2.85 s, sys: 1.71 s, total: 4.55 s
Wall time: 10.7 s


In [12]:
%%time

resnest50d_config = {
    "encoder" : "resnest50d",
    "sample_rate": 16000,
    "window_size": 512,
    "hop_size": 160,
    "mel_bins": 64,
    "fmin": 50,
    "fmax": 8000,
    "classes_num": 264
}

resnest50d_stage3 = BirdClassifier(**resnest50d_config)
resnest50d_stage3.load_state_dict(torch.load("../input/stage3-final-models/resnest50d-fold-1.bin", map_location=device))
resnest50d_stage3 = resnest50d_stage3.to(device)
resnest50d_stage3.eval()

print("done")

done
CPU times: user 897 ms, sys: 125 ms, total: 1.02 s
Wall time: 1.82 s


In [13]:
%%time

eff3_config = {
    "encoder" : "tf_efficientnet_b3_ns",
    "sample_rate": 16000,
    "window_size": 512,
    "hop_size": 160,
    "mel_bins": 64,
    "fmin": 50,
    "fmax": 8000,
    "classes_num": 264
}


eff3_stage3 = BirdClassifier(**eff3_config)
eff3_stage3.load_state_dict(torch.load("../input/stage3-final-models/eff3-fold-1.bin", map_location=device))
eff3_stage3 =  eff3_stage3.to(device)
eff3_stage3.eval()

eff3_stage3_f1 = BirdClassifier(**eff3_config)
eff3_stage3_f1.load_state_dict(torch.load("../input/stage3-final-models/eff3-fold-3.bin", map_location=device))
eff3_stage3_f1 =  eff3_stage3_f1.to(device)
eff3_stage3_f1.eval()

print("done")

done
CPU times: user 1.37 s, sys: 59.9 ms, total: 1.43 s
Wall time: 2.21 s


In [14]:
%%time

eff4_config = {
    "encoder" : "tf_efficientnet_b4_ns",
    "sample_rate": 16000,
    "window_size": 512,
    "hop_size": 160,
    "mel_bins": 64,
    "fmin": 50,
    "fmax": 8000,
    "classes_num": 264
}

eff4_stage3 = BirdClassifier(**eff4_config)
eff4_stage3.load_state_dict(torch.load("../input/stage3-final-models/eff4-fold-1.bin", map_location=device))
eff4_stage3 =  eff4_stage3.to(device)
eff4_stage3.eval()


eff4_stage3_f1 = BirdClassifier(**eff4_config)
eff4_stage3_f1.load_state_dict(torch.load("../input/stage3-final-models/eff4-fold-4.bin", map_location=device))
eff4_stage3_f1 =  eff4_stage3_f1.to(device)
eff4_stage3_f1.eval()

print("done")

done
CPU times: user 1.62 s, sys: 77.7 ms, total: 1.7 s
Wall time: 2.81 s


In [15]:
%%time

eff5_config = {
    "encoder" : "tf_efficientnet_b5_ns",
    "sample_rate": 16000,
    "window_size": 512,
    "hop_size": 160,
    "mel_bins": 120,
    "fmin": 50,
    "fmax": 8000,
    "classes_num": 264
}

eff5_stage3 = BirdClassifier(**eff5_config)
eff5_stage3.load_state_dict(torch.load("../input/stage3-final-models/eff5-fold-0.bin", map_location=device))
eff5_stage3 =  eff5_stage3.to(device)
eff5_stage3.eval()

print("done")

done
CPU times: user 1.1 s, sys: 82.6 ms, total: 1.18 s
Wall time: 2.04 s


In [16]:
SR = 16000
PERIOD = SR * 5
batch_size = 16

In [17]:
def prediction_for_clip(test_df: pd.DataFrame,
                        clip: np.ndarray,
                        threshold=0.5):
    
    clip = clip.astype(np.float32)
    
    y_torch = torch.from_numpy(clip)
    y_batch = list(y_torch.split(PERIOD))
    
    if y_batch[-1].shape[0] < PERIOD:
        last = torch.zeros(PERIOD)
        last[:y_batch[-1].shape[0]] = y_batch[-1]
        y_batch[-1] = last
    y_batch = torch.stack(y_batch)
    
    whole_size = y_batch.size(0)
    
    if whole_size % batch_size == 0:
        n_iter = whole_size // batch_size
    else:
        n_iter = whole_size // batch_size + 1
        
    site = test_df["site"].values[0]
    audio_id = test_df["audio_id"].values[0]
    
    if site in {"site_1", "site_2"}:
        
        #nocall_outs = []
        call_outs   = []
        
        for batch_i in range(n_iter):
            
            batch = y_batch[batch_i * batch_size:(batch_i + 1) * batch_size]
            
            with torch.no_grad():
                
                y = batch.to(device)
                
                cnn14_s3 = cnn14_stage3(y, None).data.cpu().numpy()

                resnet_s3 = resnest50d_stage3(y, None).data.cpu().numpy()

                eff3_s3 = eff3_stage3(y, None).data.cpu().numpy()
                eff3_s3_f1 = eff3_stage3_f1(y, None).data.cpu().numpy()

                eff4_s3 = eff4_stage3(y, None).data.cpu().numpy()
                eff4_s3_f1 = eff4_stage3_f1(y, None).data.cpu().numpy()

                eff5_s3 = eff5_stage3(y, None).data.cpu().numpy()
                
            
                call_out = np.mean([cnn14_s3, resnet_s3, eff3_s3, eff3_s3_f1, eff4_s3, eff4_s3_f1, eff5_s3], axis=0).tolist()
                
                
                call_outs.extend(call_out)
                
        
        ebirds = []
        
        call_outs = np.array(call_outs)
        
        events = call_outs >= threshold
        
        for i in range(len(events)):
            event = events[i, :]
            labels = np.argwhere(event).reshape(-1).tolist()
            
            if len(labels) == 0:
                label = "nocall"
            else:
                label = set()
                for l in labels:
                    label.add(INV_BIRD_CODE[l])
                label = " ".join(list(label))
            
            ebirds.append(label)
        
        
        row_ids = [f'{site}_{audio_id}_{i*5}' for i in range(1, len(call_outs)+1)]
        
        pred_df = pd.DataFrame({
            "row_id" : row_ids,
            "birds" : ebirds,
            #"nocall" : nocall_outs
        })
    
    else:
        
        call_outs   = []
        
        for batch_i in range(n_iter):
            
            batch = y_batch[batch_i * batch_size:(batch_i + 1) * batch_size]
            
            with torch.no_grad():
                
                y = batch.to(device)
                
                
                cnn14_s3 = cnn14_stage3(y, None).data.cpu().numpy()

                resnet_s3 = resnest50d_stage3(y, None).data.cpu().numpy()

                eff3_s3 = eff3_stage3(y, None).data.cpu().numpy()
                eff3_s3_f1 = eff3_stage3_f1(y, None).data.cpu().numpy()

                eff4_s3 = eff4_stage3(y, None).data.cpu().numpy()
                eff4_s3_f1 = eff4_stage3_f1(y, None).data.cpu().numpy()

                eff5_s3 = eff5_stage3(y, None).data.cpu().numpy()
                
            
                call_out = np.mean([cnn14_s3, resnet_s3, eff3_s3, eff3_s3_f1, eff4_s3, eff4_s3_f1, eff5_s3], axis=0).tolist()
                
                
                call_outs.extend(call_out)
        
        ebirds = []
        
        call_outs = np.array(call_outs)
        
        events = call_outs >= threshold
        
        label_set = set()
        
        for i in range(len(events)):
            event = events[i, :]
            labels = np.argwhere(event).reshape(-1).tolist()
            
            if len(labels) != 0:
                for l in labels:
                    label_set.add(INV_BIRD_CODE[l])
        
        label_list = list(label_set)
        
        if len(label_list) != 0:
            ebirds = " ".join(label_list)
        else:
            ebirds = "nocall"
            
        row_ids = test_df["row_id"].values[0]
        
        pred_df = pd.DataFrame({
            "row_id" : [row_ids],
            "birds" : [ebirds],
            #"nocall" : [0]
        })
        
    return pred_df

In [18]:
def prediction(test_df: pd.DataFrame,
               test_audio: Path,
               threshold=0.5):
    
    unique_audio_id = test_df.audio_id.unique()
    
    warnings.filterwarnings("ignore")
    prediction_dfs = []
    
    for audio_id in progress_bar(unique_audio_id):
       
        clip, _ = librosa.load(test_audio / (audio_id + ".mp3"),
                               sr=SR,
                               mono=True,
                               res_type="kaiser_fast")
        
        
        #clip, _ = librosa.load("../input/birdsong-recognition/example_test_audio/BLKFR-10-CPL_20190611_093000.pt540.mp3", sr=SR, mono=True, res_type="kaiser_fast")
        
        test_df_for_audio_id = test_df.query(f"audio_id == '{audio_id}'").reset_index(drop=True)
        
        
        prediction_df = prediction_for_clip(test_df_for_audio_id,clip=clip, threshold=threshold)

        prediction_dfs.append(prediction_df)
        
        #break
     
    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    
    return prediction_df

In [19]:
prediction_df = prediction(
    test_df=test,
    test_audio=TEST_AUDIO_DIR,
    threshold=0.3
)

In [20]:
prediction_df.birds.value_counts()

aldfly                                9
nocall                                6
chswar aldfly                         3
amerob aldfly btnwar                  3
gockin aldfly                         2
                                     ..
canwar aldfly                         1
grcfly aldfly olsfly rethaw           1
hamfly cedwax btnwar aldfly bkbwar    1
olsfly                                1
aldfly olsfly                         1
Name: birds, Length: 62, dtype: int64

In [21]:
prediction_df

Unnamed: 0,row_id,birds
0,site_1_41e6fe6504a34bf6846938ba78d13df1_5,chswar aldfly
1,site_1_41e6fe6504a34bf6846938ba78d13df1_10,aldfly wlswar camwar amered
2,site_1_41e6fe6504a34bf6846938ba78d13df1_15,moudov aldfly amered
3,site_1_41e6fe6504a34bf6846938ba78d13df1_20,camwar amered wlswar chswar bkbwar
4,site_1_41e6fe6504a34bf6846938ba78d13df1_25,balori aldfly
...,...,...
81,site_3_9cc5d9646f344f1bbb52640a988fe902,amecro comyel pinwar solsan comloo sposan bkts...
82,site_3_a56e20a518684688a9952add8a9d5213,canwar yerwar pasfly btnwar norwat amered brwh...
83,site_3_96779836288745728306903d54e264dd,hamfly cedwax btnwar aldfly bkbwar
84,site_3_f77783ba4c6641bc918b034a18c23e53,hoowar amegfi yebfly pinsis aldfly


In [22]:
all_row_id = test[["row_id"]]
submission = all_row_id.merge(prediction_df, on="row_id", how="left")

In [23]:
submission = submission.fillna("nocall")
submission.to_csv("submission.csv", index=False)

In [24]:
submission.birds.value_counts()

aldfly                                                                                                       7
chswar aldfly                                                                                                3
gockin aldfly                                                                                                2
btnwar amerob bkbwar                                                                                         2
reevir1 aldfly                                                                                               2
nocall                                                                                                       2
amerob btnwar                                                                                                2
amerob aldfly btnwar                                                                                         2
robgro amerob aldfly                                                                                         2
c