# Import packages

Import all required packages.

In [1]:
import os
import gc
import sys
import cv2
import math
import numpy as np
import pandas as pd
from glob import glob
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, StratifiedKFold
import librosa
from scipy import signal as sci_signal
import json

import torch
from torch import nn
from torchvision.models import efficientnet

#import tensorflow as tf

import albumentations as albu

import pytorch_lightning as pl
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar

# import score function of BirdCLEF
#sys.path.append('/kaggle/input/birdclef-roc-auc')
#sys.path.append('/kaggle/usr/lib/kaggle_metric_utilities')
#from metric import score

In [2]:
# Import for visualization
import matplotlib as mpl
cmap = mpl.cm.get_cmap('coolwarm')
import matplotlib.pyplot as plt
import librosa.display as lid
import IPython.display as ipd
#import cv2

  cmap = mpl.cm.get_cmap('coolwarm')


# Configuration

Hyper-paramters

In [3]:
class config:
    
    # == global config ==
    SEED = 28082015  # random seed
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # device to be used
    MIXED_PRECISION = False  # whether to use mixed-16 precision
    OUTPUT_DIR = './output/'  # output folder
    
    # == data config ==
    DATA_ROOT = 'E:/PycharmProjects/birdclef24/data'  # root folder
    PREPROCESSED_DATA_ROOT = '/kaggle/input/birdclef24-spectrograms-via-cupy'
    LOAD_DATA = True  # whether to load data from pre-processed dataset

    
    # == model config ==
    MODEL_TYPE = 'efficientnet_b0'  # model type
    
    # == dataset config ==
    BATCH_SIZE = 256  # batch size of each step
    N_WORKERS = 6  # number of workers
    
    
    # == training config ==
    FOLDS = 7  # n fold
    EPOCHS = 200  # max epochs
    LR = 7e-4  # learning rate
    WEIGHT_DECAY = 9e-6  # weight decay of optimizer
    
    # == other config ==
    VISUALIZE = True  # whether to visualize data and batch
    
    
print('fix seed')
pl.seed_everything(config.SEED, workers=True)

CFG = config

fix seed


Seed set to 28082015


In [4]:
class ECA(nn.Module):
    def __init__(self, kernel_size=5):
        super().__init__()
        self.kernel_size = kernel_size
        self.supports_masking = True
        self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=1, padding="same", bias=False)
    def forward(self, inputs):
        b, c, s = inputs.shape
        
        x = torch.mean(inputs, axis = -1)
        x = x.view(b, 1, c)
        x = self.conv(x)
        x = x.squeeze(1)
        x = nn.Sigmoid()(x)
        x = x[:,:,None]
        return inputs * x


class CausalDWConv1D(nn.Module):
    def __init__(self, 
        kernel_size=17,
        dilation_rate=1,
        use_bias=False,
        in_channels = 64,
        out_channels = 32,       
        depthwise_initializer='glorot_uniform',
        **kwargs):
        super().__init__()
        #self.causal_pad = tf.keras.layers.ZeroPadding1D((dilation_rate*(kernel_size-1),0),name=name + '_pad')
        self.dw_conv = nn.Conv1d(
            in_channels, 
            out_channels, 
            kernel_size, 
            stride=1, 
            padding='same', 
            dilation=dilation_rate, 
            groups=out_channels,
            bias=False, 
            padding_mode='zeros')

        
    def forward(self, inputs):
        x = self.dw_conv(inputs)
        return x


