In [1]:
!pip uninstall torch --y
!pip uninstall torchaudio --y

Found existing installation: torch 1.5.1
Uninstalling torch-1.5.1:
  Successfully uninstalled torch-1.5.1
Found existing installation: torchaudio 0.5.0a0+738ccba
Uninstalling torchaudio-0.5.0a0+738ccba:
  Successfully uninstalled torchaudio-0.5.0a0+738ccba


In [2]:
# Newest torch required for tracing. Torchaudio for data loading
!pip install ../input/pytorch-160-with-torchvision-070/torch-1.6.0cu101-cp37-cp37m-linux_x86_64.whl
!pip install ../input/torchaudio/torchaudio-0.6.0-cp37-cp37m-manylinux1_x86_64.whl
!pip install ../input/bird-panns/torchlibrosa-master/torchlibrosa-master/ > /dev/null

import os
import gc
import time
import math
import shutil
import random
import warnings
import typing as tp
from pathlib import Path
from contextlib import contextmanager

import yaml
from joblib import delayed, Parallel

import cv2
import librosa
import audioread
import soundfile as sf

import numpy as np
import pandas as pd

from fastprogress import progress_bar
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold

import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU
from torch.nn.modules.utils import _pair
import torch.utils.data as data
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
#from efficientnet_pytorch import EfficientNet


pd.options.display.max_rows = 500
pd.options.display.max_columns = 500

