In [1]:
import os, random
import cv2
import math
import librosa
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

from collections import Counter

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import WeightedRandomSampler
from torchvision.models import efficientnet
from torchvision.transforms import transforms
# from efficientnet_pytorch import EfficientNet

import timm

import scikitplot as skplt
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit, KFold
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.preprocessing import LabelEncoder

from glob import glob
from IPython.display import display, Audio

import cupy as cp
from cupyx.scipy import signal as cupy_signal
import yaml

from metric import score

import wandb

import plotly.graph_objects as go
import plotly.express as px

  from .autonotebook import tqdm as notebook_tqdm
  cupy._util.experimental('cupyx.jit.rawkernel')


In [2]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mlhk[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [4]:
default_config = {
    "VERSION": "v0.3",
    "DATA_PATH": "inputs",
    "LOAD_SPEC_DATA": True,
    "SEED": 24,
    "SAMPLE_RATE": 32000,
    "N_FFT": 1095,
    "WIN_SIZE": 412,
    "WIN_LAP": 100,
    "MIN_FREQ": 40,
    "MAX_FREQ": 15000,
    "EPOCHS": 10,
    "BACHSIZE": 16
}

try:
    with open('config.yaml', 'r') as f:
        default_config = yaml.load(f, Loader=yaml.SafeLoader)
except:
    pass

default_config

{'VERSION': 'v1.3',
 'DESCRIPTION': 'Get all 5s data',
 'DATA_PATH': 'inputs',
 'LOAD_SPEC_DATA': True,
 'SEED': 24,
 'SAMPLE_RATE': 32000,
 'N_FFT': 1095,
 'WIN_SIZE': 412,
 'WIN_LAP': 100,
 'MIN_FREQ': 40,
 'MAX_FREQ': 15000,
 'EPOCHS': 10,
 'FOLD': 5,
 'BACTHSIZE': 16,
 'LABEL_SMOOTHING': 0.0}

In [5]:
# Reproducibility
torch.manual_seed(default_config["SEED"])
random.seed(default_config["SEED"])
np.random.seed(default_config["SEED"])

In [6]:
def oog2spec_via_cupy(audio_data):
    
    audio_data = cp.array(audio_data)
    
    # handles NaNs
    mean_signal = cp.nanmean(audio_data)
    audio_data = cp.nan_to_num(audio_data, nan=mean_signal) if cp.isnan(audio_data).mean() < 1 else cp.zeros_like(audio_data)
    
    # to spec.
    frequencies, times, spec_data = cupy_signal.spectrogram(
        audio_data, 
        fs=default_config["SAMPLE_RATE"], 
        nfft=default_config["N_FFT"], 
        nperseg=default_config["WIN_SIZE"], 
        noverlap=default_config["WIN_LAP"], 
        window='hann'
    )
    
    # Filter frequency range
    valid_freq = (frequencies >= default_config["MIN_FREQ"]) & (frequencies <= default_config["MAX_FREQ"])
    spec_data = spec_data[valid_freq, :]
    
    # Log
    spec_data = cp.log10(spec_data + 1e-20)
    
    # min/max normalize
    spec_data = spec_data - spec_data.min()
    spec_data = spec_data / spec_data.max()
    
    return spec_data.get()

In [None]:
meta_data = pd.read_csv(f"{default_config['DATA_PATH']}/train_metadata.csv")
meta_data

In [8]:
f = open("data/duplicates.txt", 'r').readlines()
dup = [row.split(',')[0] for row in f]
print(f"Get {len(dup)} duplicates file")
for d in dup:
    meta_data = meta_data[meta_data.filename!=d]
meta_data.reset_index(drop=True)
meta_data.drop_duplicates(subset=['filename'], inplace=True)
meta_data.reset_index(drop=True, inplace=True)

Get 150 duplicates file


In [None]:
meta_data = meta_data[["primary_label", "filename"]]

label_list = sorted(meta_data["primary_label"].unique())
label_id_list = list(range(len(label_list)))
label2id = dict(zip(label_list, label_id_list))
id2label = dict(zip(label_id_list, label_list))

class_weights = meta_data.groupby(["primary_label"]).size().to_dict()
class_weights = dict((label2id[label], (num/len(meta_data))**(-0.5)) for (label, num) in class_weights.items())
class_weights


In [None]:
list_audio = glob("inputs/train_audio/*/*.ogg")
audio_count = dict.fromkeys(label_list,0)
for audio in list_audio:
    audio_count[audio.split("\\")[-2]] +=1

audio_count = dict(sorted(audio_count.items(), key=lambda item: item[1]))
audio_count

In [11]:
if default_config["LOAD_SPEC_DATA"]:
        all_bird_data = np.load(default_config["LOAD_SPEC_DATA"], allow_pickle=True).item()
else:
        all_bird_data = dict()

        for i, row_metadata in tqdm(meta_data.iterrows()):

                # load ogg
                audio_data, _ = librosa.load(f"{default_config['DATA_PATH']}/train_audio/{row_metadata.filename}", sr=default_config["SAMPLE_RATE"])

                # crop
                n_copy = math.ceil(5 * default_config["SAMPLE_RATE"] / len(audio_data))
                if n_copy > 1: audio_data = np.concatenate([audio_data]*n_copy)

                for start_idx in range(0,len(audio_data)-(len(audio_data)%(5*default_config["SAMPLE_RATE"])),5*default_config["SAMPLE_RATE"]):
                        start_idx = int(start_idx)
                        end_idx = int(start_idx + 5.0 * default_config["SAMPLE_RATE"])
                        input_audio = audio_data[start_idx:end_idx]

                        # ogg to spec.
                        input_spec = oog2spec_via_cupy(input_audio)

                        input_spec = cv2.resize(input_spec, (256, 256), interpolation=cv2.INTER_AREA)

                        all_bird_data[f"{row_metadata.filename}_{start_idx}"] = input_spec.astype(np.float32)
        # save to file
        np.save(os.path.join("data", f'spec_5sec_256_256_drop_dup.npy'), all_bird_data)

In [None]:
meta_data1 = pd.DataFrame(all_bird_data.items(), columns=["filename","audioarray"])
del all_bird_data
def drop_startidx(x):
    return x.split("_")[0]

def get_primary_label(x):
    return x.split("/")[0]

meta_data1["filename"] = meta_data1["filename"].apply(drop_startidx)
meta_data1['primary_label'] = meta_data1["filename"].apply(get_primary_label)
meta_data = meta_data1
meta_data

In [13]:
class BirdCLEF_Dataset(torch.utils.data.Dataset):
    def __init__(self, df_data, transforms):
        super(BirdCLEF_Dataset).__init__()
        self.df_data = df_data
        self.transform = transforms

    def __len__(self):
        return len(self.df_data)
    
    def __getitem__(self, index):
        X = np.array([self.df_data.iloc[index].audioarray])
        y = label2id[self.df_data.iloc[index].filename.split("/")[0]]

        if (y in minority_class):
            x = self.transform(x)

        return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

In [14]:
class BirdCLEF_Model_EfficientnetB0(nn.Module):
    def __init__(self, num_class):
        super(BirdCLEF_Model_EfficientnetB0, self).__init__()
        self.backbone = timm.create_model('tf_efficientnet_b0.in1k', pretrained=True, in_chans=1,  num_classes=num_class)
    
    def forward(self, x):
        x = self.backbone(x)
        return x

In [15]:
ssk = StratifiedKFold(n_splits=default_config["FOLD"], shuffle=True, random_state=default_config["SEED"])
meta_data['fold'] = 0
for fold, (train_idx, val_idx) in enumerate(ssk.split(np.zeros(len(meta_data)), meta_data["primary_label"].to_numpy())):
    meta_data.loc[val_idx, 'fold'] = fold

In [16]:
if not os.path.exists(f"model/{default_config['VERSION']}"):
    os.makedirs(f"model/{default_config['VERSION']}")

for f in range(default_config["FOLD"]):
    run = wandb.init(project="BirdCLEF2024_LeviKaay", name=f"BaseModel_EfficientB0_Fold{f}_{default_config['VERSION']}", entity="Kaay", config=default_config)
    
    # main loop of f-fold
    print('================================================================')
    print(f"==== Running training for fold {f} ====")
    
    train_df = meta_data[meta_data['fold'] != f].copy()
    valid_df = meta_data[meta_data['fold'] == f].copy()
    print(f'Train Samples: {len(train_df)}')
    print(f'Valid Samples: {len(valid_df)}')

    train_dataset = BirdCLEF_Dataset(train_df)
    valid_dataset = BirdCLEF_Dataset(valid_df)

    #-------------Over sampling---------------
    sample_weights = [0] * len(train_dataset)
    for idx, (data, label) in enumerate(train_dataset):
        sample_weights[idx] = class_weights[label.item()]

    sampler = WeightedRandomSampler(sample_weights, num_samples=len(train_dataset), replacement=True)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=default_config["BACTHSIZE"], sampler=sampler)
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=default_config["BACTHSIZE"], shuffle=True)

    model = BirdCLEF_Model_EfficientnetB0(num_class=len(label_list)).to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=default_config["LABEL_SMOOTHING"])
    # criterion = nn.BCEWithLogitsLoss(reduction='sum')
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.00017, steps_per_epoch=len(train_dataloader), epochs=default_config["EPOCHS"], anneal_strategy='cos')

    # Training loop
    for epoch in range(default_config["EPOCHS"]):
        for idx,batch in enumerate(train_dataloader):
            inputs, targets = batch

            inputs = inputs.to(device)
            targets = targets.to(device)
            # targets = torch.nn.functional.one_hot(targets, len(label_list)).to(device).float()
            optimizer.zero_grad()
            # Forward pass
            outputs = model(inputs)
            
            # Compute loss
            loss = criterion(outputs, targets)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            scheduler.step()
            print(f'Step {idx}/{len(train_dataloader)}, Loss: {loss.item():.4f}\r', end='', flush=True)
            wandb.log({"Learning Rate": scheduler.get_last_lr()[0]})
        model.eval()
        valid_step = []
        with torch.no_grad():
            for idx, batch in enumerate(valid_dataloader):
                inputs, targets = batch
                inputs = inputs.to(device)

                outputs = model(inputs)
                valid_step.append({"logits": outputs, "targets": targets})
            
            output_val = torch.cat([x['logits'] for x in valid_step], dim=0).cpu().detach()
            target_val = torch.cat([x['targets'] for x in valid_step], dim=0).cpu().detach()

            val_loss = criterion(output_val, target_val)
            target_val = torch.nn.functional.one_hot(target_val, len(label_list))
            
            gt_df = pd.DataFrame(target_val.numpy().astype(np.float32), columns=label_list)
            pred_df = pd.DataFrame(output_val.numpy().astype(np.float32), columns=label_list)
            
            gt_df['id'] = [f'id_{i}' for i in range(len(gt_df))]
            pred_df['id'] = [f'id_{i}' for i in range(len(pred_df))]

            val_roc_auc = score(gt_df, pred_df, row_id_column_name='id')
        
        print(f"Epoch {epoch+1}/{default_config['EPOCHS']}, train_loss: {loss.item():.4f}, valid_loss: {val_loss:.4f} valid_roc_auc: {val_roc_auc:.4f}, lr: {scheduler.get_last_lr()}")
        wandb.log({"Training Loss": loss.item(),"Valid Loss": val_loss ,"Valid ROC_AUC": val_roc_auc})
    run.finish()
    torch.save(model.state_dict(), f"model/{default_config['VERSION']}/BaseModel_EfficientB0_Fold{f}.pt")


