In [1]:
import torch.nn as nn
import torch.nn.functional as F
import timm
import torch.utils.data as torchdata
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
from sklearn.model_selection import StratifiedKFold
import librosa
import random
import torch
import pandas as pd
import os
import numpy as np
from utils.metrics import calculate_metrics, metrics_to_string, calculate_competition_metrics
from torch.utils.data.sampler import WeightedRandomSampler
from timm.scheduler import CosineLRScheduler
from utils.init_utils import AverageMeter
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

## Init Utils

In [2]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def init_logger(log_file='train.log'):
    from logging import INFO, FileHandler, Formatter, StreamHandler, getLogger
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## CNN

In [3]:

class NormalizeMelSpec(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, X):
        mean = X.mean((1, 2), keepdim=True)
        std = X.std((1, 2), keepdim=True)
        Xstd = (X - mean) / (std + self.eps)
        norm_min, norm_max = \
            Xstd.min(-1)[0].min(-1)[0], Xstd.max(-1)[0].max(-1)[0]
        fix_ind = (norm_max - norm_min) > self.eps * torch.ones_like(
            (norm_max - norm_min)
        )
        V = torch.zeros_like(Xstd)
        if fix_ind.sum():
            V_fix = Xstd[fix_ind]
            norm_max_fix = norm_max[fix_ind, None, None]
            norm_min_fix = norm_min[fix_ind, None, None]
            V_fix = torch.max(
                torch.min(V_fix, norm_max_fix),
                norm_min_fix,
            )
            V_fix = (V_fix - norm_min_fix) / (norm_max_fix - norm_min_fix)
            V[fix_ind] = V_fix
        return V


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = torch.nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1.0 / self.p)