class Conv1DBlock(nn.Module):
    def __init__(self, 
                 kernel_size=17,
                 channels = 32,
                 expand_channels = 64,
                 drop_rate=0.0,
                ):
        super().__init__()
        self.kernel_size = kernel_size
        self.conv = CausalDWConv1D(
                        kernel_size=kernel_size,
                        dilation_rate=1,
                        use_bias=False,
                        in_channels = expand_channels,
                        out_channels = expand_channels
                    )
        self.dnn_expand = nn.Linear(in_features = channels, 
                                    out_features = expand_channels
                                     )
        self.dnn_project = nn.Linear(in_features = expand_channels, 
                             out_features = channels
                                    )
        self.bn = nn.BatchNorm1d(num_features = expand_channels, eps=0.95)
        self.eca = ECA()
        self.dropout = nn.Dropout(drop_rate)
        self.act = nn.SiLU()

    def forward(self, inputs):
        skip = inputs

        x = inputs.permute([0,2,1])
        x = self.dnn_expand(x)
        
        x = x.permute([0,2,1])
        x = self.act(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.eca(x)
        
        x = x.permute([0,2,1])
        x = self.dnn_project(x)
        x = x.permute([0,2,1])

        return x + skip


class Conv1DModel(nn.Module):
    def __init__(self, 
                 kernel_size=17,
                 channels = 32,
                 expand_channels = 64,
                 drop_rate=0.0,
                 num_blocks_in_stage = 3,
                 input_len = 32_000*5,
                 n_classes = 182
                ):
        super().__init__()
        self.stem_conv = nn.Linear(in_features = 1, 
                                    out_features = channels
                                     )
        self.stem_bn = nn.BatchNorm1d(num_features = channels, eps=0.95)

        self.ConvStage_1 = nn.ModuleList([
            Conv1DBlock(kernel_size=kernel_size, channels = channels,expand_channels = expand_channels, drop_rate=drop_rate)
                                         for _ in range(num_blocks_in_stage)])
        self.PoolStage_1 = nn.AvgPool1d(kernel_size=(4))
        
        self.ConvStage_2 = nn.ModuleList([
            Conv1DBlock(kernel_size=kernel_size, channels = channels,expand_channels = expand_channels, drop_rate=drop_rate)
                                          for _ in range(num_blocks_in_stage)])
        self.PoolStage_2 = nn.AvgPool1d(kernel_size=(4))

        
        self.ConvStage_3 = nn.ModuleList([
            Conv1DBlock(kernel_size=kernel_size, channels = channels,expand_channels = expand_channels, drop_rate=drop_rate)
                                          for _ in range(num_blocks_in_stage)])
        self.PoolStage_3 = nn.AvgPool1d(kernel_size=(4))

        self.pre_out = nn.Linear(in_features = channels, out_features = n_classes*2)
        self.dropout = nn.Dropout(drop_rate)
        self.out_act = nn.SiLU()
        self.out = nn.Linear(in_features = n_classes*2, out_features = n_classes)
        self.sigmoid = nn.Sigmoid()

        
    def forward(self, inputs):
        
        b, s = inputs.shape
        x = inputs.view(b, s, 1)
        x = self.stem_conv(x)
        x = x.permute([0,2,1])
        x = self.stem_bn(x)

        for block in self.ConvStage_1:
            x = block(x)
        x = self.PoolStage_1(x)

        for block in self.ConvStage_2:
            x = block(x)
        x = self.PoolStage_2(x)

        for block in self.ConvStage_3:
            x = block(x)
        x = self.PoolStage_3(x)

        x = x.mean(axis=2)

        x = self.pre_out(x)
        x = self.dropout(x)
        x = self.out_act(x)
        
        logits = self.out(x)
        probs = self.sigmoid(logits)

        return {
                "clipwise_logits_long": logits,
                "clipwise_pred_long": probs,
            }


        

In [5]:

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
        

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head, dropout):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = nn.MultiheadAttention(n_embd, n_head)
        self.ffwd = FeedFoward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x, q = None):
        if q is not None:
            X = (q, x, x)
        else:
            X = (x, x, x)
        y = self.sa(*X)
        y = y[0]
        
        x = self.ln1(x + y)
        y = self.ffwd(x)
        x = self.ln2(x + y)
        return x


