# Importing Trained Model

In [None]:
# pytorch
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
import torchaudio 
import timm

# inference
import os
from pathlib import Path
from joblib import Parallel, delayed
from tqdm import tqdm
import glob
import pandas as pd

In [None]:
CONFIG = {"epochs": 10,
          "num_fold": 5,
          "num_classes": 262,
          "train_batch_size": 16,
          "valid_batch_size": 16,
          "sample_rate": 32_000,
          "hop_length": 512,
          "max_time": 5,
          "n_mels": 128,
          "n_fft": 1024,
          }

class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'

class BirdCLEFModel(nn.Module):
    def __init__(self, 
                 model_name="tf_efficientnet_b4_ns", 
                 embedding_size=768, 
                 pretrained=True,
                 CONFIG=None):
        super(BirdCLEFModel, self).__init__()
        self.config = CONFIG
        self.model = timm.create_model('tf_efficientnet_b4',
                     # local loading
                     checkpoint_path='/kaggle/input/tf-efficientnet/pytorch/tf-efficientnet-b4/1/tf_efficientnet_b4_aa-818f208c.pth')
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.embedding = nn.Linear(in_features, embedding_size)
        self.fc = nn.Linear(embedding_size, CONFIG["num_classes"])
    
    def forward(self, images):
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        embedding = self.embedding(pooled_features)
        output = self.fc(embedding)
        return output

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BirdCLEFModel(CONFIG=CONFIG).to(device)
# load trained model checkpoint
model.load_state_dict(torch.load("/kaggle/input/birdclef-model-pytorch/EFN-validsplit-10-32-32-32000-512-5-128-1024-0-7.pt", map_location=torch.device('cpu')))
#model.eval()

# Make inference

In [None]:
df = pd.read_csv('/kaggle/input/birdclef-2023/train_metadata.csv')
competition_classes = sorted(df.primary_label.unique())
filepaths = glob.glob('/kaggle/input/birdclef-2023/test_soundscapes/*.ogg')
# remove classes without examples
competition_classes.remove("afpkin1")
competition_classes.remove("golher1")

In [None]:
# make predictions
MIN_WINDOW = 32_000 * 5
from torchaudio import transforms as audtr
import torch.nn.functional as F

def process(filepath):
    all_predictions = []
    name = Path(filepath).stem
    audio = torchaudio.load(filepath)[0][0]
    for i in range(0, 120):
        crop = audio[i*MIN_WINDOW:(i+1)*MIN_WINDOW]
        
        mel_spectrogram = audtr.MelSpectrogram(sample_rate=CONFIG["sample_rate"],
                                        n_mels=CONFIG["n_mels"],
                                        n_fft=CONFIG["n_fft"])
        mel = mel_spectrogram(crop)
        image = torch.stack([mel, mel, mel])
        
        # normalize image
        max_val = torch.abs(image).max()
        image = image / max_val
        
        with torch.no_grad():
#             outputs = model(image[None])
#             _, pred = torch.max(outputs, 1)
#             pred = F.sigmoid(pred)
            pred = F.softmax(model(image[None]))
        t = (i + 1) * 5
        
        all_predictions.append({"row_id": f'{name}_{t}',"predictions": pred})
    return all_predictions

all_predictions = Parallel(n_jobs=os.cpu_count())(
    delayed(process)(filepath) 
    for filepath in tqdm(filepaths, 'Processing files')
)
all_predictions = [p2 for p in all_predictions for p2 in p] # flatten

In [None]:
# convert predictions into a dataframe
df = pd.concat([
    pd.DataFrame({'row_id': [p['row_id'] for p in all_predictions]}), 
    pd.DataFrame(torch.stack([p['predictions'][0] for p in all_predictions]).numpy(), columns=competition_classes)
], axis=1)
df

In [None]:
# add removed birds
df["afpkin1"] = 0
df["golher1"] = 0
df = pd.concat([df['row_id'],
          df[df.columns.difference(['row_id'])]\
         .sort_index(axis=1)], ignore_index=False, axis=1)

In [None]:
# what were the top birds predicted?
df.iloc[:, 2:].idxmax(axis=1).value_counts()

In [None]:
df["thrnig1"].describe()

In [None]:
df.describe()

In [None]:
df.to_csv('submission.csv', index=False)