class CNN(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        mel_spec_params = config["mel_spec_params"]
        self.logmelspec_extractor = nn.Sequential(
                MelSpectrogram(
                    sample_rate=mel_spec_params["sample_rate"],
                    n_mels=mel_spec_params["n_mels"],
                    f_min=mel_spec_params["f_min"],
                    f_max=mel_spec_params["f_max"],
                    n_fft=mel_spec_params["n_fft"],
                    hop_length=mel_spec_params["hop_length"],
                    normalized=True,
                ),
                AmplitudeToDB(top_db=80.0),
                NormalizeMelSpec(),
            )
        
        
        out_indices = (3, 4)
        self.backbone = timm.create_model(
            config["backbone"],
            features_only=True,
            pretrained=config["pretrained"],
            in_chans=config["in_chans"],
            num_classes=0,
            out_indices=out_indices,
        )
        feature_dims = self.backbone.feature_info.channels()
        print(f"feature dims: {feature_dims}")
        self.global_pools = torch.nn.ModuleList([GeM() for _ in out_indices])
        self.mid_features = np.sum(feature_dims)
        self.neck = torch.nn.BatchNorm1d(self.mid_features)
        self.head = nn.Linear(self.mid_features, len(config["target_columns"]))
        
    def backbone_pass(self, x):
        
        x = self.logmelspec_extractor(x["wave"]).unsqueeze(1)
        ms = self.backbone(x)
        h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
        features = self.neck(h)
        features = self.head(features)
        x["logit"] = features
        return x
    
    def forward(self, x):
        return self.backbone(x)
            

## Dataset

In [4]:
class BirdDataset(torchdata.Dataset):

    def __init__(self, df, config, num_classes, add_secondary_labels = False):
        self.df = df
        self.bird2id = config['bird2id']
        self.period = config['period']
        self.secondary_coef = config['secondary_coef']
        self.df["secondary_labels"] = (
                self.df["secondary_labels"]
                .map(
                    lambda s: s.replace("[", "")
                    .replace("]", "")
                    .replace(",", "")
                    .replace("'", "")
                    .split(" ")
                ).values
            )
        
        self.smooth_label = config['smooth_label']
        self.num_classes = num_classes
        self.add_secondary_labels = add_secondary_labels
    
    def __len__(self):
        return len(self.df)
    
    
    def prepare_target(self, idx):
        target = np.zeros(self.num_classes, dtype=np.float32)
        if self.df["primary_label"].iloc[idx] != 'nocall':
            primary_label = self.bird2id[self.df["primary_label"].iloc[idx]]
            target[primary_label] = 1.0
            if self.add_secondary_labels:
                for s in self.df["secondary_labels"].iloc[idx]:
                    if s != "" and s in self.bird2id.keys():
                        target[self.bird2id[s]] = self.secondary_coef
        target = torch.from_numpy(target).float()
        return target

    
    def load_wave_and_crop(self, filename, period, start=None):

        waveform_orig, sample_rate = librosa.load(filename, sr=32000, mono=False)
    
        wave_len = len(waveform_orig)
        waveform = np.concatenate([waveform_orig, waveform_orig, waveform_orig])
    
        effective_length = sample_rate * period
        while len(waveform) < (period * sample_rate * 3):
            waveform = np.concatenate([waveform, waveform_orig])
        if start is not None:
            start = start - (period - 5) / 2 * sample_rate
            while start < 0:
                start += wave_len
            start = int(start)
        else:
            if wave_len < effective_length:
                start = np.random.randint(effective_length - wave_len)
            elif wave_len > effective_length:
                start = np.random.randint(wave_len - effective_length)
            elif wave_len == effective_length:
                start = 0
    
        waveform_seg = waveform[start: start + int(effective_length)]
    
        return waveform_orig, waveform_seg, sample_rate, start

    def __getitem__(self, idx):
        path = self.df["path"].iloc[idx]
        
        waveform_orig, waveform_seg, sample_rate, start = self.load_wave_and_crop(path, period=5, start=0)
        waveform_seg = torch.from_numpy(np.nan_to_num(waveform_seg)).float()
        rating = self.df["rating"].iloc[idx]
        target = self.prepare_target(idx)
        
        batch_dict = {
            "wave": waveform_seg,
            "rating": rating,
            "primary_targets": (target > 0.5).float(),
            "smooth_targets": target * (1-self.smooth_label) + self.smooth_label / target.size(-1),
        }
        
        return batch_dict


## Config

In [5]:
config = {
    "mel_spec_params": {
                "sample_rate": 32000,
                "n_mels": 128,
                "f_min": 20,
                "f_max": 16000,
                "n_fft": 2048,
                "hop_length": 512,
                "normalized": True,
                "top_db": 80,
            },
    "seed" : 42,
    "secondary_coef" : 1.0,
    "smooth_label" : 0.05,
    "period" : 5,
    "backbone": "eca_nfnet_l0",
    "pretrained": True,
    "fold": 3,
    "in_chans": 1,
    
    "output_folder" : "outputs",
    "exp_name": "EXP3",
    
    "device": get_device(),
    "apex" : True, 
    "max_grad_norm" : 10,
    
    "early_stopping" : 10,
    "epochs" : 120,
    
    "train_loader_config": {
        "batch_size": 32,
        "num_workers": 8,
        "pin_memory": True,
        "drop_last": True,
    },
    "val_loader_config": {
        "batch_size": 64,
        "num_workers": 8,
        "pin_memory": True,
        "drop_last": False,
    },
    
    "lr_max" : 2.5e-4,
    "lr_min" : 1e-7,
    "weight_decay" : 1e-6,

    
    
    
    "target_columns" : ['asbfly', 'ashdro1', 'ashpri1', 'ashwoo2', 'asikoe2', 'asiope1', 'aspfly1', 'aspswi1', 'barfly1', 'barswa', 'bcnher', 'bkcbul1', 'bkrfla1', 'bkskit1', 'bkwsti', 'bladro1', 'blaeag1', 'blakit1', 'blhori1', 'blnmon1', 'blrwar1', 'bncwoo3', 'brakit1', 'brasta1', 'brcful1', 'brfowl1', 'brnhao1', 'brnshr', 'brodro1', 'brwjac1', 'brwowl1', 'btbeat1', 'bwfshr1', 'categr', 'chbeat1', 'cohcuc1', 'comfla1', 'comgre', 'comior1', 'comkin1', 'commoo3', 'commyn', 'compea', 'comros', 'comsan', 'comtai1', 'copbar1', 'crbsun2', 'cregos1', 'crfbar1', 'crseag1', 'dafbab1', 'darter2', 'eaywag1', 'emedov2', 'eucdov', 'eurbla2', 'eurcoo', 'forwag1', 'gargan', 'gloibi', 'goflea1', 'graher1', 'grbeat1', 'grecou1', 'greegr', 'grefla1', 'grehor1', 'grejun2', 'grenig1', 'grewar3', 'grnsan', 'grnwar1', 'grtdro1', 'gryfra', 'grynig2', 'grywag', 'gybpri1', 'gyhcaf1', 'heswoo1', 'hoopoe', 'houcro1', 'houspa', 'inbrob1', 'indpit1', 'indrob1', 'indrol2', 'indtit1', 'ingori1', 'inpher1', 'insbab1', 'insowl1', 'integr', 'isbduc1', 'jerbus2', 'junbab2', 'junmyn1', 'junowl1', 'kenplo1', 'kerlau2', 'labcro1', 'laudov1', 'lblwar1', 'lesyel1', 'lewduc1', 'lirplo', 'litegr', 'litgre1', 'litspi1', 'litswi1', 'lobsun2', 'maghor2', 'malpar1', 'maltro1', 'malwoo1', 'marsan', 'mawthr1', 'moipig1', 'nilfly2', 'niwpig1', 'nutman', 'orihob2', 'oripip1', 'pabflo1', 'paisto1', 'piebus1', 'piekin1', 'placuc3', 'plaflo1', 'plapri1', 'plhpar1', 'pomgrp2', 'purher1', 'pursun3', 'pursun4', 'purswa3', 'putbab1', 'redspu1', 'rerswa1', 'revbul', 'rewbul', 'rewlap1', 'rocpig', 'rorpar', 'rossta2', 'rufbab3', 'ruftre2', 'rufwoo2', 'rutfly6', 'sbeowl1', 'scamin3', 'shikra1', 'smamin1', 'sohmyn1', 'spepic1', 'spodov', 'spoowl1', 'sqtbul1', 'stbkin1', 'sttwoo1', 'thbwar1', 'tibfly3', 'tilwar1', 'vefnut1', 'vehpar1', 'wbbfly1', 'wemhar1', 'whbbul2', 'whbsho3', 'whbtre1', 'whbwag1', 'whbwat1', 'whbwoo2', 'whcbar1', 'whiter2', 'whrmun', 'whtkin2', 'woosan', 'wynlau1', 'yebbab1', 'yebbul3', 'zitcis1'],
    
    "bird2id" : {'asbfly': 0, 'ashdro1': 1, 'ashpri1': 2, 'ashwoo2': 3, 'asikoe2': 4, 'asiope1': 5, 'aspfly1': 6, 'aspswi1': 7, 'barfly1': 8, 'barswa': 9, 'bcnher': 10, 'bkcbul1': 11, 'bkrfla1': 12, 'bkskit1': 13, 'bkwsti': 14, 'bladro1': 15, 'blaeag1': 16, 'blakit1': 17, 'blhori1': 18, 'blnmon1': 19, 'blrwar1': 20, 'bncwoo3': 21, 'brakit1': 22, 'brasta1': 23, 'brcful1': 24, 'brfowl1': 25, 'brnhao1': 26, 'brnshr': 27, 'brodro1': 28, 'brwjac1': 29, 'brwowl1': 30, 'btbeat1': 31, 'bwfshr1': 32, 'categr': 33, 'chbeat1': 34, 'cohcuc1': 35, 'comfla1': 36, 'comgre': 37, 'comior1': 38, 'comkin1': 39, 'commoo3': 40, 'commyn': 41, 'compea': 42, 'comros': 43, 'comsan': 44, 'comtai1': 45, 'copbar1': 46, 'crbsun2': 47, 'cregos1': 48, 'crfbar1': 49, 'crseag1': 50, 'dafbab1': 51, 'darter2': 52, 'eaywag1': 53, 'emedov2': 54, 'eucdov': 55, 'eurbla2': 56, 'eurcoo': 57, 'forwag1': 58, 'gargan': 59, 'gloibi': 60, 'goflea1': 61, 'graher1': 62, 'grbeat1': 63, 'grecou1': 64, 'greegr': 65, 'grefla1': 66, 'grehor1': 67, 'grejun2': 68, 'grenig1': 69, 'grewar3': 70, 'grnsan': 71, 'grnwar1': 72, 'grtdro1': 73, 'gryfra': 74, 'grynig2': 75, 'grywag': 76, 'gybpri1': 77, 'gyhcaf1': 78, 'heswoo1': 79, 'hoopoe': 80, 'houcro1': 81, 'houspa': 82, 'inbrob1': 83, 'indpit1': 84, 'indrob1': 85, 'indrol2': 86, 'indtit1': 87, 'ingori1': 88, 'inpher1': 89, 'insbab1': 90, 'insowl1': 91, 'integr': 92, 'isbduc1': 93, 'jerbus2': 94, 'junbab2': 95, 'junmyn1': 96, 'junowl1': 97, 'kenplo1': 98, 'kerlau2': 99, 'labcro1': 100, 'laudov1': 101, 'lblwar1': 102, 'lesyel1': 103, 'lewduc1': 104, 'lirplo': 105, 'litegr': 106, 'litgre1': 107, 'litspi1': 108, 'litswi1': 109, 'lobsun2': 110, 'maghor2': 111, 'malpar1': 112, 'maltro1': 113, 'malwoo1': 114, 'marsan': 115, 'mawthr1': 116, 'moipig1': 117, 'nilfly2': 118, 'niwpig1': 119, 'nutman': 120, 'orihob2': 121, 'oripip1': 122, 'pabflo1': 123, 'paisto1': 124, 'piebus1': 125, 'piekin1': 126, 'placuc3': 127, 'plaflo1': 128, 'plapri1': 129, 'plhpar1': 130, 'pomgrp2': 131, 'purher1': 132, 'pursun3': 133, 'pursun4': 134, 'purswa3': 135, 'putbab1': 136, 'redspu1': 137, 'rerswa1': 138, 'revbul': 139, 'rewbul': 140, 'rewlap1': 141, 'rocpig': 142, 'rorpar': 143, 'rossta2': 144, 'rufbab3': 145, 'ruftre2': 146, 'rufwoo2': 147, 'rutfly6': 148, 'sbeowl1': 149, 'scamin3': 150, 'shikra1': 151, 'smamin1': 152, 'sohmyn1': 153, 'spepic1': 154, 'spodov': 155, 'spoowl1': 156, 'sqtbul1': 157, 'stbkin1': 158, 'sttwoo1': 159, 'thbwar1': 160, 'tibfly3': 161, 'tilwar1': 162, 'vefnut1': 163, 'vehpar1': 164, 'wbbfly1': 165, 'wemhar1': 166, 'whbbul2': 167, 'whbsho3': 168, 'whbtre1': 169, 'whbwag1': 170, 'whbwat1': 171, 'whbwoo2': 172, 'whcbar1': 173, 'whiter2': 174, 'whrmun': 175, 'whtkin2': 176, 'woosan': 177, 'wynlau1': 178, 'yebbab1': 179, 'yebbul3': 180, 'zitcis1': 181}
}

## Utils

In [6]:
def setup_output_dir(config):
    os.makedirs(config["output_folder"], exist_ok=True)
    exp_folder = os.path.join(config["output_folder"], config["exp_name"])
    os.makedirs(exp_folder, exist_ok=True)
    config["exp_folder"] = exp_folder
    return config

def normalize_rating(df):
    return np.clip(df["rating"] / df["rating"].max(), 0.1, 1.0)


def do_kfold(df, KFOLD=5):
    skf = StratifiedKFold(n_splits=KFOLD, random_state=config["seed"], shuffle=True)
    df['fold'] = -1
    for fold, (train_idx, val_idx) in enumerate(skf.split(X=df, y=df["primary_label"].values)):
        df.loc[val_idx, 'fold'] = fold
    return df


def read_dataframe():
    df = pd.read_csv('../data/birdclef-2024/train_metadata.csv')
    return df


## Dataframe Prep

In [7]:
config = setup_output_dir(config)
df = read_dataframe()
set_seed(config["seed"])
df["rating"] = normalize_rating(df)
df = do_kfold(df, KFOLD=5)

## Train Fold

In [8]:
fold = config["fold"]
logger = init_logger(log_file=os.path.join(config["exp_folder"], f"{fold}.log"))


logger.info("=" * 90)
logger.info(f"Fold {fold} Training")
logger.info("=" * 90)


trn_df = df[df['fold'] != fold].reset_index(drop=True)
val_df = df[df['fold'] == fold].reset_index(drop=True)
print(trn_df.shape)
logger.info(trn_df.shape)
logger.info(trn_df['primary_label'].value_counts())
logger.info(val_df.shape)
logger.info(val_df['primary_label'].value_counts())



Fold 3 Training
(19567, 13)
primary_label
zitcis1    400
lirplo     400
litgre1    400
comgre     400
comkin1    400
          ... 
paisto1      5
niwpig1      4
asiope1      4
integr       4
blaeag1      4
Name: count, Length: 182, dtype: int64
(4892, 13)
primary_label
zitcis1    100
lirplo     100
comgre     100
comkin1    100
commoo3    100
          ... 
integr       1
asiope1      1
wynlau1      1
nilfly2      1
niwpig1      1
Name: count, Length: 182, dtype: int64


(19567, 13)


## Prepare Data Loaders

In [9]:
labels = trn_df["primary_label"].values
un_labels = np.unique(labels)
weight = {t: 1.0 / len(np.where(labels == t)[0]) for t in un_labels}
samples_weight = np.array([weight[t] for t in labels])
sampler = WeightedRandomSampler(torch.from_numpy(samples_weight).type('torch.DoubleTensor'),
                                len(samples_weight))

In [10]:
trn_dataset = BirdDataset(df=trn_df.reset_index(drop=True), config=config, num_classes=len(config["target_columns"]), add_secondary_labels=True)
train_loader = torch.utils.data.DataLoader(trn_dataset, shuffle=False, sampler=sampler,
                                                   **config["train_loader_config"])

v_ds = BirdDataset(df=val_df.reset_index(drop=True), config=config, num_classes=len(config["target_columns"]), add_secondary_labels=False)
val_loader = torch.utils.data.DataLoader(v_ds, shuffle=False, **config["val_loader_config"])

## Loss Function

In [11]:
import torch
import torchvision


class FocalLoss(torch.nn.Module):
    def __init__(
        self,
        alpha: float = 0.25,
        gamma: float = 2,
        reduction: str = "mean",
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, x):
        inputs = x["logit"]
        targets = x["primary_targets"]        
        return torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=inputs,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )

