In [1]:
import re
import os
import gc
import sys
import cv2
import math
import numpy as np
import pandas as pd
from glob import glob
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import librosa
import scipy as sci
import timm

import torch
from torch import nn
from torchvision.models import efficientnet

import pytorch_lightning as pl

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Config:
    # == global config ==
    SEED = 42 # random seed
    DEVICE = 'cpu'  # device to be used
    MIXED_PRECISION = False  # whether to use mixed-16 precision
    # OUTPUT_DIR = '/kaggle/working/'  # output folder
    OUTPUT_DIR = '/home/yaz/birdclef24/out/'  # output folder
    
    # == data config ==
    # DATA_ROOT = '/kaggle/input/birdclef-2024'  # root folder
    DATA_ROOT = '/home/yaz/birdclef24/data'  # root folder
    # PREPROCESSED_DATA_ROOT = '/kaggle/input/birdclef24-spectrograms-via-cupy'
    PREPROCESSED_DATA_ROOT = '/home/yaz/birdclef24/data/specs'
    LOAD_DATA = False # whether to load data from pre-processed dataset
    FS = 32000  # sample rate
    N_FFT = 1095  # n FFT of Spec.
    WIN_SIZE = 412  # WIN_SIZE of Spec.
    WIN_LAP = 100  # overlap of Spec.
    MIN_FREQ = 40  # min frequency
    MAX_FREQ = 15000  # max frequency
    
    # == model config ==
    # MODEL_TYPE = 'efficientnet_b0'  # model type
    
    # == dataset config ==
    BATCH_SIZE = 64  # batch size of each step
    N_WORKERS = 4  # number of workers

    # == inference config ==
    # CKPT_ROOT = '/kaggle/input/effnetbirdclef24/pytorch/baseline_v1/1'
    CKPT_ROOT = '/home/yaz/birdclef24/out'

    # == other config ==
    VISUALIZE = False # whether to visualize data and batch


config = Config()

In [3]:
print('fix seed')
pl.seed_everything(config.SEED, workers=True)

Seed set to 42


fix seed


42

In [4]:
# labels
label_list = sorted(os.listdir(os.path.join(config.DATA_ROOT, 'train_audio')))
label_id_list = list(range(len(label_list)))
label2id = dict(zip(label_list, label_id_list))
id2label = dict(zip(label_id_list, label_list))

In [5]:
device = torch.device("cpu")

In [6]:
metadata_df = pd.read_csv(f'{config.DATA_ROOT}/train_metadata.csv')
metadata_df.head()

Unnamed: 0,primary_label,secondary_labels,type,latitude,longitude,scientific_name,common_name,author,license,rating,url,filename
0,asbfly,[],['call'],39.2297,118.1987,Muscicapa dauurica,Asian Brown Flycatcher,Matt Slaymaker,Creative Commons Attribution-NonCommercial-Sha...,5.0,https://www.xeno-canto.org/134896,asbfly/XC134896.ogg
1,asbfly,[],['song'],51.403,104.6401,Muscicapa dauurica,Asian Brown Flycatcher,Magnus Hellström,Creative Commons Attribution-NonCommercial-Sha...,2.5,https://www.xeno-canto.org/164848,asbfly/XC164848.ogg
2,asbfly,[],['song'],36.3319,127.3555,Muscicapa dauurica,Asian Brown Flycatcher,Stuart Fisher,Creative Commons Attribution-NonCommercial-Sha...,2.5,https://www.xeno-canto.org/175797,asbfly/XC175797.ogg
3,asbfly,[],['call'],21.1697,70.6005,Muscicapa dauurica,Asian Brown Flycatcher,vir joshi,Creative Commons Attribution-NonCommercial-Sha...,4.0,https://www.xeno-canto.org/207738,asbfly/XC207738.ogg
4,asbfly,[],['call'],15.5442,73.7733,Muscicapa dauurica,Asian Brown Flycatcher,Albert Lastukhin & Sergei Karpeev,Creative Commons Attribution-NonCommercial-Sha...,4.0,https://www.xeno-canto.org/209218,asbfly/XC209218.ogg


In [7]:
train_df = metadata_df[['primary_label', 'rating', 'filename']].copy()

