In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, BackboneFinetuning, EarlyStopping
import torch
import os
import gc
import json
import importlib
from pathlib import Path
import numpy as np
import glob
import configs
from ast import literal_eval
import pandas as pd

from modules.preprocess import preprocess,prepare_cfg
from modules.dataset import get_train_dataloader
from modules.model import load_model
import modules.inception_next_nano

In [None]:
# move to repo root
cur_dir = Path().resolve()

if not (cur_dir / "notebooks").exists():
    os.chdir(os.path.abspath("../"))
print(f"{Path().resolve()}")

# Config

Set the configuration name for the training model.

#### 2021-2nd CNN Model (seresnext26ts)
We did not pre-train the seresnext model with the BC2021-2023 data.

#### 2021-2nd CNN Model (rexnet_150)
```python
model_name = "cnn_v3_rexnet"
stage = "pretrain_bce"
```

#### Simple CNN Model (inception_next_nano)
```python
model_name = "simple_cnn_v1"
stage = "pretrain_bce"
```

In [None]:
model_name = "simple_cnn_v1"
stage = "pretrain_bce"

cfg = importlib.import_module(f'configs.{model_name}').basic_cfg
cfg = prepare_cfg(cfg,stage)
cfg.train_data = "./notebooks/train_metadata_rich_pretrain_merge.pkl"

In [None]:
pl.seed_everything(cfg.seed[stage], workers=True)

df_train, df_valid, df_label_train, df_label_valid, transforms = preprocess(cfg, stage)
df_train["version"] = "2023"
df_valid["version"] = "2023"
df_train.shape, df_valid.shape

In [None]:
pseudo = None
dl_train, dl_val, ds_train, ds_val = get_train_dataloader(
        df_train,
        df_valid,
        df_label_train,
        df_label_valid,
        cfg,
        pseudo,
        transforms
    )

In [None]:
logger = WandbLogger(project='BirdClef-2023', name=f'{model_name}_{stage}')
checkpoint_callback = ModelCheckpoint(
    #monitor='val_loss',
    monitor=None,
    dirpath= cfg.output_path[stage],
    save_top_k=0,
    save_last= True,
    save_weights_only=True,
    #filename= './ckpt_epoch_{epoch}_val_loss_{val_loss:.2f}',
    #filename ='./ckpt_{epoch}_{val_loss}',
    verbose= True,
    every_n_epochs=1,
    mode='min'
)

In [None]:
callbacks_to_use = [checkpoint_callback]
model = load_model(cfg,stage)
trainer = pl.Trainer(
    devices=1,
    val_check_interval=1.0,
    deterministic=None,
    max_epochs=cfg.epochs[stage],
    logger=logger,
    callbacks=callbacks_to_use,
    precision=cfg.PRECISION, accelerator="auto",
)

In [None]:
df_label_valid.shape

In [None]:
trainer.fit(model, train_dataloaders = dl_train, val_dataloaders = dl_val)