In [None]:
!pip install ../input/timm-pytorch-image-models/pytorch-image-models-master/

In [None]:
import os
import re
from typing import List, Dict, Optional

import cv2
import yaml
import timm
import librosa
import numpy as np
import pandas as pd
from pathlib import Path
import albumentations as album
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader

from utils import seed_everything
from easydict import EasyDict as edict

In [None]:
TARGET_SAMPLE_RATE = 32_000

BIRD_CODE = {
    'acafly': 0, 'acowoo': 1, 'aldfly': 2, 'ameavo': 3, 'amecro': 4, 'amegfi': 5, 'amekes': 6, 'amepip': 7, 'amered': 8,
    'amerob': 9, 'amewig': 10, 'amtspa': 11, 'andsol1': 12, 'annhum': 13, 'astfly': 14, 'azaspi1': 15, 'babwar': 16,
    'baleag': 17, 'balori': 18, 'banana': 19, 'banswa': 20, 'banwre1': 21, 'barant1': 22, 'barswa': 23, 'batpig1': 24,
    'bawswa1': 25, 'bawwar': 26, 'baywre1': 27, 'bbwduc': 28, 'bcnher': 29, 'belkin1': 30, 'belvir': 31, 'bewwre': 32,
    'bkbmag1': 33, 'bkbplo': 34, 'bkbwar': 35, 'bkcchi': 36, 'bkhgro': 37, 'bkmtou1': 38, 'bknsti': 39, 'blbgra1': 40,
    'blbthr1': 41, 'blcjay1': 42, 'blctan1': 43, 'blhpar1': 44, 'blkpho': 45, 'blsspa1': 46, 'blugrb1': 47, 'blujay': 48,
    'bncfly': 49, 'bnhcow': 50, 'bobfly1': 51, 'bongul': 52, 'botgra': 53, 'brbmot1': 54, 'brbsol1': 55, 'brcvir1': 56, 
    'brebla': 57, 'brncre': 58, 'brnjay': 59, 'brnthr': 60, 'brratt1': 61, 'brwhaw': 62, 'brwpar1': 63, 'btbwar': 64, 
    'btnwar': 65, 'btywar': 66, 'bucmot2': 67, 'buggna': 68, 'bugtan': 69, 'buhvir': 70, 'bulori': 71, 'burwar1': 72, 
    'bushti': 73, 'butsal1': 74, 'buwtea': 75, 'cacgoo1': 76, 'cacwre': 77, 'calqua': 78, 'caltow': 79, 'cangoo': 80, 
    'canwar': 81, 'carchi': 82, 'carwre': 83, 'casfin': 84, 'caskin': 85, 'caster1': 86, 'casvir': 87, 'categr': 88, 
    'ccbfin': 89, 'cedwax': 90, 'chbant1': 91, 'chbchi': 92, 'chbwre1': 93, 'chcant2': 94, 'chispa': 95, 'chswar': 96, 
    'cinfly2': 97, 'clanut': 98, 'clcrob': 99, 'cliswa': 100, 'cobtan1': 101, 'cocwoo1': 102, 'cogdov': 103, 'colcha1': 104, 
    'coltro1': 105, 'comgol': 106, 'comgra': 107, 'comloo': 108, 'commer': 109, 'compau': 110, 'compot1': 111, 'comrav': 112, 
    'comyel': 113, 'coohaw': 114, 'cotfly1': 115, 'cowscj1': 116, 'cregua1': 117, 'creoro1': 118, 'crfpar': 119, 'cubthr': 120, 
    'daejun': 121, 'dowwoo': 122, 'ducfly': 123, 'dusfly': 124, 'easblu': 125, 'easkin': 126, 'easmea': 127, 'easpho': 128, 
    'eastow': 129, 'eawpew': 130, 'eletro': 131, 'eucdov': 132, 'eursta': 133, 'fepowl': 134, 'fiespa': 135, 'flrtan1': 136, 
    'foxspa': 137, 'gadwal': 138, 'gamqua': 139, 'gartro1': 140, 'gbbgul': 141, 'gbwwre1': 142, 'gcrwar': 143, 'gilwoo': 144, 
    'gnttow': 145, 'gnwtea': 146, 'gocfly1': 147, 'gockin': 148, 'gocspa': 149, 'goftyr1': 150, 'gohque1': 151, 'goowoo1': 152, 
    'grasal1': 153, 'grbani': 154, 'grbher3': 155, 'grcfly': 156, 'greegr': 157, 'grekis': 158, 'grepew': 159, 'grethr1': 160, 
    'gretin1': 161, 'greyel': 162, 'grhcha1': 163, 'grhowl': 164, 'grnher': 165, 'grnjay': 166, 'grtgra': 167, 'grycat': 168, 
    'gryhaw2': 169, 'gwfgoo': 170, 'haiwoo': 171, 'heptan': 172, 'hergul': 173, 'herthr': 174, 'herwar': 175, 'higmot1': 176, 
    'hofwoo1': 177, 'houfin': 178, 'houspa': 179, 'houwre': 180, 'hutvir': 181, 'incdov': 182, 'indbun': 183, 'kebtou1': 184, 
    'killde': 185, 'labwoo': 186, 'larspa': 187, 'laufal1': 188, 'laugul': 189, 'lazbun': 190, 'leafly': 191, 'leasan': 192, 
    'lesgol': 193, 'lesgre1': 194, 'lesvio1': 195, 'linspa': 196, 'linwoo1': 197, 'littin1': 198, 'lobdow': 199, 'lobgna5': 200, 
    'logshr': 201, 'lotduc': 202, 'lotman1': 203, 'lucwar': 204, 'macwar': 205, 'magwar': 206, 'mallar3': 207, 'marwre': 208, 
    'mastro1': 209, 'meapar': 210, 'melbla1': 211, 'monoro1': 212, 'mouchi': 213, 'moudov': 214, 'mouela1': 215, 'mouqua': 216, 
    'mouwar': 217, 'mutswa': 218, 'naswar': 219, 'norcar': 220, 'norfli': 221, 'normoc': 222, 'norpar': 223, 'norsho': 224, 
    'norwat': 225, 'nrwswa': 226, 'nutwoo': 227, 'oaktit': 228, 'obnthr1': 229, 'ocbfly1': 230, 'oliwoo1': 231, 'olsfly': 232, 
    'orbeup1': 233, 'orbspa1': 234, 'orcpar': 235, 'orcwar': 236, 'orfpar': 237, 'osprey': 238, 'ovenbi1': 239, 'pabspi1': 240, 
    'paltan1': 241, 'palwar': 242, 'pasfly': 243, 'pavpig2': 244, 'phivir': 245, 'pibgre': 246, 'pilwoo': 247, 'pinsis': 248, 
    'pirfly1': 249, 'plawre1': 250, 'plaxen1': 251, 'plsvir': 252, 'plupig2': 253, 'prowar': 254, 'purfin': 255, 'purgal2': 256, 
    'putfru1': 257, 'pygnut': 258, 'rawwre1': 259, 'rcatan1': 260, 'rebnut': 261, 'rebsap': 262, 'rebwoo': 263, 'redcro': 264, 
    'reevir1': 265, 'rehbar1': 266, 'relpar': 267, 'reshaw': 268, 'rethaw': 269, 'rewbla': 270, 'ribgul': 271, 'rinkin1': 272, 
    'roahaw': 273, 'robgro': 274, 'rocpig': 275, 'rotbec': 276, 'royter1': 277, 'rthhum': 278, 'rtlhum': 279, 'ruboro1': 280, 
    'rubpep1': 281, 'rubrob': 282, 'rubwre1': 283, 'ruckin': 284, 'rucspa1': 285, 'rucwar': 286, 'rucwar1': 287, 'rudpig': 288, 
    'rudtur': 289, 'rufhum': 290, 'rugdov': 291, 'rumfly1': 292, 'runwre1': 293, 'rutjac1': 294, 'saffin': 295, 'sancra': 296, 
    'sander': 297, 'savspa': 298, 'saypho': 299, 'scamac1': 300, 'scatan': 301, 'scbwre1': 302, 'scptyr1': 303, 'scrtan1': 304, 
    'semplo': 305, 'shicow': 306, 'sibtan2': 307, 'sinwre1': 308, 'sltred': 309, 'smbani': 310, 'snogoo': 311, 'sobtyr1': 312, 
    'socfly1': 313, 'solsan': 314, 'sonspa': 315, 'soulap1': 316, 'sposan': 317, 'spotow': 318, 'spvear1': 319, 'squcuc1': 320, 
    'stbori': 321, 'stejay': 322, 'sthant1': 323, 'sthwoo1': 324, 'strcuc1': 325, 'strfly1': 326, 'strsal1': 327, 'stvhum2': 328, 
    'subfly': 329, 'sumtan': 330, 'swaspa': 331, 'swathr': 332, 'tenwar': 333, 'thbeup1': 334, 'thbkin': 335, 'thswar1': 336, 
    'towsol': 337, 'treswa': 338, 'trogna1': 339, 'trokin': 340, 'tromoc': 341, 'tropar': 342, 'tropew1': 343, 'tuftit': 344, 
    'tunswa': 345, 'veery': 346, 'verdin': 347, 'vigswa': 348, 'warvir': 349, 'wbwwre1': 350, 'webwoo1': 351, 'wegspa1': 352, 
    'wesant1': 353, 'wesblu': 354, 'weskin': 355, 'wesmea': 356, 'westan': 357, 'wewpew': 358, 'whbman1': 359, 'whbnut': 360, 
    'whcpar': 361, 'whcsee1': 362, 'whcspa': 363, 'whevir': 364, 'whfpar1': 365, 'whimbr': 366, 'whiwre1': 367, 'whtdov': 368, 
    'whtspa': 369, 'whwbec1': 370, 'whwdov': 371, 'wilfly': 372, 'willet1': 373, 'wilsni1': 374, 'wiltur': 375, 'wlswar': 376, 
    'wooduc': 377, 'woothr': 378, 'wrenti': 379, 'y00475': 380, 'yebcha': 381, 'yebela1': 382, 'yebfly': 383, 'yebori1': 384, 
    'yebsap': 385, 'yebsee1': 386, 'yefgra1': 387, 'yegvir': 388, 'yehbla': 389, 'yehcar1': 390, 'yelgro': 391, 'yelwar': 392, 
    'yeofly1': 393, 'yerwar': 394, 'yeteup1': 395, 'yetvir': 396, 'nocall': 397
}

INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

DATA_DIR = Path('/kaggle/input/birdclef-2021')
LOG_DIR = Path('/kaggle/input/exp-001-20210407091109-0682/')

TEST = (len(list((DATA_DIR / "test_soundscapes/").glob("*.ogg"))) != 0)
if TEST:
    AUDIO_DIR = Path(DATA_DIR / "test_soundscapes")
else:
    AUDIO_DIR = Path(DATA_DIR / "train_soundscapes")

DEVICE= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Dataset

In [None]:
class conf:
    duration = 5
    sampling_rate = 32_000
    n_fft = 2048
    hop_length = 512
    n_mels = 128
    fmin = 20
    fmax = sampling_rate // 2
    power = 2.0
    samples = sampling_rate * duration


def get_transforms(params: Dict):
    def get_object(transform):
        if hasattr(album, transform.name):
            return getattr(album, transform.name)
        else:
            return eval(transform.name)

    transforms = None
    if params is not None:
        transforms = [
            get_object(transform)(**transform.params)
            for name, transform in params.items()
        ]
        transforms = album.Compose(transforms)

    return transforms


def mono_to_color(X, mean=None, std=None, norm_max=None, norm_min=None, eps=1e-6):
    # Stack X as [X,X,X]
    X = np.stack([X, X, X], axis=-1)

    # Standardize
    mean = mean or X.mean()
    std = std or X.std()
    Xstd = (X - mean) / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Scale to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V