## Model

In [12]:
model = CNN(config).to(config["device"])
criterion = FocalLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr_max"], betas=(0.9, 0.999), eps=1e-08, weight_decay=config["weight_decay"], amsgrad=False, )
scheduler = CosineLRScheduler(optimizer, t_initial=10, warmup_t=1, cycle_limit=40, cycle_decay=1.0, lr_min=config["lr_min"], t_in_epochs=True, )

feature dims: [1536, 2304]


## Train / Val Functions

In [13]:
def batch_to_device(batch, device):
    batch_dict = {key: batch[key].to(device) for key in batch}
    return batch_dict

In [14]:
def train_one_epoch(data_loader, model, criterion, optimizer, scheduler, epoch, device, apex,
                             max_grad_norm, target_columns):

    model.train()
    losses = AverageMeter()
    optimizer.zero_grad(set_to_none=True)
    scaler = GradScaler(enabled=apex)
    iters = len(data_loader)
    gt = []
    preds = []
    with tqdm(enumerate(data_loader), total=len(data_loader)) as t:
        for i, (batch) in t:
            batch = batch_to_device(batch, device)

            with autocast(enabled=apex):
                outputs = model(batch)
                loss = criterion(outputs)

            losses.update(loss.item(), batch["wave"].size(0))
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(), max_norm=max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step(epoch + i / iters)
            t.set_postfix(
                loss=losses.avg,
                grad=grad_norm.item(),
                lr=optimizer.param_groups[0]["lr"]
            )
            gt.append(batch["primary_targets"].cpu().detach().numpy())
            preds.append(outputs["logit"].sigmoid().cpu().detach().numpy())

    gt = np.concatenate(gt)
    preds = np.concatenate(preds)
    scores = calculate_competition_metrics(gt, preds, target_columns)
    return scores, losses.avg