[34m[1mwandb[0m: Currently logged in as: [33mlhk[0m ([33mKaay[0m). Use [1m`wandb login --relogin`[0m to force relogin


==== Running training for fold 0 ====
Train Samples: 154280
Valid Samples: 38570
Epoch 1/10, train_loss: 2.9878, valid_loss: 1.7302 valid_roc_auc: 0.9491, lr: [4.760255820614569e-05]
Epoch 2/10, train_loss: 0.9382, valid_loss: 1.1740 valid_roc_auc: 0.9783, lr: [0.00012920511625188894]
Epoch 3/10, train_loss: 0.4279, valid_loss: 0.8755 valid_roc_auc: 0.9840, lr: [0.00016999999990794095]
Epoch 4/10, train_loss: 1.2097, valid_loss: 0.7634 valid_roc_auc: 0.9850, lr: [0.00016158067091301808]
Epoch 5/10, train_loss: 0.0619, valid_loss: 0.6554 valid_roc_auc: 0.9870, lr: [0.00013799366818464807]
Epoch 6/10, train_loss: 0.0083, valid_loss: 0.6015 valid_roc_auc: 0.9856, lr: [0.00010391068688878939]
Epoch 7/10, train_loss: 0.0237, valid_loss: 0.5539 valid_roc_auc: 0.9843, lr: [6.608227947142118e-05]
Epoch 8/10, train_loss: 0.0001, valid_loss: 0.4985 valid_roc_auc: 0.9858, lr: [3.2000825956524385e-05]
Epoch 9/10, train_loss: 0.0001, valid_loss: 0.4747 valid_roc_auc: 0.9872, lr: [8.416576194321667e

0,1
Learning Rate,▁▁▂▃▃▄▅▆▇▇███████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
Training Loss,█▃▂▄▁▁▁▁▁▁
Valid Loss,█▅▃▃▂▂▁▁▁▁
Valid ROC_AUC,▁▆▇███▇███

0,1
Learning Rate,0.0
Training Loss,8e-05
Valid Loss,0.47654
Valid ROC_AUC,0.98758


==== Running training for fold 1 ====
Train Samples: 154280
Valid Samples: 38570
Epoch 1/10, train_loss: 1.9119, valid_loss: 1.7355 valid_roc_auc: 0.9456, lr: [4.760255820614569e-05]
Epoch 2/10, train_loss: 0.6732, valid_loss: 1.1210 valid_roc_auc: 0.9784, lr: [0.00012920511625188894]
Epoch 3/10, train_loss: 0.4473, valid_loss: 0.8861 valid_roc_auc: 0.9837, lr: [0.00016999999990794095]
Epoch 4/10, train_loss: 0.2871, valid_loss: 0.7215 valid_roc_auc: 0.9837, lr: [0.00016158067091301808]
Epoch 5/10, train_loss: 0.5804, valid_loss: 0.6443 valid_roc_auc: 0.9849, lr: [0.00013799366818464807]
Epoch 6/10, train_loss: 0.0588, valid_loss: 0.5663 valid_roc_auc: 0.9863, lr: [0.00010391068688878939]
Epoch 7/10, train_loss: 0.0504, valid_loss: 0.5121 valid_roc_auc: 0.9857, lr: [6.608227947142118e-05]
Epoch 8/10, train_loss: 0.0001, valid_loss: 0.4943 valid_roc_auc: 0.9870, lr: [3.2000825956524385e-05]
Epoch 9/10, train_loss: 0.0000, valid_loss: 0.4766 valid_roc_auc: 0.9865, lr: [8.416576194321667e

0,1
Learning Rate,▁▁▂▃▃▄▅▆▇▇███████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
Training Loss,█▃▃▂▃▁▁▁▁▁
Valid Loss,█▅▃▂▂▂▁▁▁▁
Valid ROC_AUC,▁▇▇▇██████

0,1
Learning Rate,0.0
Training Loss,6e-05
Valid Loss,0.4691
Valid ROC_AUC,0.98704


==== Running training for fold 2 ====
Train Samples: 154280
Valid Samples: 38570
Epoch 1/10, train_loss: 2.6655, valid_loss: 1.7204 valid_roc_auc: 0.9521, lr: [4.760255820614569e-05]
Epoch 2/10, train_loss: 0.9365, valid_loss: 1.1518 valid_roc_auc: 0.9776, lr: [0.00012920511625188894]
Epoch 3/10, train_loss: 0.0813, valid_loss: 0.9180 valid_roc_auc: 0.9840, lr: [0.00016999999990794095]
Epoch 4/10, train_loss: 0.0658, valid_loss: 0.7641 valid_roc_auc: 0.9861, lr: [0.00016158067091301808]
Epoch 5/10, train_loss: 0.2974, valid_loss: 0.6867 valid_roc_auc: 0.9838, lr: [0.00013799366818464807]
Epoch 6/10, train_loss: 0.8223, valid_loss: 0.5736 valid_roc_auc: 0.9869, lr: [0.00010391068688878939]
Epoch 7/10, train_loss: 0.0003, valid_loss: 0.5205 valid_roc_auc: 0.9843, lr: [6.608227947142118e-05]
Epoch 8/10, train_loss: 0.0009, valid_loss: 0.5113 valid_roc_auc: 0.9856, lr: [3.2000825956524385e-05]
Epoch 9/10, train_loss: 0.0015, valid_loss: 0.4776 valid_roc_auc: 0.9865, lr: [8.416576194321667e

0,1
Learning Rate,▁▁▂▃▃▄▅▆▇▇███████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
Training Loss,█▃▁▁▂▃▁▁▁▁
Valid Loss,█▅▃▃▂▂▁▁▁▁
Valid ROC_AUC,▁▆▇█▇█▇███

0,1
Learning Rate,0.0
Training Loss,0.0004
Valid Loss,0.47761
Valid ROC_AUC,0.9871


==== Running training for fold 3 ====
Train Samples: 154280
Valid Samples: 38570
Epoch 1/10, train_loss: 2.0116, valid_loss: 1.7458 valid_roc_auc: 0.9482, lr: [4.760255820614569e-05]
Epoch 2/10, train_loss: 1.5779, valid_loss: 1.0668 valid_roc_auc: 0.9784, lr: [0.00012920511625188894]
Epoch 3/10, train_loss: 0.2863, valid_loss: 0.8690 valid_roc_auc: 0.9847, lr: [0.00016999999990794095]
Epoch 4/10, train_loss: 0.3880, valid_loss: 0.7325 valid_roc_auc: 0.9849, lr: [0.00016158067091301808]
Epoch 5/10, train_loss: 0.4466, valid_loss: 0.6634 valid_roc_auc: 0.9847, lr: [0.00013799366818464807]
Epoch 6/10, train_loss: 0.2023, valid_loss: 0.5805 valid_roc_auc: 0.9856, lr: [0.00010391068688878939]
Epoch 7/10, train_loss: 0.0000, valid_loss: 0.5623 valid_roc_auc: 0.9840, lr: [6.608227947142118e-05]
Epoch 8/10, train_loss: 0.0003, valid_loss: 0.5075 valid_roc_auc: 0.9841, lr: [3.2000825956524385e-05]
Epoch 9/10, train_loss: 0.0020, valid_loss: 0.4969 valid_roc_auc: 0.9852, lr: [8.416576194321667e

0,1
Learning Rate,▁▁▂▃▃▄▅▆▇▇███████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
Training Loss,█▆▂▂▃▂▁▁▁▁
Valid Loss,█▄▃▂▂▂▁▁▁▁
Valid ROC_AUC,▁▇████████

0,1
Learning Rate,0.0
Training Loss,0.0
Valid Loss,0.48682
Valid ROC_AUC,0.98539


==== Running training for fold 4 ====
Train Samples: 154280
Valid Samples: 38570
Epoch 1/10, train_loss: 2.5626, valid_loss: 1.7607 valid_roc_auc: 0.9482, lr: [4.760255820614569e-05]
Epoch 2/10, train_loss: 0.7122, valid_loss: 1.1084 valid_roc_auc: 0.9764, lr: [0.00012920511625188894]
Epoch 3/10, train_loss: 0.4616, valid_loss: 0.9821 valid_roc_auc: 0.9821, lr: [0.00016999999990794095]
Epoch 4/10, train_loss: 0.5413, valid_loss: 0.7357 valid_roc_auc: 0.9852, lr: [0.00016158067091301808]
Epoch 5/10, train_loss: 0.2633, valid_loss: 0.6528 valid_roc_auc: 0.9834, lr: [0.00013799366818464807]
Epoch 6/10, train_loss: 0.5144, valid_loss: 0.6136 valid_roc_auc: 0.9852, lr: [0.00010391068688878939]
Epoch 7/10, train_loss: 0.2688, valid_loss: 0.5738 valid_roc_auc: 0.9843, lr: [6.608227947142118e-05]
Epoch 8/10, train_loss: 0.0001, valid_loss: 0.5006 valid_roc_auc: 0.9844, lr: [3.2000825956524385e-05]
Epoch 9/10, train_loss: 0.0000, valid_loss: 0.4841 valid_roc_auc: 0.9860, lr: [8.416576194321667e

0,1
Learning Rate,▁▁▂▃▃▄▅▆▇▇███████▇▇▇▇▆▆▆▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
Training Loss,█▃▂▂▂▂▂▁▁▁
Valid Loss,█▄▄▂▂▂▁▁▁▁
Valid ROC_AUC,▁▆▇███████

0,1
Learning Rate,0.0
Training Loss,7e-05
Valid Loss,0.48576
Valid ROC_AUC,0.98576


In [17]:
input_tensor = torch.randn(default_config["BACTHSIZE"], 1, 256, 256)

### Export model to ONNX

In [18]:
  # input shape
input_names = ['x']
output_names = ['output']

for fold in range(default_config["FOLD"]):
    bird_model = BirdCLEF_Model_EfficientnetB0(num_class=len(label_list))
    weights = torch.load(f"model/{default_config['VERSION']}/BaseModel_EfficientB0_Fold{fold}.pt", map_location=torch.device('cpu'))
    bird_model.load_state_dict(weights)
    bird_model.eval()

    torch.onnx.export(bird_model, input_tensor, f"model/{default_config['VERSION']}/BaseModel_EfficientB0_Fold{fold}.onnx", input_names=input_names, output_names=output_names)

### Export model to OpenVino

In [19]:
from openvino.runtime import Core
import openvino as ov

In [20]:
for fold in range(default_config["FOLD"]):
    bird_model = BirdCLEF_Model_EfficientnetB0(num_class=len(label_list))
    weights = torch.load(f"model/{default_config['VERSION']}/BaseModel_EfficientB0_Fold{fold}.pt", map_location=torch.device('cpu'))
    bird_model.load_state_dict(weights)
    bird_model.eval()

    ov_model = ov.convert_model(bird_model, example_input=input_tensor)
    ov.save_model(ov_model, f"model/{default_config['VERSION']}/BaseModel_EfficientB0_Fold{fold}.xml")

In [21]:
ie = Core()
classification_model_xml = f"model/v0.4/BaseModel_EfficientB0_Fold0.xml"
model = ie.read_model(model=classification_model_xml)

### PLot the ROC curve for each class

In [22]:
valid_df = meta_data[meta_data['fold'] == 0].copy()
valid_dataset = BirdCLEF_Dataset(valid_df)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=default_config["BACTHSIZE"], shuffle=True)
model = BirdCLEF_Model_EfficientnetB0(num_class=len(label_list))
weights = torch.load(f"model/{default_config['VERSION']}/BaseModel_EfficientB0_Fold0.pt")
model.load_state_dict(weights)

model.eval()
valid_step = []
with torch.no_grad():
    for idx, batch in enumerate(valid_dataloader):
        inputs, targets = batch
        inputs = inputs.to('cpu')

        outputs = model(inputs)
        valid_step.append({"logits": outputs, "targets": targets})
    
    output_val = nn.Softmax(dim=1)(torch.cat([x['logits'] for x in valid_step], dim=0)).cpu().detach()
    target_val = torch.cat([x['targets'] for x in valid_step], dim=0).cpu().detach()

    target_val = torch.nn.functional.one_hot(target_val, len(label_list))
    
    gt_df = pd.DataFrame(target_val.numpy().astype(np.float32), columns=label_list)
    pred_df = pd.DataFrame(output_val.numpy().astype(np.float32), columns=label_list)
    
    gt_df['id'] = [f'id_{i}' for i in range(len(gt_df))]
    pred_df['id'] = [f'id_{i}' for i in range(len(pred_df))]

    val_roc_auc = score(gt_df, pred_df, row_id_column_name='id')

In [23]:
amount = dict()
for i,amm in enumerate(valid_df["primary_label"].value_counts().to_numpy()):
    amount[i]=amm

In [24]:
# One hot encode the labels in order to plot them
y_onehot = pd.DataFrame(target_val.numpy())
y_scores = output_val

# Create an empty figure, and iteratively add new lines
# every time we compute a new class
fig = go.Figure()
fig.add_shape(
    type='line', line=dict(dash='dash'),
    x0=0, x1=1, y0=0, y1=1
)

for i in range(y_scores.shape[1]):
    y_true = y_onehot.iloc[:, i]
    y_score = y_scores[:, i]

    fpr, tpr, _ = roc_curve(y_true, y_score)
    auc_score = roc_auc_score(y_true, y_score)

    name = f"{id2label[y_onehot.columns[i]]} (Amount={amount[y_onehot.columns[i]]}) (AUC={auc_score:.2f})"
    fig.add_trace(go.Scatter(x=fpr, y=tpr, name=name, mode='lines'))

fig.update_layout(
    title="ROC curve on Validation set",
    template='plotly_dark',
    xaxis_title='False Positive Rate',
    yaxis_title='True Positive Rate',
    yaxis=dict(scaleanchor="x", scaleratio=1),
    xaxis=dict(constrain='domain')
)
fig.show()

In [25]:
fig.write_html(f"model/{default_config['VERSION']}/ROC_Curve.html")