In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
import cv2
import audioread
import logging
import os
import random
import time
import warnings

import librosa
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

from contextlib import contextmanager
from pathlib import Path
from typing import Optional

from fastprogress import progress_bar
from sklearn.metrics import f1_score
from torchvision import models

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def get_logger(out_file=None):
    logger= logging.getLogger()
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    logger.handlers=[]
    logger.setLevel(logging.INFO)
    
    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    handler.setLevel(logging.INFO)
    logger.addHandler(handler)
    
    if out_file is not None:
        fh = logging.FileHandler(out_file)
        fh.setFormatter(formatter)
        fh.setLevel(logging.INFO)
        logger.addHandler(fh)
    logger.info("logger set up")
    return logger

@contextmanager
def timer(name: str, logger: Optional[logging.Logger] = None):
    t0 = time.time()
    msg = f"[{name}] start"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)
    yield
    
    msg = f"[{name}] done in {time.time() - t0:.2f} s"
    if logger is None:
        print(msg)
    else:
        logger.info(msg)  

In [3]:
logger = get_logger('main.log')
set_seed(1213)

2024-01-11 14:01:52,328 - INFO - logger set up


In [4]:
TARGET_SR = 32000

In [5]:
test = pd.read_csv('/kaggle/input/birdcall-check/test.csv')
test_audio = "/kaggle/input/birdcall-check/test_audio"
test.head(5)

Unnamed: 0,site,row_id,seconds,audio_id
0,site_1,site_1_41e6fe6504a34bf6846938ba78d13df1_5,5.0,41e6fe6504a34bf6846938ba78d13df1
1,site_1,site_1_41e6fe6504a34bf6846938ba78d13df1_10,10.0,41e6fe6504a34bf6846938ba78d13df1
2,site_1,site_1_41e6fe6504a34bf6846938ba78d13df1_15,15.0,41e6fe6504a34bf6846938ba78d13df1
3,site_1,site_1_41e6fe6504a34bf6846938ba78d13df1_20,20.0,41e6fe6504a34bf6846938ba78d13df1
4,site_1,site_1_41e6fe6504a34bf6846938ba78d13df1_25,25.0,41e6fe6504a34bf6846938ba78d13df1


In [6]:
class ResNet(nn.Module):
    def __init__(self, base_model_name: str,pretrained = False,num_classes = 264):
        super().__init__()
        base_model = models.__getattribute__(base_model_name)(pretrained = pretrained)
        layers = list(base_model.children())[:-2]
        layers.append(nn.AdaptiveMaxPool2d(1))
        self.encoder = nn.Sequential(*layers)
        
        in_features = base_model.fc.in_features
        
        self.classifier = nn.Sequential(
        nn.Linear(in_features,1024), nn.ReLU(), nn.Dropout(p=0.2),
        nn.Linear(1024,1024), nn.ReLU(), nn.Dropout(p=0.2),
        nn.Linear(1024, num_classes))
        
    def forward(self, x):
        batch_size = x.size(0)
        x = self.encoder(x).view(batch_size,-1)
        x = self.classifier(x)
        multiclass_proba = F.softmax(x,dim=1)
        multilabel_proba = F.sigmoid(x)
        return{
            "logits" : x,
            "multiclass_proba": multiclass_proba,
            "multilabel_proba": multilabel_proba
        }

In [7]:
model_config ={
    "base_model_name": "resnet50",
    "pretrained": False,
    "num_classes": 264
}

melspectrogram_parameters = {
    "n_mels": 128,
    "fmin": 20,
    "fmax": 16000
}

weights_path = "/kaggle/input/birdcall-resnet50-init-weights/best.pth"

In [8]:
from sklearn.preprocessing import LabelEncoder
df = pd.read_csv("/kaggle/input/birdsong-recognition/train.csv")

unique_bird_names = df["ebird_code"].unique()
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(unique_bird_names)
BIRD_CODE = dict(zip(unique_bird_names,encoded_labels))
INV_BIRD_CODE = {v: k for k, v in BIRD_CODE.items()}

