In [11]:
import pandas as pd
import lightning as L

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from pytorch_lightning.loggers import WandbLogger

from torchvision.transforms import v2
import albumentations as albu

from config import Config
from dataset import BirdCLEFDataset
from model import BirdCLEFModel

### Make a config

In [12]:
config = Config()

### Load data

In [13]:
data = pd.read_csv(config.metadata)
data = pd.concat(
    [
        pd.Series(data['primary_label']),  
        pd.Series(data['filename'])
    ], 
    axis=1, names=['primary_label', 'filename']
)

data = pd.concat([data, pd.get_dummies(data['primary_label'])], axis=1)

train_data, valid_data = train_test_split(data, train_size=0.7, shuffle=True, random_state=42)
train_data = train_data.reset_index(drop=True)
valid_data = valid_data.reset_index(drop=True)


In [14]:
augs = albu.Compose([albu.XYMasking(p=0.4, num_masks_x=(1, 3), num_masks_y=(1, 3), mask_x_length=(1, 10),mask_y_length=(1, 20))])

In [15]:
train_dataset = BirdCLEFDataset(train_data, config, augs)
test_dataset = BirdCLEFDataset(valid_data, config, augs)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size, num_workers=3, pin_memory=True)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=config.batch_size, num_workers=3, pin_memory=True)

### Model initialization

In [16]:
model = BirdCLEFModel(config)

### Training

In [None]:
wandb_logger = WandbLogger('asphodel_birdclef2024')
trainer = L.Trainer(max_epochs=config.epochs, logger=wandb_logger)  
trainer.fit(model, train_dataloader, test_dataloader)

### Test

In [None]:
trainer.test(model, test_dataloader)

### Convert to ONNX

In [None]:
import torch.onnx as onnx
import pytorch_lightning as pl

path_to_checkpoint = ''
checkpoint = torch.load(path_to_checkpoint)
checkpoint = checkpoint['state_dict']
model.load_state_dict(checkpoint)
model.eval()

dummy_input = torch.randn(48, 3, 256, 256)
onnx_file_path = "model.onnx"
torch.onnx.export(model, dummy_input, onnx_file_path)