In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, BackboneFinetuning, EarlyStopping
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

import importlib
from pathlib import Path
import numpy as np
import glob
import timm
import pandas as pd
import torchaudio as ta

from modules.preprocess import preprocess,prepare_cfg
from modules.dataset import get_train_dataloader
from modules.model import load_model
import modules.inception_next_nano


In [None]:
# move to repo root
cur_dir = Path().resolve()

if not (cur_dir / "notebooks").exists():
    os.chdir(os.path.abspath("../"))
print(f"{Path().resolve()}")

# Config

Set the configuration name for the training model.

#### 2021-2nd CNN Model (seresnext26ts)
```python
model_name = "cnn_v1"
stage = "train_bce"
```

#### 2021-2nd CNN Model (rexnet_150)
```python
model_name = "cnn_v3_rexnet"
stage = "train_bce"
```

#### Simple CNN Model (inception_next_nano)
```python
model_name = "simple_cnn_v1"
stage = "train_bce"
```

In [None]:
model_name = "cnn_v1"
stage = "train_bce"

cfg = importlib.import_module(f'configs.{model_name}').basic_cfg
cfg = prepare_cfg(cfg,stage)

pl.seed_everything(cfg.seed[stage], workers=True)


In [None]:
df_train, df_valid, df_label_train, df_label_valid, transforms = preprocess(cfg, stage)
df_train["version"] = "2023"
df_valid["version"] = "2023"
len(df_train), len(df_valid)

In [None]:
if cfg.use_2024_additional_cleaned:
    # append handlabeled data
    df_additional = pd.read_pickle(cfg.train_2024_additional_cleaned)
    df_additional["version"] = "2023"
    df_additional["presence_type"]  = "foreground"

    # repeating records
    df_additional = pd.concat([df_additional] * cfg.num_cleaned_repeat).reset_index(drop=True)

    # make one-hot label
    add_primary_label = pd.Categorical(df_additional["primary_label"], categories=cfg.bird_cols)
    add_primary_label = pd.get_dummies(add_primary_label,  dtype=np.float64)
    assert (add_primary_label.columns == df_label_train.columns).all()


    df_train = pd.concat([df_train, df_additional]).reset_index(drop=True)
    df_label_train = pd.concat([df_label_train, add_primary_label]).reset_index(drop=True)
    # df_train.shape

    # shuffle
    perm_idx = df_train.index.to_series().sample(frac=1, random_state=0)
    df_train = df_train.iloc[perm_idx].reset_index(drop=True)
    df_label_train = df_label_train.iloc[perm_idx].reset_index(drop=True)

    all_primary_labels = df_train["primary_label"]
    sample_weights = (
        all_primary_labels.value_counts() / 
        all_primary_labels.value_counts().sum()
    )  ** (cfg.class_exponent_weight)
    sample_weights = sample_weights / sample_weights.mean()
    df_train["weight"] = sample_weights[df_train["primary_label"].values].values

len(df_train)

In [None]:
pseudo = None
dl_train, dl_val, ds_train, ds_val = get_train_dataloader(
        df_train,
        df_valid,
        df_label_train,
        df_label_valid,
        cfg,
        pseudo,
        transforms
    )


## Trainer

In [None]:
logger = WandbLogger(project='BirdClef-2023', name=f'{model_name}_{stage}')
checkpoint_callback = ModelCheckpoint(
    #monitor='val_loss',
    monitor=None,
    dirpath= cfg.output_path[stage],
    save_top_k=0,
    save_last= True,
    save_weights_only=True,
    #filename= './ckpt_epoch_{epoch}_val_loss_{val_loss:.2f}',
    #filename ='./ckpt_{epoch}_{val_loss}',
    verbose= True,
    every_n_epochs=1,
    mode='min'
)

In [None]:
callbacks_to_use = [checkpoint_callback]
model = load_model(cfg,stage)
trainer = pl.Trainer(
    devices=1,
    val_check_interval=1.0,
    deterministic=None,
    max_epochs=cfg.epochs[stage],
    logger=logger,
    callbacks=callbacks_to_use,
    precision=cfg.PRECISION, accelerator="auto",
)

In [None]:
trainer.fit(model, train_dataloaders = dl_train, val_dataloaders = dl_val)