In [9]:
def mono_to_color(X: np.ndarray,
                  mean=None,
                  std=None,
                  norm_max=None,
                  norm_min=None,
                  eps=1e-6):
    
    X = np.stack([X, X, X], axis=-1)

 
    mean = mean or X.mean()
    X = X - mean
    std = std or X.std()
    Xstd = X / (std + eps)
    _min, _max = Xstd.min(), Xstd.max()
    norm_max = norm_max or _max
    norm_min = norm_min or _min
    if (_max - _min) > eps:
        # Normalize to [0, 255]
        V = Xstd
        V[V < norm_min] = norm_min
        V[V > norm_max] = norm_max
        V = 255 * (V - norm_min) / (norm_max - norm_min)
        V = V.astype(np.uint8)
    else:
        # Just zero
        V = np.zeros_like(Xstd, dtype=np.uint8)
    return V


class TestDataset(data.Dataset):
    def __init__(self, df: pd.DataFrame, clip: np.ndarray,
                 img_size=224, melspectrogram_parameters={}):
        self.df = df
        self.clip = clip
        self.img_size = img_size
        self.melspectrogram_parameters = melspectrogram_parameters
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        SR = 32000
        sample = self.df.loc[idx, :]
        site = sample.site
        row_id = sample.row_id
        
        if site == "site_3":
            y = self.clip.astype(np.float32)
            len_y = len(y)
            start = 0
            end = SR * 5
            images = []
            while len_y > start:
                y_batch = y[start:end].astype(np.float32)
                if len(y_batch) != (SR * 5):
                    break
                start = end
                end = end + SR * 5
                
                melspec = librosa.feature.melspectrogram(y=y_batch,
                                                         sr=SR,
                                                         **self.melspectrogram_parameters)
                melspec = librosa.power_to_db(melspec).astype(np.float32)
                image = mono_to_color(melspec)
                height, width, _ = image.shape
                image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
                image = np.moveaxis(image, 2, 0)
                image = (image / 255.0).astype(np.float32)
                images.append(image)
            images = np.asarray(images)
            return images, row_id, site
        else:
            end_seconds = int(sample.seconds)
            start_seconds = int(end_seconds - 5)
            
            start_index = SR * start_seconds
            end_index = SR * end_seconds
            
            y = self.clip[start_index:end_index].astype(np.float32)

            melspec = librosa.feature.melspectrogram(y=y, sr=SR, **self.melspectrogram_parameters)
            melspec = librosa.power_to_db(melspec).astype(np.float32)

            image = mono_to_color(melspec)
            height, width, _ = image.shape
            image = cv2.resize(image, (int(width * self.img_size / height), self.img_size))
            image = np.moveaxis(image, 2, 0)
            image = (image / 255.0).astype(np.float32)

            return image, row_id, site