In [6]:
class FeatureExctractor(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, kernel_size = 7, channels=16, expand_channels=32, drop_rate = 0.1, n_features=25):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        self.Scales = nn.ModuleList([nn.Conv1d(in_channels = 1,
                                                out_channels = channels,
                                                kernel_size = 1,
                                                stride=1, 
                                                padding='same') for _ in range(n_features)])
        
        self.ConvExt = nn.ModuleList([
            Conv1DBlock(kernel_size=kernel_size, channels = channels,expand_channels = expand_channels, drop_rate=drop_rate)
                                          for _ in range(n_features)])
        self.BNs = nn.ModuleList([nn.BatchNorm1d(60) for _ in range(n_features)])
    def forward(self, x):
        x = x.view(-1, 556)
        
        state_t = x[:, 0:60] - 273
        state_q0001 = x[:, 60:120] *1_000
        state_q0002 = x[:, 120:180] *1_000
        state_q0003 = x[:, 180:240] *1_000
        state_u = x[:, 240:300] / 100
        state_v = x[:, 300:360] / 100
    
        state_ps = x[:, 360:361]/ 100_000 - 1
        pbuf_SOLIN = x[:, 361:362] / 1000
        pbuf_LHFLX = x[:, 362:363] / 1000
        pbuf_SHFLX = x[:, 363:364] / 1000
        pbuf_TAUX = x[:, 364:365] / 1
        pbuf_TAUY = x[:, 365:366] / 1
        pbuf_COSZRS = x[:, 366:367] / 1
        cam_in_ALDIF = x[:, 367:368] / 1
        cam_in_ALDIR = x[:, 368:369] / 1
        cam_in_ASDIF = x[:, 369:370] / 1
        cam_in_ASDIR = x[:, 370:371] / 1
        cam_in_LWUP = x[:, 371:372] / 1000
        cam_in_ICEFRAC = x[:, 372:373] / 1
        cam_in_LANDFRAC = x[:, 373:374] /1
        cam_in_OCNFRAC = x[:, 374:375]  /1
        cam_in_SNOWHLAND = x[:, 375:376] / 1
    
        pbuf_ozone = x[:, 376:436] * 100_000
        pbuf_CH4 = x[:, 436:496] * 100_000
        pbuf_N2O = x[:, 496:556] * 100_000
            
        inputs = [
                state_t,
                state_q0001,
                state_q0002,
                state_q0003, 
                state_u,
                state_v,
    
                torch.repeat_interleave(state_ps, 60, dim=-1),
                torch.repeat_interleave(pbuf_SOLIN, 60, dim=-1),
                torch.repeat_interleave(pbuf_LHFLX, 60, dim=-1),
                torch.repeat_interleave(pbuf_SHFLX, 60, dim=-1),
                torch.repeat_interleave(pbuf_TAUX, 60, dim=-1),
               torch.repeat_interleave(pbuf_TAUY, 60, dim=-1),
                torch.repeat_interleave(pbuf_COSZRS, 60, dim=-1),
                torch.repeat_interleave(cam_in_ALDIF, 60, dim=-1),
                torch.repeat_interleave(cam_in_ALDIR, 60, dim=-1),
               torch.repeat_interleave(cam_in_ASDIF, 60, dim=-1),
                torch.repeat_interleave(cam_in_ASDIR, 60, dim=-1),
                torch.repeat_interleave(cam_in_LWUP, 60, dim=-1),
                torch.repeat_interleave(cam_in_ICEFRAC, 60, dim=-1),
                torch.repeat_interleave(cam_in_LANDFRAC, 60, dim=-1),
                torch.repeat_interleave(cam_in_OCNFRAC, 60, dim=-1),
                torch.repeat_interleave(cam_in_SNOWHLAND, 60, dim=-1),
    
                pbuf_ozone,
                pbuf_CH4,
                pbuf_N2O
        ]

        output = []
        for i, conv in enumerate(self.ConvExt):
            #t = self.BNs[i](inputs[i])
            t = inputs[i]
            t = t.view(-1, 1, 60)
            t = self.Scales[i](t)
            output.append(conv(t))

        return torch.cat(output, 1)#.permute([0,2,1])


