In [1]:
import pandas as pd
import lightning as L

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks import RichProgressBar

from torchvision.transforms import v2
import albumentations as albu

import config
from dataset import BirdCLEFDataset
from model import BirdCLEFModel, ModelUtils

### Make a config

In [2]:
name_config = "resnet_config.toml"
current_config = config.load_config(name_config)

### Load data

In [3]:
data = pd.read_csv(current_config['meta_parameters']['metadata'])
data = pd.concat(
    [
        pd.Series(data['primary_label']),  
        pd.Series(data['filename'])
    ], 
    axis=1, names=['primary_label', 'filename']
)

data = pd.concat([data, pd.get_dummies(data['primary_label'])], axis=1)

train_data, valid_data = train_test_split(data, train_size=0.8, shuffle=True, random_state=42)
train_data = train_data.reset_index(drop=True)
valid_data = valid_data.reset_index(drop=True)


In [4]:
augs = albu.Compose([
    albu.HorizontalFlip(p=0.2), 
    albu.XYMasking(p=0.35, num_masks_x=(1, 3), num_masks_y=(1, 3), mask_x_length=(1, 10),mask_y_length=(1, 20))
])

In [5]:
train_dataset = BirdCLEFDataset(train_data, current_config, augs)
test_dataset = BirdCLEFDataset(valid_data, current_config, augs)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=current_config['model_parameters']['batch_size'], num_workers=3, pin_memory=True)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=current_config['model_parameters']['batch_size'], num_workers=3, pin_memory=True)

### Model initialization

In [6]:
model = BirdCLEFModel(current_config)

In [8]:
checkpoint = torch.load("resnetBEST.pt")
keys_to_remove = [key for key in checkpoint.keys() if (key.startswith('model.fc') or key.startswith('criterion.weight'))]
for key in keys_to_remove:
    del checkpoint[key]
model.load_state_dict(checkpoint, strict=False)

_IncompatibleKeys(missing_keys=['model.fc.weight', 'model.fc.bias', 'criterion.weight'], unexpected_keys=[])

### Training

In [9]:
wandb_logger = WandbLogger(current_config["meta_parameters"]["run_name"])
trainer = L.Trainer(max_epochs=current_config['model_parameters']['epochs'], logger=wandb_logger, callbacks=[RichProgressBar()])  

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(model, train_dataloader, test_dataloader)

`Trainer.fit` stopped: `max_epochs=40` reached.


### Test

In [11]:
trainer.test(model, test_dataloader)

[{'test_loss': 4.708274256178965, 'test_f1': 0.9073024988174438}]

In [12]:
torch.save(model.state_dict(), "resnet50best.pt")

### Convert to ONNX

In [35]:
checkpoint = torch.load('/workspace/birdclef/lightning_logs/h24iw852/checkpoints/epoch=29-step=63947.ckpt')['state_dict']

keys_to_remove = [key for key in checkpoint.keys() if (key.startswith('model.classifier') or key.startswith('model.criterion') or key.startswith('criterion.weight'))]
for key in keys_to_remove:
    del checkpoint[key]

model.load_state_dict(checkpoint, strict=False)

_IncompatibleKeys(missing_keys=['model.classifier.weight', 'model.classifier.bias', 'criterion.weight'], unexpected_keys=[])

In [12]:
import torch.onnx as onnx
import pytorch_lightning as pl

path_to_checkpoint = '/workspace/birdclef/resnetBEST.pt'
checkpoint = torch.load(path_to_checkpoint)
#checkpoint = checkpoint['state_dict']
model.load_state_dict(checkpoint)
model.eval().to('cpu')

dummy_input = torch.randn(48, 3, 128, 313)
onnx_file_path = "resnet.onnx"
torch.onnx.export(model, dummy_input, onnx_file_path)