Processing /kaggle/input/pytorch-160-with-torchvision-070/torch-1.6.0cu101-cp37-cp37m-linux_x86_64.whl
Installing collected packages: torch
[31mERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.

We recommend you use --use-feature=2020-resolver to test your packages with the new resolver before it becomes the default.

kornia 0.3.2 requires torch<1.6.0,>=1.5.0, but you'll have torch 1.6.0+cu101 which is incompatible.
allennlp 1.0.0 requires torch<1.6.0,>=1.5.0, but you'll have torch 1.6.0+cu101 which is incompatible.[0m
Successfully installed torch-1.6.0+cu101
Processing /kaggle/input/torchaudio/torchaudio-0.6.0-cp37-cp37m-manylinux1_x86_64.whl
Installing collected packages: torchaudio
Successfully installed torchaudio-0.6.0


In [3]:
args = {
    'n_mels': 224, #224
    'hop_length': 384,
    'std_sr': 32000,
    'threshold': .55, #0.6
    'model': '../input/birdweights/c2_9_18.pt' # c2_9_18.pt (569), c2_0_41.pt (548), c3_4_17 (554)
}
model = torch.jit.load(args['model'], map_location='cpu').cuda()
model.eval();

root = os.path.join('../input', "birdsong-recognition")
# TRAIN_AUDIO_DIR = RAW_DATA / "train_audio"
test_dir = os.path.join(root, 'test_audio')
print(test_dir)

if os.path.isdir(test_dir):
    print('Running actual submission!')
    test_df = pd.read_csv(os.path.join(root, 'test.csv'))
    sub = pd.read_csv("../input/birdsong-recognition/sample_submission.csv")
else:
    print('Running debug submission')
    test_dir = os.path.join('../input', 'birdcall-check', 'test_audio')
    test_df = pd.read_csv(os.path.join('../input', 'birdcall-check', 'test.csv'))
    sub = pd.read_csv('../input/inference-pytorch-birdcall-resnet-baseline/submission.csv')
    print('Naturally no test dir found')
sub.to_csv("submission.csv", index=False)  # this will be overwritten if everything goes well. Hmnmm. will it?

../input/birdsong-recognition/test_audio
Running debug submission
Naturally no test dir found


In [4]:
BIRD_CODE = {
    'aldfly': 0, 'ameavo': 1, 'amebit': 2, 'amecro': 3, 'amegfi': 4,
    'amekes': 5, 'amepip': 6, 'amered': 7, 'amerob': 8, 'amewig': 9,
    'amewoo': 10, 'amtspa': 11, 'annhum': 12, 'astfly': 13, 'baisan': 14,
    'baleag': 15, 'balori': 16, 'banswa': 17, 'barswa': 18, 'bawwar': 19,
    'belkin1': 20, 'belspa2': 21, 'bewwre': 22, 'bkbcuc': 23, 'bkbmag1': 24,
    'bkbwar': 25, 'bkcchi': 26, 'bkchum': 27, 'bkhgro': 28, 'bkpwar': 29,
    'bktspa': 30, 'blkpho': 31, 'blugrb1': 32, 'blujay': 33, 'bnhcow': 34,
    'boboli': 35, 'bongul': 36, 'brdowl': 37, 'brebla': 38, 'brespa': 39,
    'brncre': 40, 'brnthr': 41, 'brthum': 42, 'brwhaw': 43, 'btbwar': 44,
    'btnwar': 45, 'btywar': 46, 'buffle': 47, 'buggna': 48, 'buhvir': 49,
    'bulori': 50, 'bushti': 51, 'buwtea': 52, 'buwwar': 53, 'cacwre': 54,
    'calgul': 55, 'calqua': 56, 'camwar': 57, 'cangoo': 58, 'canwar': 59,
    'canwre': 60, 'carwre': 61, 'casfin': 62, 'caster1': 63, 'casvir': 64,
    'cedwax': 65, 'chispa': 66, 'chiswi': 67, 'chswar': 68, 'chukar': 69,
    'clanut': 70, 'cliswa': 71, 'comgol': 72, 'comgra': 73, 'comloo': 74,
    'commer': 75, 'comnig': 76, 'comrav': 77, 'comred': 78, 'comter': 79,
    'comyel': 80, 'coohaw': 81, 'coshum': 82, 'cowscj1': 83, 'daejun': 84,
    'doccor': 85, 'dowwoo': 86, 'dusfly': 87, 'eargre': 88, 'easblu': 89,
    'easkin': 90, 'easmea': 91, 'easpho': 92, 'eastow': 93, 'eawpew': 94,
    'eucdov': 95, 'eursta': 96, 'evegro': 97, 'fiespa': 98, 'fiscro': 99,
    'foxspa': 100, 'gadwal': 101, 'gcrfin': 102, 'gnttow': 103, 'gnwtea': 104,
    'gockin': 105, 'gocspa': 106, 'goleag': 107, 'grbher3': 108, 'grcfly': 109,
    'greegr': 110, 'greroa': 111, 'greyel': 112, 'grhowl': 113, 'grnher': 114,
    'grtgra': 115, 'grycat': 116, 'gryfly': 117, 'haiwoo': 118, 'hamfly': 119,
    'hergul': 120, 'herthr': 121, 'hoomer': 122, 'hoowar': 123, 'horgre': 124,
    'horlar': 125, 'houfin': 126, 'houspa': 127, 'houwre': 128, 'indbun': 129,
    'juntit1': 130, 'killde': 131, 'labwoo': 132, 'larspa': 133, 'lazbun': 134,
    'leabit': 135, 'leafly': 136, 'leasan': 137, 'lecthr': 138, 'lesgol': 139,
    'lesnig': 140, 'lesyel': 141, 'lewwoo': 142, 'linspa': 143, 'lobcur': 144,
    'lobdow': 145, 'logshr': 146, 'lotduc': 147, 'louwat': 148, 'macwar': 149,
    'magwar': 150, 'mallar3': 151, 'marwre': 152, 'merlin': 153, 'moublu': 154,
    'mouchi': 155, 'moudov': 156, 'norcar': 157, 'norfli': 158, 'norhar2': 159,
    'normoc': 160, 'norpar': 161, 'norpin': 162, 'norsho': 163, 'norwat': 164,
    'nrwswa': 165, 'nutwoo': 166, 'olsfly': 167, 'orcwar': 168, 'osprey': 169,
    'ovenbi1': 170, 'palwar': 171, 'pasfly': 172, 'pecsan': 173, 'perfal': 174,
    'phaino': 175, 'pibgre': 176, 'pilwoo': 177, 'pingro': 178, 'pinjay': 179,
    'pinsis': 180, 'pinwar': 181, 'plsvir': 182, 'prawar': 183, 'purfin': 184,
    'pygnut': 185, 'rebmer': 186, 'rebnut': 187, 'rebsap': 188, 'rebwoo': 189,
    'redcro': 190, 'redhea': 191, 'reevir1': 192, 'renpha': 193, 'reshaw': 194,
    'rethaw': 195, 'rewbla': 196, 'ribgul': 197, 'rinduc': 198, 'robgro': 199,
    'rocpig': 200, 'rocwre': 201, 'rthhum': 202, 'ruckin': 203, 'rudduc': 204,
    'rufgro': 205, 'rufhum': 206, 'rusbla': 207, 'sagspa1': 208, 'sagthr': 209,
    'savspa': 210, 'saypho': 211, 'scatan': 212, 'scoori': 213, 'semplo': 214,
    'semsan': 215, 'sheowl': 216, 'shshaw': 217, 'snobun': 218, 'snogoo': 219,
    'solsan': 220, 'sonspa': 221, 'sora': 222, 'sposan': 223, 'spotow': 224,
    'stejay': 225, 'swahaw': 226, 'swaspa': 227, 'swathr': 228, 'treswa': 229,
    'truswa': 230, 'tuftit': 231, 'tunswa': 232, 'veery': 233, 'vesspa': 234,
    'vigswa': 235, 'warvir': 236, 'wesblu': 237, 'wesgre': 238, 'weskin': 239,
    'wesmea': 240, 'wessan': 241, 'westan': 242, 'wewpew': 243, 'whbnut': 244,
    'whcspa': 245, 'whfibi': 246, 'whtspa': 247, 'whtswi': 248, 'wilfly': 249,
    'wilsni1': 250, 'wiltur': 251, 'winwre3': 252, 'wlswar': 253, 'wooduc': 254,
    'wooscj2': 255, 'woothr': 256, 'y00475': 257, 'yebfly': 258, 'yebsap': 259,
    'yehbla': 260, 'yelwar': 261, 'yerwar': 262, 'yetvir': 263
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [5]:
from torchaudio.transforms import *
fmin, fmax = 25, 11000
f1 = MelSpectrogram(sample_rate=args['std_sr'], n_fft=1024, hop_length=args['hop_length'], f_min=fmin, f_max=fmax, n_mels=args['n_mels']).cuda()
f2 = MelSpectrogram(sample_rate=args['std_sr'], n_fft=2048, hop_length=args['hop_length'], f_min=fmin, f_max=fmax, n_mels=args['n_mels']).cuda()
f3 = MelSpectrogram(sample_rate=args['std_sr'], n_fft=4096, hop_length=args['hop_length'], f_min=fmin, f_max=fmax, n_mels=args['n_mels']).cuda()

In [6]:
class TestDataset(data.Dataset):
    def __init__(self, df, spectrogram):
        self.df = df
        self.spectrogram = spectrogram
        
    def __len__(self):
        return len(self.df)
    
    def normalize(self, spect):
        spect = spect.contiguous()
        miu = spect.view(spect.size(0), -1).mean(-1).unsqueeze(-1).unsqueeze(-1)
        std = spect.view(spect.size(0), -1).std(-1).unsqueeze(-1).unsqueeze(-1) + 1e-8
        return (spect - miu) / std / 2
    
    def __getitem__(self, idx):
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id
        
        if site == "site_3":
            y = self.spectrogram
            len_y = y.shape[-1]
            start = 0
            end = args['std_sr'] * 5 // args['hop_length']
            y_all = []
            while len_y > start:
                y_batch = self.normalize(y[:, :, start:end])
                if end >= len_y:
                    clip_length = args['std_sr'] * 5 // args['hop_length']
                    y_batch = self.normalize(y[:, :, -clip_length:])
                    y_all.append(y_batch)
                    break
                start = end
                end = end + args['std_sr'] * 5 // args['hop_length']
                y_all.append(y_batch)
            y_all = torch.stack(y_all)
            return y_all, row_id, site
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)
            
            start_index = args['std_sr'] * start_seconds // args['hop_length']
            end_index = args['std_sr'] * end_seconds // args['hop_length']
            
            # Hmmm. so assuming the final index will always be less than length of whole sound file
            y = self.spectrogram[:, :, start_index:end_index]
            y = self.normalize(y)
        return y, row_id, site
    
#originally 0.5
def prediction_for_clip(dataset, threshold=0.54): 
    loader = data.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    prediction_dict = {}
    for image, row_id, site in loader:
        site = site[0]
        row_id = row_id[0]
        image = image.to(device).float()
        if site in {"site_1", "site_2"}:
            # print(row_id, image.shape)
            image = image.to(device).float()
            with torch.no_grad():
                prediction = model(image)
                proba = prediction.view(-1).sigmoid().cpu().numpy()

            events = proba >= threshold
            labels = np.argwhere(events).reshape(-1).tolist()
        else:
            # print(row_id, image.shape)
            # to avoid prediction on large batch
            image = image.squeeze(0)
            batch_size = 16
            whole_size = image.size(0)
            if whole_size % batch_size == 0:
                n_iter = whole_size // batch_size
            else:
                n_iter = whole_size // batch_size + 1
                
            all_events = set()
            for batch_i in range(n_iter):
                batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
                if batch.ndim == 3:
                    batch = batch.unsqueeze(0)

                batch = batch.to(device)
                with torch.no_grad():
                    prediction = model(batch)
                    proba = prediction.view(batch.size(0), -1).sigmoid().cpu().numpy()
                    
                events = proba >= threshold
                for i in range(len(events)):
                    event = events[i, :]
                    labels = np.argwhere(event).reshape(-1).tolist()
                    for label in labels:
                        all_events.add(label)
                        
            labels = list(all_events)
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
            labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict


unique_audio_id = test_df.audio_id.unique()
warnings.filterwarnings("ignore")
prediction_dfs = []
for audio_id in progress_bar(unique_audio_id):
    # Load the spectrogram
    signal, sample_rate = torchaudio.load(os.path.join(test_dir, audio_id+'.mp3'))

    signal = signal.mean(0).cuda()
    signal = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=args['std_sr'])(signal)
    spectrogram = torch.log((torch.stack([f1(signal), f2(signal), f3(signal)], axis=0)) + 1e-8).half().float().cpu()

    test_df_for_audio_id = test_df.query(
        f"audio_id == '{audio_id}'").reset_index(drop=True)
    dataset = TestDataset(test_df_for_audio_id, spectrogram)
    prediction_dict = prediction_for_clip(dataset=dataset, threshold=args['threshold'])
    row_id = list(prediction_dict.keys())
    birds = list(prediction_dict.values())
    prediction_df = pd.DataFrame({
        "row_id": row_id,
        "birds": birds
    })
    prediction_dfs.append(prediction_df)

prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
prediction_df['birds'].value_counts()

aldfly           51
nocall           19
comyel            2
amerob            1
aldfly hamfly     1
btnwar            1
aldfly comyel     1
Name: birds, dtype: int64

In [7]:
prediction_df.to_csv('submission.csv', index=False)
prediction_df

Unnamed: 0,row_id,birds
0,site_1_41e6fe6504a34bf6846938ba78d13df1_5,aldfly
1,site_1_41e6fe6504a34bf6846938ba78d13df1_10,aldfly
2,site_1_41e6fe6504a34bf6846938ba78d13df1_15,aldfly
3,site_1_41e6fe6504a34bf6846938ba78d13df1_20,nocall
4,site_1_41e6fe6504a34bf6846938ba78d13df1_25,aldfly
5,site_1_cce64fffafed40f2b2f3d3413ec1c4c2_5,aldfly
6,site_1_cce64fffafed40f2b2f3d3413ec1c4c2_10,aldfly
7,site_1_cce64fffafed40f2b2f3d3413ec1c4c2_15,aldfly
8,site_1_cce64fffafed40f2b2f3d3413ec1c4c2_20,nocall
9,site_1_cce64fffafed40f2b2f3d3413ec1c4c2_25,aldfly


for i in range(len(prediction_df)):
    if 'site_3' in prediction_df['row_id'].values[i]:
        prediction_df.birds.values[i] = 'nocall'
prediction_df.to_csv('submission.csv', index=False)
prediction_df