In [7]:
class LEADHead(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()

        self.act = nn.SELU()
        self.conv_seq = nn.Conv1d(in_channels = n_embd, out_channels = 6,
                                                kernel_size = 1,
                                                stride=1, 
                                                padding='same')
        
        self.conv_flat = nn.Conv1d(in_channels = n_embd, out_channels = 8,
                                                kernel_size = 1,
                                                stride=1, 
                                                padding='same')
        
    def forward(self, x):

        x = x.permute([0,2,1])
        
        p_seq = self.conv_seq(x)
        p_seq = nn.Flatten()(p_seq)
    
        p_flat = self.conv_flat(x)
        p_flat = torch.mean(p_flat, axis = -1)
        
        return torch.cat([p_seq, p_flat], axis= -1)

In [8]:
LEADHead(32)(torch.ones([8,60,32]))

tensor([[-0.5447, -0.5447, -0.5447,  ..., -0.8126,  1.0072,  0.6713],
        [-0.5447, -0.5447, -0.5447,  ..., -0.8126,  1.0072,  0.6713],
        [-0.5447, -0.5447, -0.5447,  ..., -0.8126,  1.0072,  0.6713],
        ...,
        [-0.5447, -0.5447, -0.5447,  ..., -0.8126,  1.0072,  0.6713],
        [-0.5447, -0.5447, -0.5447,  ..., -0.8126,  1.0072,  0.6713],
        [-0.5447, -0.5447, -0.5447,  ..., -0.8126,  1.0072,  0.6713]],
       grad_fn=<CatBackward0>)

In [47]:
# batch_size = 16
# block_size = 256
# max_iters = 5000
# learning_rate = 3e-4
# eval_iters = 100
# n_embd = 384
# n_head = 8
# n_layer = 12
# dropout = 0.2

nn_config = dict(
    n_embd = 128,
    n_head = 4,
    fe_channels = 32, 
    encoder_layers = 4, 
    fe_drop_rate = 0.1,
    att_drop_rate = 0.2,
    n_features = 25
)

    
class LEADModelAtt(nn.Module):
    def __init__(self, n_embd = 64, n_head = 4, encoder_layers = 3, fe_channels=16, fe_drop_rate=0.1, att_drop_rate=0.2, n_features = 25):
        super().__init__()
        self.fe = FeatureExctractor(kernel_size = 7, channels=fe_channels, expand_channels=fe_channels*2, drop_rate = fe_drop_rate, n_features=n_features)
        
        self.bottleneck = nn.Sequential(
            nn.Conv1d(in_channels = n_features*fe_channels, out_channels = n_embd*4, kernel_size = 3, stride=1,  padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(n_embd*4),
            
            nn.Conv1d(in_channels = n_embd*4, out_channels = n_embd*2, kernel_size = 3, stride=1,  padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(n_embd*2),
            
            nn.Conv1d(in_channels = n_embd*2, out_channels = n_embd, kernel_size = 1, stride=1,  padding='same')
        )

        self.blocks = nn.Sequential(*[Block(n_embd = n_embd, n_head=n_head, dropout = att_drop_rate) for _ in range(encoder_layers)])
        self.head  = LEADHead(n_embd = n_embd)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)



    def forward(self, inputs, targets=None):
        #B, T = inputs.shape

        xf = self.fe(inputs)
        xf = self.bottleneck(xf)
        x = xf.permute([0,2,1])
        
        x = self.blocks(x)

        out = self.head(x)
        return out

In [48]:
LEADModelAtt(**nn_config)(torch.ones([8, 556])).shape

torch.Size([8, 368])

In [30]:
import torchvision

class FocalLossBCE(torch.nn.Module):
    def __init__(
            self,
            alpha: float = 0.25,
            gamma: float = 2,
            reduction: str = "mean",
            bce_weight: float = 1.0,
            focal_weight: float = 1.0,
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction) #, pos_weight=sample_weights_420)
        self.bce_weight = bce_weight
        self.focal_weight = focal_weight

    def forward(self, logits, targets):
        focall_loss = torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=logits,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )
        bce_loss = self.bce(logits, targets)
        return self.bce_weight * bce_loss + self.focal_weight * focall_loss


criterion = FocalLossBCE(focal_weight=5, alpha = 0.3)

# DATASET

In [31]:
df = pd.read_parquet("train_data_sample.parquet").sample(100000).drop('sample_id', axis=1).reset_index(drop=True)

In [32]:
mean_y = df.iloc[:, 556:].mean().to_numpy()
std_y = df.iloc[:, 556:].std().to_numpy()
std_y = np.clip(std_y, 1e-10, 1e3)

In [33]:
class LEAD_Dataset(torch.utils.data.Dataset):
    def __init__(self, df, augmentation=False, mode='train'):
        if mode == 'train':
            self.df = df.reset_index(drop=True)
        elif mode == 'valid':
            self.df = df.reset_index(drop=True)
        else:
            self.df = df.reset_index(drop=True)
        self.mode = mode
        self.augmentation = augmentation
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        x = self.df.iloc[idx, :556].to_numpy()
        y = self.df.iloc[idx, 556:].to_numpy() 
        y = (y - mean_y) / std_y
        
        return torch.tensor(x), torch.tensor(y)

 

In [34]:
LEAD_Dataset(df).__getitem__(3)[1].view(1, -1)

tensor([[-8.4946e-01, -8.4117e-01, -1.0590e+00, -8.3551e-01, -8.2164e-01,
         -9.1835e-01, -8.5223e-01, -7.7115e-01, -7.8668e-01, -8.1774e-01,
         -8.6398e-01, -9.4047e-01, -1.0246e+00, -1.0578e+00, -1.0390e+00,
         -1.0456e+00, -8.9029e-01, -3.4263e-01, -3.8428e-01,  1.7187e-02,
         -2.7066e-01, -4.7442e-02, -5.5365e-01, -3.2171e-01, -4.0188e-01,
         -4.0085e-01, -3.8172e-01, -2.1434e-01, -1.1706e-01, -1.9239e-01,
         -4.9736e-01, -8.4497e-01, -8.7760e-01, -8.1536e-01, -4.2608e-01,
         -5.6318e-01, -6.5608e-01,  2.3213e-01, -1.6411e-01, -1.5355e-01,
          3.5687e-01,  6.4876e-02,  3.7405e-01,  1.8168e-01,  1.2330e-01,
         -3.7806e-01, -6.2132e-01, -7.4627e-01, -8.1018e-01, -5.5426e-01,
         -6.8920e-01, -7.0061e-01, -6.6780e-01, -6.5509e-01, -5.6214e-01,
         -5.0127e-01, -2.9542e-01, -7.3688e-02,  1.6327e-01,  4.6431e-01,
         -2.6445e-06,  2.8567e-07,  2.8671e-07,  2.1952e-07,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.

In [35]:
LEADModelAtt(**nn_config)(torch.cat([LEAD_Dataset(df).__getitem__(3)[0].view(1, -1), LEAD_Dataset(df).__getitem__(5)[0].view(1, -1)]))

tensor([[-0.2631,  0.2642,  0.2161,  0.1341,  0.1177,  0.5436,  0.8648,  0.7057,
          0.8269,  0.7984,  0.8157,  0.5113,  0.6099,  0.6781,  0.5989,  0.6146,
          0.6173,  0.5490,  0.3686,  0.4140,  0.5298,  0.4512,  0.5016,  0.3730,
          0.4644,  0.4565,  0.4043,  0.5224,  0.6393,  0.5659,  0.5884,  0.4525,
          0.2773,  0.1116, -0.0298, -0.3391, -0.4142, -0.5168, -0.5530, -0.5055,
         -0.4807, -0.5546, -0.5800, -0.5910, -0.4927, -0.5330, -0.6515, -0.5910,
         -0.5682, -0.6745, -0.7100, -0.4314, -0.6006, -0.6802, -0.5908, -0.5121,
         -0.5633, -0.5850, -0.5313,  0.3399,  0.0268,  0.0852,  0.0098, -0.0980,
         -0.3195, -0.7499, -0.5332, -0.2038,  0.0202,  0.1026,  0.1323,  0.0908,
          0.1467,  0.2413,  0.2793,  0.2152,  0.2440,  0.2394,  0.2800,  0.2054,
          0.2528,  0.2755,  0.3343,  0.2714,  0.2261,  0.1722,  0.1364,  0.1414,
          0.1799,  0.2093,  0.1009,  0.1203,  0.1349,  0.1740,  0.0771,  0.0988,
         -0.0686, -0.0794, -

In [36]:
a = LEAD_Dataset(df).__getitem__(63)[1]
b = LEAD_Dataset(df).__getitem__(9)[1]

In [37]:
a.shape

torch.Size([368])

In [38]:
nn.MSELoss()(torch.tensor(np.expand_dims(a, 0)), torch.tensor(np.expand_dims(b, 0)))

tensor(1.6672)

In [49]:

import torch
#from torcheval.metrics import R2Score 
from torchmetrics.regression import R2Score
metric = R2Score()



class LEADModel(pl.LightningModule):
    
    def __init__(self):
        super().__init__()
        
        # == backbone ==
        self.backbone = LEADModelAtt(**nn_config).to(config.DEVICE)
        

        self.loss_fn = nn.MSELoss()
        self.metric = R2Score()
        
        # == record ==
        self.validation_step_outputs = []
        
    def forward(self, images):
        return self.backbone(images)
    
    def configure_optimizers(self):
        
        # == define optimizer ==
        model_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr=config.LR,
            weight_decay=config.WEIGHT_DECAY
        )
        
        # == define learning rate scheduler ==
        lr_scheduler = CosineAnnealingWarmRestarts(
            model_optimizer,
            T_0=config.EPOCHS,
            T_mult=1,
            eta_min=1e-7,
            last_epoch=-1
        )
        
        return {
            'optimizer': model_optimizer,
            'lr_scheduler': {
                'scheduler': lr_scheduler,
                'interval': 'epoch',
                'monitor': 'val_loss',
                'frequency': 1
            }
        }
    
    def training_step(self, batch, batch_idx):
        
        # == obtain input and target ==
        image, target = batch
        image = image.to(self.device).float()
        target = target.to(self.device).float()
        
        # == pred ==
        y_pred = self(image)
        
        # == compute loss ==
        train_loss = self.loss_fn(y_pred, target)
        
        # == record ==
        self.log('train_loss', train_loss, True)
        
        return train_loss
    
    def validation_step(self, batch, batch_idx):
        
        # == obtain input and target ==
        image, target = batch
        image = image.to(self.device).float()
        target = target.to(self.device).float()
        
        # == pred ==
        with torch.no_grad():
            y_pred = self(image)
            
        self.validation_step_outputs.append({"logits": y_pred, "targets": target})
        
    def train_dataloader(self):
        return self._train_dataloader

    def validation_dataloader(self):
        return self._validation_dataloader
    
    def on_epoch_start(self):
        print('\n')

    def on_load_checkpoint(self, checkpoint: dict) -> None:
        state_dict = checkpoint["state_dict"]
        model_state_dict = self.state_dict()
        is_changed = False
        for k in state_dict:
            if k in model_state_dict:
                if state_dict[k].shape != model_state_dict[k].shape:
                    print(f"Skip loading parameter: {k}, "
                                f"required shape: {model_state_dict[k].shape}, "
                                f"loaded shape: {state_dict[k].shape}")
                    state_dict[k] = model_state_dict[k]
                    is_changed = True
            else:
                print(f"Dropping parameter {k}")
                is_changed = True

        if is_changed:
            checkpoint.pop("optimizer_states", None)
    
    def on_validation_epoch_end(self):
        
        # = merge batch data =
        outputs = self.validation_step_outputs
        
        #output_val = nn.Sigmoid()(torch.cat([x['logits'] for x in outputs], dim=0)).cpu().detach()
        #output_val = torch.cat([x['logits'] for x in outputs], dim=0).cpu().detach()
        #target_val = torch.cat([x['targets'] for x in outputs], dim=0).cpu().detach()
        output_val = torch.cat([x['logits'] for x in outputs], dim=0)#.cpu().detach()
        target_val = torch.cat([x['targets'] for x in outputs], dim=0)#.cpu().detach()
        
        
        # = compute validation loss =
        val_loss = self.loss_fn(output_val, target_val)
        # == record ==
        print(f"val_loss: {val_loss}")
        self.log('val_loss', val_loss, True)
        
        val_loss = val_loss.cpu().detach()

    
        #output_val = nn.Sigmoid()(output_val).cpu().detach()
        output_val = output_val.cpu().detach()
        target_val = target_val.cpu().detach()



        r2=0
        for i in range(368):
            r2_i = self.metric(output_val[:, i], target_val[:, i])
            if r2_i > 1e-6:
                r2 += r2_i
        val_score  = r2/ 368
            

        
        # self.metric.update(target_val, output_val)
        # val_score = self.metric.compute()
        
        # target to one-hot
        #target_val = torch.nn.functional.one_hot(target_val, len(label_list))
        
        # = val with ROC AUC =
        # gt_df = pd.DataFrame(target_val.numpy().astype(np.float32), columns=label_list)
        # pred_df = pd.DataFrame(output_val.numpy().astype(np.float32), columns=label_list)
        
        # gt_df['id'] = [f'id_{i}' for i in range(len(gt_df))]
        # pred_df['id'] = [f'id_{i}' for i in range(len(pred_df))]
        
        # val_score = score(gt_df.drop(cols_drop_on_val, axis=1), pred_df.drop(cols_drop_on_val, axis=1), row_id_column_name='id')
        
        print(f"val_R2: {val_score}")
        
        self.log("val_R2", val_score, True)
        
        # clear validation outputs
        self.validation_step_outputs = list()
        
        return {'val_loss': val_loss, 'val_R2': val_score}

In [50]:
USE_CHECKPOINT = False
#CHK_PATH = './pretrain_checkpoints/eca_nfnet_l0_fold_0_0.97126.ckpt'


def run_training(fold_id, total_df):
    print('================================================================')
    print(f"==== Running training for fold {fold_id} ====")
    
    # == create dataset and dataloader ==
    train_df = total_df[total_df['fold'] != fold_id].drop('fold', axis=1).copy()
    valid_df = total_df[total_df['fold'] == fold_id].drop('fold', axis=1).copy()
    
    print(f'Train Samples: {len(train_df)}')
    print(f'Valid Samples: {len(valid_df)}')
    
  
    train_ds = LEAD_Dataset(train_df)
    val_ds =  LEAD_Dataset(valid_df)
    #val_ds = WaveAllFileDataset(df=valid_df, name_col="filepath", **val_dataset_config)
    
    
    train_dl = torch.utils.data.DataLoader(
        train_ds,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        #num_workers=config.N_WORKERS,
        pin_memory=True,
        #persistent_workers=True
    )
    
    val_dl = torch.utils.data.DataLoader(
        val_ds,
        batch_size=config.BATCH_SIZE * 2,
        shuffle=False,
        #num_workers=config.N_WORKERS,
        pin_memory=True,
        #persistent_workers=True
    )
    
    # == init model ==
    if USE_CHECKPOINT:
        model = LEADModel.load_from_checkpoint(CHK_PATH, strict=False)
    else:
        model = LEADModel()
    # == init callback ==
    checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                          dirpath=config.OUTPUT_DIR,
                                          save_top_k=1,
                                          save_last=True,
                                          save_weights_only=True,
                                          filename=f"fold_{fold_id}",
                                          mode='min')

    callbacks_to_use = [checkpoint_callback, TQDMProgressBar(refresh_rate=1)]

    print(f'trainer')
    # == init trainer ==
    trainer = pl.Trainer(
        max_epochs=config.EPOCHS,
        val_check_interval=1.,
        num_sanity_val_steps=0,
        callbacks=callbacks_to_use,
        enable_model_summary=False,
        accelerator="gpu" if torch.cuda.is_available() else 'auto',
        deterministic=True,
        precision='16-mixed' if config.MIXED_PRECISION else 32,
    )
    
    # == Training ==
    trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
    
    # == Prediction ==
    best_model_path = checkpoint_callback.best_model_path
    weights = torch.load(best_model_path)['state_dict']
    model.load_state_dict(weights)
    
    
    return trainer

In [41]:
#train_df = train_df[train_df.target<30].reset_index(drop=True)

kf = KFold(n_splits=config.FOLDS, shuffle=True, random_state=config.SEED)
df['fold'] = 0
for fold, (train_idx, val_idx) in enumerate(kf.split(df)):
    df.loc[val_idx, 'fold'] = fold
    

In [42]:
#config.EPOCHS = 10
#config.LR = 1e-5

In [51]:


import logging

def disable_logging_during_tests():
    # Store the current log level to restore it later
    original_log_level = logging.getLogger().getEffectiveLevel()

    # Set the log level to a higher level, e.g., WARNING or CRITICAL
    logging.disable(logging.ERROR)

    # Run your tests here

    # Restore the original log level after the tests
    logging.disable(original_log_level)

# Call this function before running your tests
disable_logging_during_tests()



In [53]:
selected_folds = [0,4,5]
    
# training
torch.set_float32_matmul_precision('high')



for f in range(config.FOLDS):
    
    if f not in selected_folds:
        continue
    
    # get validation index
    #val_idx = list(train_df[train_df['fold'] == f].index)
    
    # main loop of f-fold
    trainer = run_training(f, df)
    

    
    # only training one fold
    #break


# for idx, val_score in enumerate(fold_val_score_list):
#     print(f'Fold {idx} Val Score: {val_score:.5f}')

# oof_gt_df = oof_df[['samplename'] + label_list].copy()
# oof_pred_df = oof_df[['samplename'] + pred_cols].copy()
# oof_pred_df.columns = ['samplename'] + label_list
# oof_score = score(oof_gt_df, oof_pred_df, 'samplename')
# print(f'OOF Score: {oof_score:.5f}')

#oof_df.to_csv(f"{config.OUTPUT_DIR}/oof_pred.csv", index=False)

==== Running training for fold 0 ====
Train Samples: 85714
Valid Samples: 14286
trainer


Training: |                                                                                                   …

Validation: |                                                                                                 …

val_loss: 0.6069918274879456
val_R2: 0.16118839383125305


Validation: |                                                                                                 …

val_loss: 0.5179113149642944
val_R2: 0.2442925125360489


Validation: |                                                                                                 …

val_loss: 0.48237138986587524
val_R2: 0.27924787998199463


Validation: |                                                                                                 …

val_loss: 0.4769443869590759
val_R2: 0.2849314510822296


Validation: |                                                                                                 …

val_loss: 0.4639175236225128
val_R2: 0.2979210317134857


Validation: |                                                                                                 …

val_loss: 0.4352833926677704
val_R2: 0.3267115354537964


Validation: |                                                                                                 …

val_loss: 0.4233962595462799
val_R2: 0.33874624967575073


Validation: |                                                                                                 …

val_loss: 0.417022168636322
val_R2: 0.3449324369430542


Validation: |                                                                                                 …

val_loss: 0.41143104434013367
val_R2: 0.35035938024520874


Validation: |                                                                                                 …

val_loss: 0.404816210269928
val_R2: 0.35782480239868164


Validation: |                                                                                                 …

val_loss: 0.3996056020259857
val_R2: 0.3627639412879944


Validation: |                                                                                                 …

val_loss: 0.4095602333545685
val_R2: 0.35248953104019165


Validation: |                                                                                                 …

val_loss: 0.3929939568042755
val_R2: 0.3701450526714325


Validation: |                                                                                                 …

val_loss: 0.38637152314186096
val_R2: 0.3770081400871277


Validation: |                                                                                                 …

val_loss: 0.38330891728401184
val_R2: 0.37897974252700806


Validation: |                                                                                                 …

val_loss: 0.3845108449459076
val_R2: 0.3788372576236725


Validation: |                                                                                                 …

val_loss: 0.3854481279850006
val_R2: 0.37703290581703186


Validation: |                                                                                                 …

val_loss: 0.3852533996105194
val_R2: 0.37699344754219055


Validation: |                                                                                                 …

val_loss: 0.3820417821407318
val_R2: 0.3807038962841034


Validation: |                                                                                                 …

val_loss: 0.3795163035392761
val_R2: 0.38351428508758545


Validation: |                                                                                                 …

val_loss: 0.3797265887260437
val_R2: 0.38284164667129517


Validation: |                                                                                                 …

val_loss: 0.3769436180591583
val_R2: 0.3851703703403473


Validation: |                                                                                                 …

val_loss: 0.37291577458381653
val_R2: 0.3908238410949707


Validation: |                                                                                                 …

val_loss: 0.37641677260398865
val_R2: 0.38570401072502136


Validation: |                                                                                                 …

val_loss: 0.37051206827163696
val_R2: 0.39211270213127136


Validation: |                                                                                                 …

val_loss: 0.3885743319988251
val_R2: 0.373775452375412


Validation: |                                                                                                 …

val_loss: 0.3689335882663727
val_R2: 0.39379170536994934


Validation: |                                                                                                 …

val_loss: 0.3688802421092987
val_R2: 0.3947659432888031


Validation: |                                                                                                 …

val_loss: 0.37876036763191223
val_R2: 0.38444286584854126


Validation: |                                                                                                 …

val_loss: 0.3694678843021393
val_R2: 0.3932725787162781


Validation: |                                                                                                 …

val_loss: 0.36338332295417786
val_R2: 0.39931273460388184


Validation: |                                                                                                 …

val_loss: 0.3777080178260803
val_R2: 0.3856325149536133
==== Running training for fold 4 ====
Train Samples: 85714
Valid Samples: 14286
trainer


Training: |                                                                                                   …

FileNotFoundError: [Errno 2] No such file or directory: ''

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(tokenized, split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
model = LEADModelAtt(**nn_config).to(config.DEVICE)

In [None]:
train_dl = torch.utils.data.DataLoader(
        LEAD_Dataset(df.drop('fold', axis=1)),
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
        #persistent_workers=True
    )

In [None]:
max_iters = 2000
eval_iters = 1000
learning_rate = 3e-4
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for step in range(max_iters):
    #print(iter)
    # if iter % eval_iters == 0:
    #     losses = estimate_loss()
    #     print(f"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}")

    # sample a batch of data
    #xb, yb = LEAD_Dataset(df.drop('fold', axis=1)).__getitem__(iter)

    xb, yb = next(iter(train_dl))

    # evaluate the loss
    logits = model.forward(xb.to('cuda').float(), yb.to('cuda').float())

    loss = nn.MSELoss()(logits, yb.to('cuda').float())
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    print(loss.item())





In [None]:
#config.EPOCHS = 25  # max epochs
#config.LR = 3e-4  # learning rate

In [None]:
%reload_ext tensorboard
%tensorboard --logdir ./lightning_logs/version_0/