In [15]:
def validate_one_epoch(data_loader, model, criterion, device, apex, target_columns):
    model.eval()
    losses = AverageMeter()
    gt = []
    preds = []

    with tqdm(enumerate(data_loader), total=len(data_loader)) as t:
        for i, (batch) in t:
            batch = batch_to_device(batch, device)
            with autocast(enabled=apex):
                with torch.no_grad():
                    outputs = model(batch)
                    loss = criterion(outputs)

            losses.update(loss.item(), batch["wave"].size(0))
            t.set_postfix(loss=losses.avg)

            gt.append(batch["primary_targets"].cpu().detach().numpy())
            preds.append(batch["logit"].sigmoid().cpu().detach().numpy())

    gt = np.concatenate(gt)
    preds = np.concatenate(preds)
    scores = calculate_competition_metrics(gt, preds, target_columns)
    return scores, losses.avg


## Train Fold

In [16]:
def main():
    patience = config["early_stopping"]
    best_score = 0.0
    n_patience = 0
    
    for epoch in range(1, config["epochs"] + 1):
    
        train_scores, train_losses_avg = train_one_epoch(data_loader=train_loader, model=model,
                                                  criterion=criterion, optimizer=optimizer,
                                                  scheduler=scheduler,
                                                  epoch=0, device=config["device"],
                                                  apex=config["apex"],
                                                  max_grad_norm=config["max_grad_norm"],
                                                  target_columns=config["target_columns"])
    
        train_scores_str = metrics_to_string(train_scores, "Train")
        train_info = f"Epoch {epoch} - Train loss: {train_losses_avg:.4f}, {train_scores_str}"
        logger.info(train_info)
    
        val_scores, val_losses_avg = validate_one_epoch(data_loader=val_loader, model=model, criterion=criterion, device=config["device"],
                                            apex=config["apex"], target_columns=config["target_columns"])
    
        val_scores_str = metrics_to_string(val_scores, f"Valid")
        val_info = f"Epoch {epoch} - Valid loss: {val_losses_avg:.4f}, {val_scores_str}"
        logger.info(val_info)
    
        val_score = val_scores["ROC"]
    
        is_better = val_score > best_score
        best_score = max(val_score, best_score)
    
        exp_name = config["exp_name"]
    
        if is_better:
            state = {
                "epoch": epoch,
                "state_dict": model.state_dict(),
                "best_loss": best_score,
                "optimizer": optimizer.state_dict(),
            }
            logger.info(
                f"Epoch {epoch} - Save Best Score: {best_score:.4f} Model\n")
            torch.save(
                state,
                os.path.join(config["exp_folder"], f"{fold}.bin")
            )
            n_patience = 0
        else:
            n_patience += 1
            logger.info(
                f"Valid loss didn't improve last {n_patience} epochs.\n")
    
        if n_patience >= patience:
            logger.info(
                "Early stop, Training End.\n")
            state = {
                "epoch": epoch,
                "state_dict": model.state_dict(),
                "best_loss": best_score,
                "optimizer": optimizer.state_dict(),
            }
            torch.save(
                state,
                os.path.join(config["exp_folder"], f"final_{fold}.bin")
            )
            break

In [None]:
if __name__ == "__main__":
    main()