# create target
train_df['target'] = train_df.primary_label.map(label2id)
# create filepath
train_df['filepath'] = config.DATA_ROOT + '/train_audio/' + train_df.filename
# create new sample name
train_df['samplename'] = train_df.filename.map(lambda x: x.split('/')[0] + '-' + x.split('/')[-1].split('.')[0])

print(f'{len(train_df)} samples')

train_df.head()

24459 samples


Unnamed: 0,primary_label,rating,filename,target,filepath,samplename
0,asbfly,5.0,asbfly/XC134896.ogg,0,/home/yaz/birdclef24/data/train_audio/asbfly/X...,asbfly-XC134896
1,asbfly,2.5,asbfly/XC164848.ogg,0,/home/yaz/birdclef24/data/train_audio/asbfly/X...,asbfly-XC164848
2,asbfly,2.5,asbfly/XC175797.ogg,0,/home/yaz/birdclef24/data/train_audio/asbfly/X...,asbfly-XC175797
3,asbfly,4.0,asbfly/XC207738.ogg,0,/home/yaz/birdclef24/data/train_audio/asbfly/X...,asbfly-XC207738
4,asbfly,4.0,asbfly/XC209218.ogg,0,/home/yaz/birdclef24/data/train_audio/asbfly/X...,asbfly-XC209218


In [8]:
import numpy as np
import scipy as sci

def ogg2spec_via_scipy(audio_data):
    # handles NaNs
    mean_signal = np.nanmean(audio_data)
    audio_data = np.nan_to_num(audio_data, nan=mean_signal) if np.isnan(audio_data).mean() < 1 else np.zeros_like(audio_data)

    # to spec.
    frequencies, times, spec_data = sci.signal.spectrogram(
        audio_data, 
        fs=config.FS, 
        nfft=config.N_FFT, 
        nperseg=config.WIN_SIZE, 
        noverlap=config.WIN_LAP, 
        window='hann'
    )
    
    # Filter frequency range
    valid_freq = (frequencies >= config.MIN_FREQ) & (frequencies <= config.MAX_FREQ)
    spec_data = spec_data[valid_freq, :]
    
    # Log
    spec_data = np.log10(spec_data + 1e-20)
    
    # min/max normalize
    spec_data = spec_data - spec_data.min()
    spec_data = spec_data / spec_data.max()
    
    return spec_data

In [9]:
class EffNet(nn.Module):
    def __init__(self, model_name="efficientnet_b3", num_classes=None) -> None:
        super().__init__()

        self.model = timm.create_model(
            model_name, 
            pretrained=True, 
            in_chans=1, 
            num_classes=num_classes
        )
        self.sig = nn.Sigmoid()
    
    def forward(self, x):
        # [B, W, H] -> [B, 1, W, H]
        x = x.unsqueeze(1)
        x = self.model(x)
        x = self.sig(x)

        return x

# %% [markdown]
# ## Dataset


In [10]:
class BirdDataset(torch.utils.data.Dataset):
    
    def __init__(
        self,
        bird_data,
        augmentation=None,
    ):
        super().__init__()
        self.bird_data = bird_data
        self.keys_list = list(bird_data.keys())
        self.augmentation = augmentation
    
    def __len__(self):
        return len(self.bird_data)
    
    def __getitem__(self, index):
        
        _spec = self.bird_data[self.keys_list[index]]
        
        if self.augmentation is not None:
            _spec = self.augmentation(image=_spec)['image'] 
        
        return torch.tensor(_spec, dtype=torch.float32)