class CustomTestDataset(Dataset):
    def __init__(self, df: pd.DataFrame, cfg: Dict):
        super().__init__()
        self.cfg = cfg
        self.filenames = df["file_name"].values
        self.seconds = df["seconds"].values
        self.transforms = get_transforms(cfg.transforms)
        self.y = None
        self.prior_filename = None

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        seconds = self.seconds[idx]
        path_name = str(AUDIO_DIR / filename)
        
        if filename == self.prior_filename:
            y = self.y
        else:
            y, sr = librosa.load(path_name, sr=conf.sampling_rate)
            self.y = y
            self.prior_filename = filename

        start_index = conf.sampling_rate * (seconds - 5)
        end_index = conf.sampling_rate * seconds
        y = y[start_index:end_index].astype(np.float32)

        melspec = librosa.feature.melspectrogram(
            y,
            sr=conf.sampling_rate,
            n_mels=conf.n_mels,
            fmin=conf.fmin,
            fmax=conf.fmax,
        )
        melspec = librosa.power_to_db(melspec).astype(np.float32)
        image = mono_to_color(melspec)

        if self.transforms:
            image = self.transforms(image=image)["image"]

        image = cv2.resize(image, (self.cfg.img_size.height, self.cfg.img_size.width))
        image = image.transpose(2, 0, 1)
        image = (image / 255.0).astype(np.float32)

        return image