In [10]:
def get_model(config: dict, weights_path: str):
    model = ResNet(**config)
    checkpoint = torch.load(weights_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint["model_state_dict"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    return model

In [11]:
def prediction_for_clip(test_df: pd.DataFrame, 
                        clip: np.ndarray, 
                        model: ResNet, 
                        mel_params: dict, 
                        threshold=0.5):

    dataset = TestDataset(df=test_df, 
                          clip=clip,
                          img_size=224,
                          melspectrogram_parameters=mel_params)
    loader = data.DataLoader(dataset, batch_size=1, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    prediction_dict = {}
    for image, row_id, site in progress_bar(loader):
        site = site[0]
        row_id = row_id[0]
        if site in {"site_1", "site_2"}:
            image = image.to(device)

            with torch.no_grad():
                prediction = model(image)
                proba = prediction["multilabel_proba"].detach().cpu().numpy().reshape(-1)

            events = proba >= threshold
            labels = np.argwhere(events).reshape(-1).tolist()

        else:
            # to avoid prediction on large batch
            image = image.squeeze(0)
            batch_size = 16
            whole_size = image.size(0)
            if whole_size % batch_size == 0:
                n_iter = whole_size // batch_size
            else:
                n_iter = whole_size // batch_size + 1
                
            all_events = set()
            for batch_i in range(n_iter):
                batch = image[batch_i * batch_size:(batch_i + 1) * batch_size]
                if batch.ndim == 3:
                    batch = batch.unsqueeze(0)

                batch = batch.to(device)
                with torch.no_grad():
                    prediction = model(batch)
                    proba = prediction["multilabel_proba"].detach().cpu().numpy()
                    
                events = proba >= threshold
                for i in range(len(events)):
                    event = events[i, :]
                    labels = np.argwhere(event).reshape(-1).tolist()
                    for label in labels:
                        all_events.add(label)
                        
            labels = list(all_events)
        if len(labels) == 0:
            prediction_dict[row_id] = "nocall"
        else:
            labels_str_list = list(map(lambda x: INV_BIRD_CODE[x], labels))
            label_string = " ".join(labels_str_list)
            prediction_dict[row_id] = label_string
    return prediction_dict

In [12]:
import os


def prediction(test_df: pd.DataFrame,
               test_audio: Path,
               model_config: dict,
               mel_params: dict,
               weights_path: str,
               threshold=0.5):
    model = get_model(model_config, weights_path)
    unique_audio_id = test_df.audio_id.unique()

    warnings.filterwarnings("ignore")
    prediction_dfs = []
    for audio_id in unique_audio_id:
        with timer(f"Loading {audio_id}", logger):
            clip, _ = librosa.load(test_audio+"/"+(audio_id+".mp3"),
                                  sr=TARGET_SR,
                                  mono=True,
                                  res_type="kaiser_fast")
            test_df_for_audio_id = test_df.query(
            f"audio_id == '{audio_id}'").reset_index(drop=True)
        with timer(f"Prediction on {audio_id}", logger):
            prediction_dict = prediction_for_clip(test_df_for_audio_id,
                                                  clip=clip,
                                                  model=model,
                                                  mel_params=mel_params,
                                                  threshold=threshold)
        row_id = list(prediction_dict.keys())
        birds = list(prediction_dict.values())
        prediction_df = pd.DataFrame({
            "row_id": row_id,
            "birds": birds
        })
        prediction_dfs.append(prediction_df)
    
    prediction_df = pd.concat(prediction_dfs, axis=0, sort=False).reset_index(drop=True)
    return prediction_df

In [13]:
pip install resampy

Collecting resampy
  Downloading resampy-0.4.2-py3-none-any.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: resampy
Successfully installed resampy-0.4.2
Note: you may need to restart the kernel to use updated packages.


In [14]:
submission = prediction(test_df=test,
                        test_audio=test_audio,
                        model_config=model_config,
                        mel_params=melspectrogram_parameters,
                        weights_path=weights_path,
                        threshold=0.8)
submission.to_csv("submission.csv", index=False)

2024-01-11 14:02:06,328 - INFO - [Loading 41e6fe6504a34bf6846938ba78d13df1] start
2024-01-11 14:02:16,357 - INFO - [Loading 41e6fe6504a34bf6846938ba78d13df1] done in 10.03 s
2024-01-11 14:02:16,359 - INFO - [Prediction on 41e6fe6504a34bf6846938ba78d13df1] start


2024-01-11 14:02:18,886 - INFO - [Prediction on 41e6fe6504a34bf6846938ba78d13df1] done in 2.53 s
2024-01-11 14:02:18,889 - INFO - [Loading cce64fffafed40f2b2f3d3413ec1c4c2] start
2024-01-11 14:02:19,198 - INFO - [Loading cce64fffafed40f2b2f3d3413ec1c4c2] done in 0.31 s
2024-01-11 14:02:19,199 - INFO - [Prediction on cce64fffafed40f2b2f3d3413ec1c4c2] start


2024-01-11 14:02:21,315 - INFO - [Prediction on cce64fffafed40f2b2f3d3413ec1c4c2] done in 2.12 s
2024-01-11 14:02:21,318 - INFO - [Loading 99af324c881246949408c0b1ae54271f] start
2024-01-11 14:02:21,657 - INFO - [Loading 99af324c881246949408c0b1ae54271f] done in 0.34 s
2024-01-11 14:02:21,658 - INFO - [Prediction on 99af324c881246949408c0b1ae54271f] start


2024-01-11 14:02:23,445 - INFO - [Prediction on 99af324c881246949408c0b1ae54271f] done in 1.79 s
2024-01-11 14:02:23,448 - INFO - [Loading 6ab74e177aa149468a39ca10beed6222] start
2024-01-11 14:02:23,732 - INFO - [Loading 6ab74e177aa149468a39ca10beed6222] done in 0.28 s
2024-01-11 14:02:23,734 - INFO - [Prediction on 6ab74e177aa149468a39ca10beed6222] start


2024-01-11 14:02:25,220 - INFO - [Prediction on 6ab74e177aa149468a39ca10beed6222] done in 1.49 s
2024-01-11 14:02:25,221 - INFO - [Loading b2fd3f01e9284293a1e33f9c811a2ed6] start
2024-01-11 14:02:25,530 - INFO - [Loading b2fd3f01e9284293a1e33f9c811a2ed6] done in 0.31 s
2024-01-11 14:02:25,531 - INFO - [Prediction on b2fd3f01e9284293a1e33f9c811a2ed6] start


2024-01-11 14:02:27,291 - INFO - [Prediction on b2fd3f01e9284293a1e33f9c811a2ed6] done in 1.76 s
2024-01-11 14:02:27,293 - INFO - [Loading de62b37ebba749d2abf29d4a493ea5d4] start
2024-01-11 14:02:27,370 - INFO - [Loading de62b37ebba749d2abf29d4a493ea5d4] done in 0.08 s
2024-01-11 14:02:27,371 - INFO - [Prediction on de62b37ebba749d2abf29d4a493ea5d4] start


2024-01-11 14:02:27,681 - INFO - [Prediction on de62b37ebba749d2abf29d4a493ea5d4] done in 0.31 s
2024-01-11 14:02:27,682 - INFO - [Loading 8680a8dd845d40f296246dbed0d37394] start
2024-01-11 14:02:28,061 - INFO - [Loading 8680a8dd845d40f296246dbed0d37394] done in 0.38 s
2024-01-11 14:02:28,062 - INFO - [Prediction on 8680a8dd845d40f296246dbed0d37394] start


2024-01-11 14:02:30,549 - INFO - [Prediction on 8680a8dd845d40f296246dbed0d37394] done in 2.49 s
2024-01-11 14:02:30,552 - INFO - [Loading 940d546e5eb745c9a74bce3f35efa1f9] start
2024-01-11 14:02:31,173 - INFO - [Loading 940d546e5eb745c9a74bce3f35efa1f9] done in 0.62 s
2024-01-11 14:02:31,175 - INFO - [Prediction on 940d546e5eb745c9a74bce3f35efa1f9] start


2024-01-11 14:02:35,177 - INFO - [Prediction on 940d546e5eb745c9a74bce3f35efa1f9] done in 4.00 s
2024-01-11 14:02:35,178 - INFO - [Loading 07ab324c602e4afab65ddbcc746c31b5] start
2024-01-11 14:02:35,407 - INFO - [Loading 07ab324c602e4afab65ddbcc746c31b5] done in 0.23 s
2024-01-11 14:02:35,408 - INFO - [Prediction on 07ab324c602e4afab65ddbcc746c31b5] start


2024-01-11 14:02:36,786 - INFO - [Prediction on 07ab324c602e4afab65ddbcc746c31b5] done in 1.38 s
2024-01-11 14:02:36,787 - INFO - [Loading 899616723a32409c996f6f3441646c2a] start
2024-01-11 14:02:37,215 - INFO - [Loading 899616723a32409c996f6f3441646c2a] done in 0.43 s
2024-01-11 14:02:37,216 - INFO - [Prediction on 899616723a32409c996f6f3441646c2a] start


2024-01-11 14:02:39,735 - INFO - [Prediction on 899616723a32409c996f6f3441646c2a] done in 2.52 s
2024-01-11 14:02:39,737 - INFO - [Loading 9cc5d9646f344f1bbb52640a988fe902] start
2024-01-11 14:02:41,948 - INFO - [Loading 9cc5d9646f344f1bbb52640a988fe902] done in 2.21 s
2024-01-11 14:02:41,949 - INFO - [Prediction on 9cc5d9646f344f1bbb52640a988fe902] start


2024-01-11 14:02:54,360 - INFO - [Prediction on 9cc5d9646f344f1bbb52640a988fe902] done in 12.41 s
2024-01-11 14:02:54,361 - INFO - [Loading a56e20a518684688a9952add8a9d5213] start
2024-01-11 14:02:54,636 - INFO - [Loading a56e20a518684688a9952add8a9d5213] done in 0.27 s
2024-01-11 14:02:54,637 - INFO - [Prediction on a56e20a518684688a9952add8a9d5213] start


2024-01-11 14:02:56,005 - INFO - [Prediction on a56e20a518684688a9952add8a9d5213] done in 1.37 s
2024-01-11 14:02:56,007 - INFO - [Loading 96779836288745728306903d54e264dd] start
2024-01-11 14:02:56,181 - INFO - [Loading 96779836288745728306903d54e264dd] done in 0.17 s
2024-01-11 14:02:56,182 - INFO - [Prediction on 96779836288745728306903d54e264dd] start


2024-01-11 14:02:56,787 - INFO - [Prediction on 96779836288745728306903d54e264dd] done in 0.60 s
2024-01-11 14:02:56,788 - INFO - [Loading f77783ba4c6641bc918b034a18c23e53] start
2024-01-11 14:02:56,889 - INFO - [Loading f77783ba4c6641bc918b034a18c23e53] done in 0.10 s
2024-01-11 14:02:56,890 - INFO - [Prediction on f77783ba4c6641bc918b034a18c23e53] start


2024-01-11 14:02:57,173 - INFO - [Prediction on f77783ba4c6641bc918b034a18c23e53] done in 0.28 s
2024-01-11 14:02:57,175 - INFO - [Loading 856b194b097441958697c2bcd1f63982] start
2024-01-11 14:02:57,424 - INFO - [Loading 856b194b097441958697c2bcd1f63982] done in 0.25 s
2024-01-11 14:02:57,425 - INFO - [Prediction on 856b194b097441958697c2bcd1f63982] start


2024-01-11 14:02:58,517 - INFO - [Prediction on 856b194b097441958697c2bcd1f63982] done in 1.09 s


In [15]:
submission

Unnamed: 0,row_id,birds
0,site_1_41e6fe6504a34bf6846938ba78d13df1_5,aldfly
1,site_1_41e6fe6504a34bf6846938ba78d13df1_10,aldfly
2,site_1_41e6fe6504a34bf6846938ba78d13df1_15,aldfly
3,site_1_41e6fe6504a34bf6846938ba78d13df1_20,nocall
4,site_1_41e6fe6504a34bf6846938ba78d13df1_25,aldfly
...,...,...
71,site_3_9cc5d9646f344f1bbb52640a988fe902,aldfly
72,site_3_a56e20a518684688a9952add8a9d5213,aldfly
73,site_3_96779836288745728306903d54e264dd,aldfly
74,site_3_f77783ba4c6641bc918b034a18c23e53,aldfly


In [16]:
submission['short_row_id'] = submission['row_id'].str[:39]
submission

Unnamed: 0,row_id,birds,short_row_id
0,site_1_41e6fe6504a34bf6846938ba78d13df1_5,aldfly,site_1_41e6fe6504a34bf6846938ba78d13df1
1,site_1_41e6fe6504a34bf6846938ba78d13df1_10,aldfly,site_1_41e6fe6504a34bf6846938ba78d13df1
2,site_1_41e6fe6504a34bf6846938ba78d13df1_15,aldfly,site_1_41e6fe6504a34bf6846938ba78d13df1
3,site_1_41e6fe6504a34bf6846938ba78d13df1_20,nocall,site_1_41e6fe6504a34bf6846938ba78d13df1
4,site_1_41e6fe6504a34bf6846938ba78d13df1_25,aldfly,site_1_41e6fe6504a34bf6846938ba78d13df1
...,...,...,...
71,site_3_9cc5d9646f344f1bbb52640a988fe902,aldfly,site_3_9cc5d9646f344f1bbb52640a988fe902
72,site_3_a56e20a518684688a9952add8a9d5213,aldfly,site_3_a56e20a518684688a9952add8a9d5213
73,site_3_96779836288745728306903d54e264dd,aldfly,site_3_96779836288745728306903d54e264dd
74,site_3_f77783ba4c6641bc918b034a18c23e53,aldfly,site_3_f77783ba4c6641bc918b034a18c23e53


In [17]:
submission.drop_duplicates(subset="short_row_id", 
                     keep="first", inplace=True) 
submission['row_id'] = submission['short_row_id']
submission.drop('short_row_id',axis=1,inplace=True)
  

In [18]:
submission

Unnamed: 0,row_id,birds
0,site_1_41e6fe6504a34bf6846938ba78d13df1,aldfly
5,site_1_cce64fffafed40f2b2f3d3413ec1c4c2,aldfly
12,site_1_99af324c881246949408c0b1ae54271f,aldfly
19,site_1_6ab74e177aa149468a39ca10beed6222,aldfly
25,site_1_b2fd3f01e9284293a1e33f9c811a2ed6,aldfly
32,site_2_de62b37ebba749d2abf29d4a493ea5d4,aldfly
33,site_2_8680a8dd845d40f296246dbed0d37394,aldfly
42,site_2_940d546e5eb745c9a74bce3f35efa1f9,nocall
56,site_2_07ab324c602e4afab65ddbcc746c31b5,aldfly
61,site_2_899616723a32409c996f6f3441646c2a,nocall