In [11]:
class BirdModel(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        
        # == backbone ==
        self.backbone = EffNet(num_classes=len(label_list))
        
        # == loss function ==
        self.loss_fn = nn.CrossEntropyLoss()
        
        # == record ==
        self.validation_step_outputs = []
        
    def forward(self, images):
        return self.backbone(images)
    
    def configure_optimizers(self):
        
        # == define optimizer ==
        model_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=config.LR,
            weight_decay=config.WEIGHT_DECAY
        )
        
        # == define learning rate scheduler ==
        # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            model_optimizer,
            T_0=config.EPOCHS,
            T_mult=1,
            eta_min=1e-6,
            last_epoch=-1
        )
        
        return {
            'optimizer': model_optimizer,
            'lr_scheduler': {
                'scheduler': lr_scheduler,
                'interval': 'epoch',
                'monitor': 'val_loss',
                'frequency': 1
            }
        }
    
    def training_step(self, batch, batch_idx):
        
        # == obtain input and target ==
        image, target = batch
        image = image.to(self.device)
        target = target.to(self.device)
        
        # == pred ==
        y_pred = self(image)
        
        # == compute loss ==
        train_loss = self.loss_fn(y_pred, target)
        
        # == record ==
        self.log('train_loss', train_loss, True)
        
        return train_loss
    
    def validation_step(self, batch, batch_idx):
        
        # == obtain input and target ==
        image, target = batch
        image = image.to(self.device)
        target = target.to(self.device)
        
        # == pred ==
        with torch.no_grad():
            y_pred = self(image)
            
        self.validation_step_outputs.append({"logits": y_pred, "targets": target})
        
    def train_dataloader(self):
        return self._train_dataloader

    def validation_dataloader(self):
        return self._validation_dataloader
    
    def on_validation_epoch_end(self):
        
        # = merge batch data =
        outputs = self.validation_step_outputs
        
        output_val = torch.cat([x['logits'] for x in outputs], dim=0).cpu().detach()
        target_val = torch.cat([x['targets'] for x in outputs], dim=0).cpu().detach()
        
        # = compute validation loss =
        val_loss = self.loss_fn(output_val, target_val)
        
        # target to one-hot
        target_val = torch.nn.functional.one_hot(target_val, len(label_list))
        
        # = val with ROC AUC =
        gt_df = pd.DataFrame(target_val.numpy().astype(np.float32), columns=label_list)
        pred_df = pd.DataFrame(output_val.numpy().astype(np.float32), columns=label_list)
        
        gt_df['id'] = [f'id_{i}' for i in range(len(gt_df))]
        pred_df['id'] = [f'id_{i}' for i in range(len(pred_df))]
        
        val_score = score(gt_df, pred_df, row_id_column_name='id')
        
        self.log("val_score", val_score, True)
        
        # clear validation outputs
        self.validation_step_outputs = list()
        
        return {'val_loss': val_loss, 'val_score': val_score}

In [12]:
ckpt_list = [f'{config.CKPT_ROOT}/fold_0-v2.ckpt']


In [13]:
input_tensor = torch.randn(config.BATCH_SIZE, 256, 256)  # input shape
input_names = ['x']
output_names = ['output']

In [14]:
onnx_ckpt_list = list()
for ckpt_path in ckpt_list:
    ckpt_name = os.path.basename(ckpt_path).split('.')[0]
    # == init model ==
    bird_model = BirdModel()
    
    # == load ckpt ==
    weights = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']
    bird_model.load_state_dict(weights)
    bird_model.eval()
    
    # == convert to onnx ==
    torch.onnx.export(bird_model.backbone, input_tensor, f"{ckpt_name}.onnx", verbose=False, input_names=input_names, output_names=output_names)
    
    onnx_ckpt_list.append(f"{ckpt_name}.onnx")

In [15]:
all_bird_data = dict()

# https://www.kaggle.com/code/markwijkhuizen/birdclef-2024-efficientvit-inference
if len(glob(f'{config.DATA_ROOT}/test_soundscapes/*.ogg')) > 0:
    ogg_file_paths = glob(f'{config.DATA_ROOT}/test_soundscapes/*.ogg')
else:
    ogg_file_paths = sorted(glob(f'{config.DATA_ROOT}/unlabeled_soundscapes/*.ogg'))[:10]

for i, file_path in tqdm(enumerate(ogg_file_paths), total=len(ogg_file_paths)):
    row_id = re.search(r'/([^/]+)\.ogg$', file_path).group(1)  # filename
    audio_data, _ = librosa.load(file_path, sr=config.FS)
    
    # to spec.
    spec = ogg2spec_via_scipy(audio_data)
    
    # pad
    pad = 512 - (spec.shape[1] % 512)
    if pad > 0:
        spec = np.pad(spec, ((0,0), (0,pad)))
    
    # reshape
    spec = spec.reshape(512,-1,512).transpose([0, 2, 1])
    spec = cv2.resize(spec, (256, 256), interpolation=cv2.INTER_AREA)
    
    for j in range(48):
        all_bird_data[f'{row_id}_{(j+1)*5}'] = spec[:, :, j]

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:13<00:00,  1.40s/it]