## Layer

In [None]:
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
    
    
layer_encoder = {
    'GeM': GeM,
}

## Model

In [None]:
class CustomModel(nn.Module):
    def __init__(
        self,
        n_classes: int,
        model_name: str = "resnet50",
        pooling_name: str = "GeM",
        args_pooling: Optional[Dict] = None,
    ):
        super(CustomModel, self).__init__()

        self.backbone = timm.create_model(model_name, pretrained=False)

        final_in_features = list(self.backbone.children())[-1].in_features
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])

        self.pooling = layer_encoder[pooling_name](**args_pooling)

        self.act = nn.ReLU()
        self.drop = nn.Dropout(p=0.5)
        self.fc = nn.Linear(final_in_features, n_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.pooling(x)
        x = x.view(len(x), -1)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc(x)
        return x

## Inference

In [None]:
def predict(df, cfg):
    test_dataset = CustomTestDataset(df, cfg.data.test)
    test_loader = DataLoader(test_dataset, **cfg.data.test.loader)

    model = CustomModel(
        model_name=cfg.model.backbone,
        n_classes=cfg.model.n_classes,
        **cfg.model.params
    ).to(DEVICE)
    model.load_state_dict(torch.load(str(LOG_DIR / 'weight_best.pt')))
    
    test_preds = np.zeros(
        (
            len(test_loader.dataset),
            cfg.model.n_classes * cfg.data.test.tta.iter_num,
        )
    )
    test_preds_tta = np.zeros(
        (len(test_loader.dataset), cfg.model.n_classes)
    )
    
    model.eval()
    for t in range(cfg.data.test.tta.iter_num):
        for i, images in enumerate(test_loader):
            images = images.to(DEVICE)

            preds = model(images)

            start_batch_idx = i * test_loader.batch_size
            end_batch_idx = (i + 1) * test_loader.batch_size

            start_col_idx = t * cfg.model.n_classes
            end_col_idx = (t + 1) * cfg.model.n_classes

            test_preds[
                start_batch_idx:end_batch_idx, start_col_idx:end_col_idx
            ] = preds.sigmoid().cpu().detach().numpy()

    for i in range(cfg.model.n_classes):
        preds_col_idx = [
            i + cfg.model.n_classes * j
            for j in range(cfg.data.valid.tta.iter_num)
        ]
        test_preds_tta[:, i] = np.mean(
            test_preds[:, preds_col_idx], axis=1
        ).reshape(-1)
    
    return test_preds_tta


def get_predict_labels(preds):
    events = preds >= cfg.threshold
    nocall_col = np.zeros((len(preds), 1)).astype(bool)
    nocall_col[events.sum(1) == 0] = True
    events = np.concatenate([events, nocall_col], axis=1)
    
    predict_labels = []
    for i in range(len(events)):
        event = events[i, :]
        labels = np.argwhere(event).reshape(-1).tolist()
        
        row_labels = []
        for label in labels:
            row_labels.append(INV_BIRD_CODE[label])
        predict_labels.append(" ".join(row_labels))

    return predict_labels

## Main

In [None]:
if TEST:
    test_df = pd.read_csv(DATA_DIR / "test.csv")
    sub_df = pd.read_csv(DATA_DIR / "sample_submission.csv")
else:
    test_df = pd.read_csv(DATA_DIR / "train_soundscape_labels.csv")
    sub_df = pd.read_csv(DATA_DIR / "train_soundscape_labels.csv", usecols=["row_id"])
    
test_df["file_id"] = test_df["audio_id"].astype(str) + "_" + test_df["site"]

p = r"^(.+)_\d+.ogg"
all_test_audio = os.listdir(AUDIO_DIR)
file_id2fname = {re.search(p, f).group(1): f for f in all_test_audio if re.search(p, f)}
test_df["file_name"] = test_df["file_id"].map(file_id2fname)
test_df.head()
    

with open(LOG_DIR / 'config.yml', 'r') as yf:
    cfg = edict(yaml.safe_load(yf))
    
cfg.data.test.loader.batch_size = 16
    
seed_everything(cfg.seed)

test_preds = predict(test_df, cfg)
test_preds_labels = get_predict_labels(test_preds)
sub_df["birds"] = test_preds_labels

sub_df.to_csv('submission.csv', index=False)
sub_df.head()

In [None]:
sub_df['birds'].value_counts()