In [33]:
def predict(data_loader, onnx_model):
    pred = []
    for batch in tqdm(data_loader):
        with torch.no_grad():
            x = batch
            n_pad = 0
            
            # == make sure the batch_size equal to setting
            if x.shape[0] < config.BATCH_SIZE:
                n_pad = config.BATCH_SIZE - x.shape[0]
                zero_tensor = torch.zeros((n_pad, 256, 256))
                x = torch.cat([x, zero_tensor], dim=0)
            
            outputs = onnx_model.run(output_names, {input_names[0]: x.numpy()})[0]
            # outputs = sci.special.softmax(outputs[:config.BATCH_SIZE-n_pad, ...], axis=1)
            # outputs = sci.special.expit(outputs[:config.BATCH_SIZE-n_pad, ...])
            outputs = outputs[:config.BATCH_SIZE - n_pad, ...]
        pred.append(outputs)
    
    return np.concatenate(pred, axis=0)

In [34]:
!pip install /kaggle/input/onnxruntime/onnxruntime-1.17.3-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl --no-index --find-links /kaggle/input/onnxruntime

[0mLooking in links: /kaggle/input/onnxruntime
Processing /kaggle/input/onnxruntime/onnxruntime-1.17.3-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
[31mERROR: Could not install packages due to an OSError: [Errno 2] No such file or directory: '/kaggle/input/onnxruntime/onnxruntime-1.17.3-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl'
[0m[31m
[0m

In [35]:
import onnx
import onnxruntime as ort

In [36]:

predictions = []

for ckpt in onnx_ckpt_list:
    
    # == init ONNX model ==
    onnx_model = onnx.load(ckpt)
    onnx_model_graph = onnx_model.graph
    onnx_session = ort.InferenceSession(onnx_model.SerializeToString())
    
    # == create dataset & dataloader ==
    test_dataset = BirdDataset(all_bird_data)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        num_workers=config.N_WORKERS,
        shuffle=False,
        drop_last=False
    )
    
    predictions.append(predict(test_loader, onnx_session))
    gc.collect()

predictions = np.mean(predictions, axis=0)

100%|██████████| 8/8 [00:51<00:00,  6.43s/it]


In [37]:
sub_pred = pd.DataFrame(predictions, columns=label_list)
sub_id = pd.DataFrame({'row_id': list(all_bird_data.keys())})

sub = pd.concat([sub_id, sub_pred], axis=1)

sub.to_csv('submission.csv',index=False)
print(f'Submissionn shape: {sub.shape}')
sub.head(5)

Submissionn shape: (480, 183)


Unnamed: 0,row_id,asbfly,ashdro1,ashpri1,ashwoo2,asikoe2,asiope1,aspfly1,aspswi1,barfly1,...,whbwoo2,whcbar1,whiter2,whrmun,whtkin2,woosan,wynlau1,yebbab1,yebbul3,zitcis1
0,1000170626_5,0.003723,0.004229,0.000502,0.000394,0.004116,0.000652,0.000469,0.000983,0.002857,...,0.001649,0.000674,0.001414,0.008595,0.008795,0.011197,0.000999,0.000743,0.000443,0.000741
1,1000170626_10,0.00285,0.002882,0.000791,0.000392,0.008801,0.000827,0.000542,0.001651,0.001828,...,0.004571,0.001344,0.001578,0.00332,0.015401,0.025757,0.001358,0.000568,0.000581,0.000644
2,1000170626_15,0.001534,0.000863,0.000717,0.000377,0.001915,0.000769,0.000427,0.00097,0.001388,...,0.001491,0.000644,0.002587,0.004041,0.001557,0.022457,0.001209,0.000743,0.000322,0.001332
3,1000170626_20,0.006436,0.003801,0.00032,0.000314,0.005617,0.000833,0.000591,0.00135,0.003048,...,0.008529,0.001082,0.000811,0.007079,0.059859,0.003592,0.001096,0.000792,0.000612,0.000438
4,1000170626_25,0.002176,0.001391,0.0011,0.000898,0.005209,0.001141,0.000803,0.00118,0.001547,...,0.002974,0.000957,0.004663,0.001848,0.003423,0.033054,0.00156,0.000833,0.000645,0.003131


## Test scoring

In [38]:
from metric import score

In [39]:
sub

Unnamed: 0,row_id,asbfly,ashdro1,ashpri1,ashwoo2,asikoe2,asiope1,aspfly1,aspswi1,barfly1,...,whbwoo2,whcbar1,whiter2,whrmun,whtkin2,woosan,wynlau1,yebbab1,yebbul3,zitcis1
0,1000170626_5,0.003723,0.004229,0.000502,0.000394,0.004116,0.000652,0.000469,0.000983,0.002857,...,0.001649,0.000674,0.001414,0.008595,0.008795,0.011197,0.000999,0.000743,0.000443,0.000741
1,1000170626_10,0.002850,0.002882,0.000791,0.000392,0.008801,0.000827,0.000542,0.001651,0.001828,...,0.004571,0.001344,0.001578,0.003320,0.015401,0.025757,0.001358,0.000568,0.000581,0.000644
2,1000170626_15,0.001534,0.000863,0.000717,0.000377,0.001915,0.000769,0.000427,0.000970,0.001388,...,0.001491,0.000644,0.002587,0.004041,0.001557,0.022457,0.001209,0.000743,0.000322,0.001332
3,1000170626_20,0.006436,0.003801,0.000320,0.000314,0.005617,0.000833,0.000591,0.001350,0.003048,...,0.008529,0.001082,0.000811,0.007079,0.059859,0.003592,0.001096,0.000792,0.000612,0.000438
4,1000170626_25,0.002176,0.001391,0.001100,0.000898,0.005209,0.001141,0.000803,0.001180,0.001547,...,0.002974,0.000957,0.004663,0.001848,0.003423,0.033054,0.001560,0.000833,0.000645,0.003131
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
475,1001358022_220,0.001727,0.001369,0.000358,0.000391,0.001919,0.000576,0.000424,0.000409,0.001516,...,0.000916,0.000785,0.000760,0.000938,0.001830,0.013524,0.000561,0.000278,0.000184,0.001260
476,1001358022_225,0.000508,0.001780,0.001442,0.000609,0.006306,0.000954,0.000733,0.001339,0.001542,...,0.003178,0.001743,0.003349,0.003541,0.001490,0.066607,0.001411,0.000759,0.000349,0.001924
477,1001358022_230,0.000564,0.001413,0.001334,0.000540,0.002917,0.000805,0.000638,0.000891,0.001284,...,0.001145,0.000890,0.002608,0.002513,0.001668,0.040288,0.000964,0.000494,0.000320,0.001725
478,1001358022_235,0.000769,0.001117,0.000992,0.000515,0.001577,0.000742,0.000474,0.000725,0.001662,...,0.001456,0.000800,0.001106,0.001448,0.002291,0.022501,0.000770,0.000355,0.000280,0.001204


In [40]:
len(list(all_bird_data.keys()))

480

In [41]:
predictions.shape

(480, 182)

In [54]:
gt = sub.copy()
gt.iloc[:, 1:] = np.random.randint(0, 2, size=gt.iloc[:, 1:].shape)

In [57]:
gt

Unnamed: 0,row_id,asbfly,ashdro1,ashpri1,ashwoo2,asikoe2,asiope1,aspfly1,aspswi1,barfly1,...,whbwoo2,whcbar1,whiter2,whrmun,whtkin2,woosan,wynlau1,yebbab1,yebbul3,zitcis1
0,1000170626_5,1.0,0.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,...,0.0,1.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,1.0
1,1000170626_10,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,...,1.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0
2,1000170626_15,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0
3,1000170626_20,0.0,0.0,0.0,1.0,0.0,1.0,1.0,1.0,0.0,...,1.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
4,1000170626_25,0.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
475,1001358022_220,1.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,1.0,...,1.0,0.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,0.0
476,1001358022_225,0.0,1.0,0.0,1.0,0.0,1.0,1.0,1.0,0.0,...,1.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0
477,1001358022_230,0.0,0.0,1.0,0.0,1.0,1.0,0.0,1.0,1.0,...,1.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0
478,1001358022_235,1.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,...,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0


In [58]:
score(gt, sub, row_id_column_name="row_id")

0.4994172031772896