In [1]:
import os
import cv2
import csv
import sys
import copy
from tqdm import tqdm
from typing import Union, List, Dict, Any, cast
import random
import librosa
import librosa.display
import numpy as np 
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import label_ranking_average_precision_score, accuracy_score
import torchvision

import audiomentations as audioaa

import matplotlib.pyplot as plt
import IPython.display as ipd 
import skimage.io
from skimage.transform import resize
import albumentations as albu
from albumentations import pytorch as AT
from PIL import Image
from functools import partial


import pretrainedmodels
#from resnest.torch import resnest50

sys.path.append('../')

import src.audio_augs as aa
from src.utils import patch_first_conv
from src.loss import lsep_loss_stable, lsep_loss
from src.batch_mixer import BatchMixer
from src.pann import *

import timm
from timm.models.resnet import resnet50d
from timm.models.efficientnet import tf_efficientnet_b0_ns, tf_efficientnet_lite4, mobilenetv2_140, tf_efficientnet_b1_ns
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation


import warnings
warnings.filterwarnings('ignore')

In [2]:
train_folder_path = "../data/train/"
train_np_folder_path = "../data/train_np/"
test_folder_path = "../data/test/"
sample_submission = "../data/sample_submission.csv"
train_tp_path = "../data/train_tp.csv"
train_fp_path = "../data/train_fp.csv"
train_tp_folds = pd.read_csv("../data/train_tp_folds_v3.csv")
train_fp_folds = pd.read_csv("train_fp_folds.csv").drop("Unnamed: 0", 1)

train_files = os.listdir(train_folder_path)
test_files = os.listdir(test_folder_path)

train_tp = pd.read_csv(train_tp_path)
train_fp = pd.read_csv(train_fp_path)

_df = pd.read_csv("missing_3classes_extended.csv")
_df = _df.drop(columns="Unnamed: 0")

pseudo_hard = pd.read_csv("../data/hard_pseudo_227.csv")
pseudo_tp = pd.read_csv("../data/tp_pseudo_268.csv")
pseudo_fp = pd.read_csv("../data/fp_pseudo_268.csv")

In [3]:
class Config:
    SEED = 25
    NUM_BIRDS = 24
    BATCH_SIZE = 16
    NUM_WORKERS = 4
    FOLD = 4
    TEST_FOLD = 5
    EPOCHS = 50
    
    #optimizer params
    LR = 0.01
    LR_ADAM = 1e-3
    WEIGHT_DECAY = 0.0001
    MOMENTUM = 0.9
    T_MAX = 8
    
    #scheduler params
    FACTOR = 0.8
    PATIENCE = 4

    SR = 48000
    LENGTH_1  = 10* SR
    LENGTH_2 = 5 * SR
    #TODO: MAKE AUGS CONF
    
encoder_params = {
    "efficientnet_b0": {
        #"features": 1280,
        "features": 1792,
        #"features":2048,
        "init_op": partial(mobilenetv2_140, pretrained=True, drop_path_rate=0.2)
        }
    }
    
model_param = {
        'encoder' : 'efficientnet_b0',
        'sample_rate': 48000,
        'window_size' : 2048, #* 2, # 512 * 2
        'hop_size' : 512, #345 * 2, # 320
        'mel_bins' : 224, # 60
        'fmin' : 80,
        'fmax' : 15000,
        'classes_num' : 24
    }

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(Config.SEED)


In [5]:
class AudioSEDModel(nn.Module):
    def __init__(self, encoder, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num):
        super().__init__()

        window = 'hann'
        center = True
        pad_mode = 'reflect'
        ref = 1.0
        amin = 1e-10
        top_db = None
        self.interpolate_ratio = 30
        self.mixup_coff = Mixup(1.)

        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 
            win_length=window_size, window=window, center=center, pad_mode=pad_mode, 
            freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 
            n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 
            freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 
            freq_drop_width=8, freq_stripes_num=2)
        
        # Model Encoder
        self.encoder = encoder_params[encoder]["init_op"]()
        self.fc1 = nn.Linear(encoder_params[encoder]["features"], 1024, bias=True)
        self.att_block = AttBlock_V2(1024, classes_num, activation="sigmoid")
        self.bn0 = nn.BatchNorm2d(mel_bins)
        self.init_weight()
    
    def init_weight(self):
        init_layer(self.fc1)
        init_bn(self.bn0)
    
    def forward(self, input, mixup_lambda=None):
        """Input : (batch_size, data_length)"""

        x = self.spectrogram_extractor(input)
        # batch_size x 1 x time_steps x freq_bins
        x = self.logmel_extractor(x)
        # batch_size x 1 x time_steps x mel_bins

        frames_num = x.shape[2]

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        #print(x.shape)

        if self.training and False:
            x = self.spec_augmenter(x)
        
        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)
        
        # Output shape (batch size, channels, time, frequency)
        x = x.expand(x.shape[0], 3, x.shape[2], x.shape[3])
        #print(x.shape)
        x = self.encoder.forward_features(x)
        #print(x.shape)
        x = torch.mean(x, dim=3)
        #print(x.shape)

        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2
        #print(x.shape)

        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc1(x))
        x = x.transpose(1, 2)
        x = F.dropout(x, p=0.5, training=self.training)
        #print(x.shape)

        (clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
        logit = torch.sum(norm_att * self.att_block.cla(x), dim=2)
        segmentwise_logit = self.att_block.cla(x).transpose(1, 2)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        # Get framewise output
        framewise_output = interpolate(segmentwise_output,
                                       self.interpolate_ratio)
        framewise_output = pad_framewise_output(framewise_output, frames_num)
        
        framewise_logit = interpolate(segmentwise_logit, self.interpolate_ratio)
        framewise_logit = pad_framewise_output(framewise_logit, frames_num)

        output_dict = {
            "framewise_output": framewise_output,
            "segmentwise_output": segmentwise_output,
            "logit": logit,
            "framewise_logit": framewise_logit,
            "clipwise_output": clipwise_output
        }


        return output_dict
    
def crop_or_pad(y, is_train=True):
    length = Config.LENGTH_2
    if len(y) < length:
        
        pad_width = length - len(y)
        pad_sub = start = np.random.randint(0, pad_width)
        
        y = np.pad(y, (pad_sub, pad_width-pad_sub), "minimum")
    elif len(y) > length:
        start = np.random.randint(len(y) - length)
        
        y = y[start:start + length]

    y = y.astype(np.float32, copy=False)
    #print(y.shape)

    return y

In [6]:
class RainforestDataset(Dataset):
    def __init__(self, df, audio_transforms = None, image_transforms = None,):
        self.audio_transforms = audio_transforms
        self.img_transforms = image_transforms
        self.df = df

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        sample = copy.deepcopy(self.df.iloc[idx, :].values)
        try:
            wav = np.load(train_np_folder_path + sample[0] + ".npy")
        except:
            wav, sr = librosa.load('../data/test/' + sample[0] + ".flac", sr=None)
            
        tmin = float(sample[3]) * Config.SR
        tmax = float(sample[5]) * Config.SR
        center = np.round((tmin + tmax) / 2)
        
        multiplier = random.random() * 0.5 + 1
        clip_size = (tmax - tmin) * multiplier
        beginning = center - Config.LENGTH_2 / 2
        if beginning < 0:
            beginning = 0
            
        beginning = np.random.randint( beginning , center)
        ending = beginning + Config.LENGTH_2
        if ending > len(wav):
            ending = len(wav)
            beginning = ending - Config.LENGTH_2
            
        wav_slice = wav[int(beginning):int(ending)]
        
        beginning_time = beginning / Config.SR
        ending_time = ending / Config.SR
        recording_id = sample[0]
        query_string = f"recording_id == '{recording_id}' & "
        query_string += f"t_min < {ending_time} & t_max > {beginning_time}"
        all_tp_events = self.df.query(query_string)

        label_array = np.zeros(24, dtype=np.float32)
        for species_id in all_tp_events["species_id"].unique():
            label_array[int(species_id)] = sample[-1]
            #if species_id == 12:
            #    label_array[3] = 1


            
            
        #wav_slice = crop_or_pad(wav_slice)
       
        if self.audio_transforms: # and bird_id not in (3, 7, 8, 9):
            wav_slice =  self.audio_transforms(wav_slice)
            #wav_slice = self.audio_transforms(samples=wav_slice, sample_rate=Config.SR)
            
        
        #new_sample_rate = 32000
        #wav_slice = librosa.resample(wav_slice, Config.SR, new_sample_rate)
            
        #wav_slice = np.expand_dims(wav_slice, 0).astype(np.float32)
        wav_slice = wav_slice.astype(np.float32) * 10.

        return torch.tensor(wav_slice), label_array

In [7]:
train_tp_folds["true"] = 1
train_fp_folds["true"] = 0
#X_train = train_tp_folds[(train_tp_folds['fold'] != Config.FOLD) & (train_tp_folds['fold'] != Config.TEST_FOLD)].reset_index(drop=True)
use_pseudo = True


if use_pseudo:
    
    pseudo_fp_samples = pseudo_fp.sample(n=2000, random_state=Config.SEED)
    
    pseudo_fp["fold"] = 3
    pseudo = pd.concat([pseudo_tp, pseudo_fp_samples], ignore_index=True)
    X_train = pseudo[(pseudo['fold'] != Config.FOLD)].reset_index(drop=True)
    X_val = pseudo[pseudo['fold'] == Config.FOLD].reset_index(drop=True)
    
else:
    X_train = train_tp_folds[(train_tp_folds['fold'] != Config.FOLD)].reset_index(drop=True)
    X_val = train_tp_folds[train_tp_folds['fold'] == Config.FOLD].reset_index(drop=True)
#X_test = train_tp_folds[train_tp_folds['fold'] == Config.TEST_FOLD].reset_index(drop=True)
#X_train = X_train[~(X_train.species_id == 12)]
#X_val = X_val[~(X_val.species_id == 12)]

#X_train = train_tp_folds
add_dop = False

if add_dop:
    _df = _df[_df.recording_id.isin(X_train.recording_id)]
    X_train = X_train[_df.columns]
    X_train = pd.concat([X_train, _df])
    
X_train = X_train[['recording_id', 'species_id', 'songtype_id', 't_min', 'f_min', 't_max', "f_max", "is_pseudo"]]
X_train["label"] = 1
X_train = X_train.fillna(value=1)
X_train.loc[X_train.is_pseudo == 1, "label"] = 0.85



add_pseudo = False

if add_pseudo:
    X_train = X_train[pseudo_clear.columns]
    X_train = pd.concat([X_train, pseudo_clear])

print('Training on ' + str(len(X_train)) + ' examples')
print('Validating on ' + str(len(X_val)) + ' examples')
#print('Testing on ' + str(len(X_test)) + ' examples')

Training on 12415 examples
Validating on 2766 examples


In [8]:
audio_transform_train = aa.Compose([
  aa.OneOf([
    aa.GaussianNoiseSNR(min_snr=5.0, max_snr=20.0),
    aa.PinkNoiseSNR( min_snr=5.0, max_snr=20.0,)
  ]),
  aa.PitchShift(max_steps=4, sr=Config.SR, p=0.2),
  #aa.TimeStretch(max_rate=1.2, p=0.1),
  aa.TimeShift(sr=Config.SR),
  aa.VolumeControl(mode="sine", p=0.2 )
])

audio_transform = audioaa.Compose([
    audioaa.AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
    audioaa.FrequencyMask(),
    audioaa.TimeMask(),
    audioaa.AddGaussianSNR(min_SNR=0.001, max_SNR=0.6, p=0.5),
    audioaa.PitchShift(min_semitones=-4, max_semitones=4, p=0.2),
    audioaa.AddBackgroundNoise(sounds_path="../data/noise/", p= 0.2),
    
    #|audioaa.Shift(min_fraction=-0.5, max_fraction=0.5, p=0.5),
    #audioaa.Normalize(),
    #audioaa.PolarityInversion(p=0.5),
    #audioaa.Gain(min_gain_in_db=-12, max_gain_in_db=12, p=0.5),
    #audioaa.ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=40, p=0.5)
])


In [9]:
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss

class PANNsLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.bce = nn.BCELoss()

    def forward(self, input, target):
        input_ = input["clipwise_output"]
        input_ = torch.where(torch.isnan(input_),
                             torch.zeros_like(input_),
                             input_)
        input_ = torch.where(torch.isinf(input_),
                             torch.zeros_like(input_),
                             input_)

        target = target.float()

        return self.bce(input_, target)
    
def soft_focal_loss(input, target, focus=2.0, raw=True, eps=1e-7):

    if raw:
        input = torch.sigmoid(input)

    ranks = input.argsort(dim=-1, descending=True)
    mask = torch.ones_like(input)
    for i in range(1, 3):
        mask[ranks == i] = 0.

    prob_true = input * target + (1 - input) * (1 - target)
    prob_true = torch.clamp(prob_true, eps, 1-eps)
    modulating_factor = (1.0 - prob_true).pow(focus)

    return (- modulating_factor * prob_true.log() * mask).mean()
        
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, use_coeffs = False, coeffs = None):
        super().__init__()
        self.gamma = gamma
        self.coeffs = coeffs
        self.use_coeffs = use_coeffs

    def forward(self, logit, target):
        target = target.float()
        batch_size = target.shape[0]
        max_val = (-logit).clamp(min=0)
        loss = logit - logit * target + max_val + \
            ((-max_val).exp() + (-logit - max_val).exp()).log()

        invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        
        if self.use_coeffs:
            loss = loss * self.coeffs.repeat(batch_size,1)
        if len(loss.size()) == 2:
            loss = loss.sum(dim=1)

        return loss.mean()   

class ImprovedPANNsLoss(nn.Module):
    def __init__(self, output_key="logit", weights=[1, 1], pos_weights =  None):
        super().__init__()

        self.output_key = output_key
        if output_key == "logit":
            self.normal_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
        else:
            self.normal_loss = nn.BCELoss()

        self.bce = nn.BCELoss()
        self.weights = weights

    def forward(self, input, target):
        input_ = input[self.output_key]
        target = target.float()

        framewise_output = input["framewise_output"]
        clipwise_output_with_max, _ = framewise_output.max(dim=1)

        normal_loss = self.normal_loss(input_, target)
        auxiliary_loss = self.bce(clipwise_output_with_max, target)

        return self.weights[0] * normal_loss + self.weights[1] * auxiliary_loss
    
class ImprovedFocalLoss(nn.Module):
    def __init__(self, weights=[1, 1], use_coeffs = False, coeffs = None):
        super().__init__()

        self.focal = FocalLoss(coeffs=coeffs)
        
        self.weights = weights

    def forward(self, input, target):
        input_ = input["logit"]
        target = target.float()

        framewise_output = input["framewise_logit"]
        clipwise_output_with_max, _ = framewise_output.max(dim=1)

        normal_loss = self.focal(input_, target)
        auxiliary_loss = self.focal(clipwise_output_with_max, target)

        return self.weights[0] * normal_loss + self.weights[1] * auxiliary_loss
    
    
class ImprovedLsep(nn.Module):
    def __init__(self, weights=[1, 1]):
        super().__init__()
        self.weights = weights

    def forward(self, input, target):
        #print(input)
    
        input_ = input["logit"]
        target = target.float()

        framewise_output = input["framewise_logit"]
        clipwise_output_with_max, _ = framewise_output.max(dim=1)

        normal_loss = lsep_loss_stable(input_, target)
        auxiliary_loss = lsep_loss_stable(clipwise_output_with_max, target)

        return self.weights[0] * normal_loss + self.weights[1] * auxiliary_loss

In [10]:
model = AudioSEDModel(**model_param)
#model.load_state_dict(torch.load('best_model.pt'))

In [11]:
train_dataset = RainforestDataset(X_train, audio_transforms=audio_transform_train, image_transforms=None)
train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers = Config.NUM_WORKERS, drop_last = True)

coeffiicients = np.array([0.25, 0.289, 0.238, 0.508, 0.229, 0.221, 0.212, 
                          0.285, 0.228, 0.215, 0.218, 0.263, 0.297, 0.216, 
                          0.247, 0.279, 0.220, 0.218, 0.360, 0.212, 0.215, 
                          0.221,0.219, 0.240])

pos_weights = torch.ones(Config.NUM_BIRDS).cuda()
criterion = ImprovedPANNsLoss(pos_weights=pos_weights)

criterion_focal = ImprovedFocalLoss()
lsep_loss = ImprovedLsep()
optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LR_ADAM, weight_decay = 0.01)# momentum = 0.9)
#optimizer = torch.optim.SGD(model.parameters(), lr=Config.LR, weight_decay=Config.WEIGHT_DECAY, momentum=Config.MOMENTUM)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.8)
#scheduler =torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=500, T_mult=1, eta_min=1e-6)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 2, factor = 0.7, mode = "max")
scheduler =torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
mixer = BatchMixer(p=0.5)
mixup_augmenter = Mixup(mixup_alpha=1.)
loss_function = nn.BCEWithLogitsLoss()

if torch.cuda.is_available():
    model = model.cuda()
    loss_function = loss_function.cuda()

In [12]:
def load_val_file(record_id, df):

    wav = np.load('../data/train_np/' + record_id + ".npy")
        # Split for enough segments to not miss anything
        #segments = len(wav) / Config.LENGTH_1
        #segments = int(np.ceil(segments))
    window = 10 * Config.SR
    #stride = 5 * Config.SR
    full_length = 60 * Config.SR

    mel_array = []
    #for i in range(0, full_length + stride - window, stride):
    for i in range(0, full_length, window):
        
            wav_slice = wav[i:i+window]
            #new_sample_rate = 32000
            #wav_slice = librosa.resample(wav_slice, Config.SR, new_sample_rate)
            #wav_slice = np.expand_dims(wav_slice, axis=0).astype(np.float32) 
            wav_slice = wav_slice.astype(np.float32) * 10.
            mel_array.append(wav_slice)
        
        
    val_labels_array = np.zeros(Config.NUM_BIRDS, dtype=np.single)
    species_ids = copy.deepcopy(df[(df.recording_id==record_id)].species_id.unique())
    val_labels_array[species_ids] = 1.
    #if 12 in species_ids:
    #    val_labels_array[3] = 1.
        
    
    return np.array(mel_array), val_labels_array

def lwlrap(truth, scores):
    """Calculate the overall lwlrap using sklearn.metrics.lrap."""
    # sklearn doesn't correctly apply weighting to samples with no labels, so just skip them.
    sample_weight = np.sum(truth > 0, axis=1)
    nonzero_weight_sample_indices = np.flatnonzero(sample_weight > 0)
    overall_lwlrap = label_ranking_average_precision_score(
      truth[nonzero_weight_sample_indices, :] > 0,
      scores[nonzero_weight_sample_indices, :],
      sample_weight=sample_weight[nonzero_weight_sample_indices])
    return overall_lwlrap


def validate(model, files_ids, df):
        val_loss = []
        val_corr = []
        val_metrics = []
        model.eval()
        for i in tqdm(range(0, len(files_ids))):
            data, target = load_val_file(files_ids[i], X_val)
            data, target = torch.tensor(data), torch.tensor(target)
            data = data.float()
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda().unsqueeze(0)
            output = model(data)
            framewise_output = output["framewise_logit"]
            output, _ = framewise_output.max(dim=1) 
            output, _ = torch.max(output, 0)
          
            output = output.unsqueeze(0)
            #print(output.shape)
            #loss = loss_function(output, target)
            loss = lsep_loss_stable(output, target)
            #loss = criterion(output, target)
            #loss = 0
            
            val_metric = lwlrap(target.cpu().detach().numpy(), output.cpu().detach().numpy())
            vals, answers = torch.max(output, 1)
            vals, targets = torch.max(target, 1)
            val_metrics.append(val_metric.item())
            corrects = 0
            for i in range(0, len(answers)):
                if answers[i] == targets[i]:
                    corrects = corrects + 1
            val_corr.append(corrects)
            val_loss.append(loss.item())
        valid_epoch_metric = sum(val_metrics) / len(val_loss)

        return val_loss, val_corr, valid_epoch_metric

In [13]:
best_corrects = 0
files_ids = copy.deepcopy(X_val.recording_id.unique())
mixup=False
# Train loop
print('Starting training loop')
for e in range(0, 150):
    # Stats
    train_loss = []
    train_corr = []
    train_metrics = []
    
    # Single epoch - train
    model.train()
    for batch, (data, target) in tqdm(enumerate(train_loader)):
        data = data.float()
        if mixup:
            mixup_lambda = torch.tensor(mixup_augmenter.get_lambda(len(data)))
            target = do_mixup(target, mixup_lambda)
        #data, target = mixer(data, target)
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()
            if mixup:
                mixup_lambda =  mixup_lambda.cuda()

        #print(data.shape, target.shape, mixup_lambda.shape)    
        optimizer.zero_grad()
        if mixup:
            output = model(data, mixup_lambda )
        else:
            output = model(data)
        #loss = loss_function(output, target)
        
        label_smoothing_list = [0.02, 0.015, 0.01]
        label_smoothing = random.choice(label_smoothing_list) 
        targets_smooth = target * (1 - label_smoothing) + 0.5 * label_smoothing
        
        loss = lsep_loss(output, targets_smooth)
        #output = output["logit"]
        #loss = loss_function(output, targets_smooth)
        #loss = lsep_loss(output, target)
        framewise_output = output["framewise_logit"]
        clipwise_output_with_max, _ = framewise_output.max(dim=1)
        train_metric = lwlrap(target.cpu().detach().numpy(), clipwise_output_with_max.cpu().detach().numpy())
        train_metrics.append(train_metric.item())
       
        loss.backward()
        optimizer.step()
        #scheduler.step()

        # Stats
        vals, answers = torch.max(output["clipwise_output"], 1)
        vals, targets = torch.max(target, 1)
        corrects = 0
        for i in range(0, len(answers)):
            if answers[i] == targets[i]:
                corrects = corrects + 1
        train_corr.append(corrects)
        train_loss.append(loss.item())
        train_epoch_metric = sum(train_metrics) / len(train_loss)
    
    # Stats
    for g in optimizer.param_groups:
        lr = g['lr']
    print('Epoch ' + str(e) + ' training end. LR: ' + str(lr) + ', Loss: ' + str(sum(train_loss) / len(train_loss)) +
          ', Correct answers: ' + str(sum(train_corr)) + '/' + str(train_dataset.__len__()) + '/train_metric:' + str(train_epoch_metric))
    
    with torch.no_grad():
                # Stats
        val_loss, val_corr, valid_epoch_metric = validate(model, files_ids, X_val)
    # Stats
    print('Epoch ' + str(e) + ' validation end. LR: ' + str(lr) + ', Loss: ' + str(sum(val_loss) / len(val_loss)) +
          ', Correct answers: ' + str(sum(val_corr)) + '/' + str(len(files_ids)) + ", Val metric: " + str(valid_epoch_metric))
    
    # If this epoch is better than previous on validation, save model
    # Validation loss is the more common metric, but in this case our loss is misaligned with competition metric, making accuracy a better metric
    #valid_epoch_metric = 1 - (sum(train_loss) / len(train_loss))
    if train_epoch_metric > best_corrects:
        print('Saving new best model at epoch ') #+ str(e) + ' ' + str(sum(val_corr)) + '/' + str(len(files_ids)))
        torch.save(model.state_dict(), 'best_model_train.pt')
        best_corrects = train_epoch_metric
        
    # Call every epoch
    #scheduler.step(valid_epoch_metric)
    scheduler.step()

# Free memory
#del model

Starting training loop


775it [04:53,  2.64it/s]
  1%|          | 2/226 [00:00<00:20, 10.68it/s]

Epoch 0 training end. LR: 0.001, Loss: 4.82452334373228, Correct answers: 5171/12415/train_metric:0.637043949459621


100%|██████████| 226/226 [00:26<00:00,  8.49it/s]


Epoch 0 validation end. LR: 0.001, Loss: 2.492921588695155, Correct answers: 84/226, Val metric: 0.8289162340687553
Saving new best model at epoch 


775it [04:39,  2.77it/s]
  1%|          | 2/226 [00:00<00:20, 10.68it/s]

Epoch 1 training end. LR: 0.0009755527298894294, Loss: 3.474335747995684, Correct answers: 7094/12415/train_metric:0.7816924186419162


100%|██████████| 226/226 [00:31<00:00,  7.27it/s]


Epoch 1 validation end. LR: 0.0009755527298894294, Loss: 2.3908939129483384, Correct answers: 53/226, Val metric: 0.8229373335794747
Saving new best model at epoch 


775it [04:35,  2.81it/s]
  1%|          | 2/226 [00:00<00:20, 10.67it/s]

Epoch 2 training end. LR: 0.0009046039886902864, Loss: 2.9824891717972295, Correct answers: 7546/12415/train_metric:0.8204587229984315


100%|██████████| 226/226 [00:37<00:00,  5.97it/s]


Epoch 2 validation end. LR: 0.0009046039886902864, Loss: 2.1517117339952856, Correct answers: 60/226, Val metric: 0.8433822321121962
Saving new best model at epoch 


775it [04:39,  2.77it/s]
  1%|          | 2/226 [00:00<00:21, 10.54it/s]

Epoch 3 training end. LR: 0.0007940987335200905, Loss: 2.7101705900315314, Correct answers: 7681/12415/train_metric:0.8365107827378481


100%|██████████| 226/226 [00:24<00:00,  9.30it/s]


Epoch 3 validation end. LR: 0.0007940987335200905, Loss: 2.054140753450647, Correct answers: 74/226, Val metric: 0.8818493828037737
Saving new best model at epoch 


775it [04:37,  2.79it/s]
  1%|          | 2/226 [00:00<00:20, 10.69it/s]

Epoch 4 training end. LR: 0.0006548539886902864, Loss: 2.420561834150745, Correct answers: 7967/12415/train_metric:0.8553228358043943


100%|██████████| 226/226 [00:24<00:00,  9.31it/s]


Epoch 4 validation end. LR: 0.0006548539886902864, Loss: 2.0381016288183433, Correct answers: 76/226, Val metric: 0.8656358444406385
Saving new best model at epoch 


775it [04:41,  2.75it/s]
  0%|          | 1/226 [00:00<00:23,  9.47it/s]

Epoch 5 training end. LR: 0.0005005000000000001, Loss: 2.190576592876065, Correct answers: 8003/12415/train_metric:0.8709241363224526


100%|██████████| 226/226 [00:27<00:00,  8.23it/s]


Epoch 5 validation end. LR: 0.0005005000000000001, Loss: 2.0061696208683792, Correct answers: 69/226, Val metric: 0.8717274934479176
Saving new best model at epoch 


775it [04:43,  2.74it/s]
  0%|          | 1/226 [00:00<00:24,  9.20it/s]

Epoch 6 training end. LR: 0.0003461460113097139, Loss: 1.9984649855859817, Correct answers: 8269/12415/train_metric:0.8831798615453476


100%|██████████| 226/226 [00:24<00:00,  9.29it/s]


Epoch 6 validation end. LR: 0.0003461460113097139, Loss: 1.9468400858144845, Correct answers: 67/226, Val metric: 0.8846851557645352
Saving new best model at epoch 


775it [04:41,  2.75it/s]
  1%|          | 2/226 [00:00<00:20, 10.71it/s]

Epoch 7 training end. LR: 0.00020690126647990976, Loss: 1.7643275231699789, Correct answers: 8431/12415/train_metric:0.8948647927739994


100%|██████████| 226/226 [00:20<00:00, 10.81it/s]


Epoch 7 validation end. LR: 0.00020690126647990976, Loss: 1.944474306781735, Correct answers: 78/226, Val metric: 0.8940990929896447
Saving new best model at epoch 


775it [04:43,  2.73it/s]
  0%|          | 1/226 [00:00<00:24,  9.04it/s]

Epoch 8 training end. LR: 9.639601130971382e-05, Loss: 1.6069211937150647, Correct answers: 8524/12415/train_metric:0.9033154142103812


100%|██████████| 226/226 [00:34<00:00,  6.54it/s]


Epoch 8 validation end. LR: 9.639601130971382e-05, Loss: 1.9354987967330797, Correct answers: 66/226, Val metric: 0.8841101474007517
Saving new best model at epoch 


775it [04:41,  2.75it/s]
  0%|          | 1/226 [00:00<00:26,  8.64it/s]

Epoch 9 training end. LR: 2.5447270110570814e-05, Loss: 1.5517114788870658, Correct answers: 8629/12415/train_metric:0.9086091045552349


100%|██████████| 226/226 [00:24<00:00,  9.26it/s]


Epoch 9 validation end. LR: 2.5447270110570814e-05, Loss: 1.9096893821142416, Correct answers: 62/226, Val metric: 0.8906126992826287
Saving new best model at epoch 


775it [04:38,  2.78it/s]
  1%|          | 2/226 [00:00<00:20, 10.71it/s]

Epoch 10 training end. LR: 1e-06, Loss: 1.5097053579361208, Correct answers: 8565/12415/train_metric:0.9101671212005683


100%|██████████| 226/226 [00:24<00:00,  9.32it/s]


Epoch 10 validation end. LR: 1e-06, Loss: 1.9269978324923895, Correct answers: 62/226, Val metric: 0.8884646850549685
Saving new best model at epoch 


775it [04:47,  2.70it/s]
  1%|          | 2/226 [00:00<00:20, 10.71it/s]

Epoch 11 training end. LR: 2.5447270110570814e-05, Loss: 1.5187592507177783, Correct answers: 8585/12415/train_metric:0.909870999050855


100%|██████████| 226/226 [00:24<00:00,  9.28it/s]

Epoch 11 validation end. LR: 2.5447270110570814e-05, Loss: 1.9307658229253988, Correct answers: 68/226, Val metric: 0.8935945126906522



775it [04:45,  2.72it/s]
  1%|          | 2/226 [00:00<00:20, 10.70it/s]

Epoch 12 training end. LR: 9.639601130971413e-05, Loss: 1.5512580563945155, Correct answers: 8507/12415/train_metric:0.9076186621518095


100%|██████████| 226/226 [00:27<00:00,  8.17it/s]

Epoch 12 validation end. LR: 9.639601130971413e-05, Loss: 1.812057024609726, Correct answers: 64/226, Val metric: 0.8993511008787798



775it [04:40,  2.76it/s]
  1%|          | 2/226 [00:00<00:20, 10.71it/s]

Epoch 13 training end. LR: 0.00020690126647991054, Loss: 1.5624504542735316, Correct answers: 8592/12415/train_metric:0.9068963487322488


100%|██████████| 226/226 [00:27<00:00,  8.17it/s]

Epoch 13 validation end. LR: 0.00020690126647991054, Loss: 1.9985739952695054, Correct answers: 65/226, Val metric: 0.8859231161934836



775it [04:45,  2.72it/s]
  1%|          | 2/226 [00:00<00:20, 10.67it/s]

Epoch 14 training end. LR: 0.0003461460113097153, Loss: 1.6913398758826717, Correct answers: 8487/12415/train_metric:0.8997227251040868


100%|██████████| 226/226 [00:27<00:00,  8.18it/s]

Epoch 14 validation end. LR: 0.0003461460113097153, Loss: 1.914876093906639, Correct answers: 66/226, Val metric: 0.8914819638399554



775it [04:44,  2.72it/s]
  0%|          | 1/226 [00:00<00:23,  9.78it/s]

Epoch 15 training end. LR: 0.0005005000000000021, Loss: 1.8818242252257562, Correct answers: 8389/12415/train_metric:0.8897472553154606


100%|██████████| 226/226 [00:31<00:00,  7.26it/s]

Epoch 15 validation end. LR: 0.0005005000000000021, Loss: 2.0277607440948486, Correct answers: 78/226, Val metric: 0.8741389591692166



775it [04:39,  2.77it/s]
  0%|          | 1/226 [00:00<00:25,  8.92it/s]

Epoch 16 training end. LR: 0.0006548539886902891, Loss: 1.9556587659159015, Correct answers: 8270/12415/train_metric:0.8868740447990761


100%|██████████| 226/226 [00:47<00:00,  4.79it/s]

Epoch 16 validation end. LR: 0.0006548539886902891, Loss: 2.1291529393829074, Correct answers: 71/226, Val metric: 0.8681212865187937



775it [04:43,  2.74it/s]
  0%|          | 1/226 [00:00<00:24,  9.36it/s]

Epoch 17 training end. LR: 0.0007940987335200938, Loss: 2.064089596579152, Correct answers: 8315/12415/train_metric:0.8796036757479302


100%|██████████| 226/226 [00:24<00:00,  9.32it/s]

Epoch 17 validation end. LR: 0.0007940987335200938, Loss: 1.9072916212335098, Correct answers: 55/226, Val metric: 0.8798090114356998



775it [04:38,  2.78it/s]
  0%|          | 1/226 [00:00<00:24,  9.35it/s]

Epoch 18 training end. LR: 0.0009046039886902903, Loss: 2.0756192231178283, Correct answers: 8167/12415/train_metric:0.8768942004923511


100%|██████████| 226/226 [00:20<00:00, 10.81it/s]

Epoch 18 validation end. LR: 0.0009046039886902903, Loss: 2.0304128617312003, Correct answers: 77/226, Val metric: 0.8774868096434082



775it [04:36,  2.80it/s]
  0%|          | 1/226 [00:00<00:25,  8.70it/s]

Epoch 19 training end. LR: 0.0009755527298894335, Loss: 2.0935554393645255, Correct answers: 8190/12415/train_metric:0.8763208251928009


100%|██████████| 226/226 [00:24<00:00,  9.31it/s]

Epoch 19 validation end. LR: 0.0009755527298894335, Loss: 1.9850751003332898, Correct answers: 68/226, Val metric: 0.889576674616366



775it [04:49,  2.68it/s]
  0%|          | 1/226 [00:00<00:23,  9.45it/s]

Epoch 20 training end. LR: 0.0010000000000000041, Loss: 2.0689009706435666, Correct answers: 8218/12415/train_metric:0.8776142112523153


100%|██████████| 226/226 [00:20<00:00, 10.83it/s]

Epoch 20 validation end. LR: 0.0010000000000000041, Loss: 2.063692388281358, Correct answers: 78/226, Val metric: 0.8692182014139638



775it [04:33,  2.83it/s]
  1%|          | 2/226 [00:00<00:21, 10.51it/s]

Epoch 21 training end. LR: 0.0009755527298894334, Loss: 2.0104563837666665, Correct answers: 8275/12415/train_metric:0.8808319297271924


100%|██████████| 226/226 [00:24<00:00,  9.28it/s]

Epoch 21 validation end. LR: 0.0009755527298894334, Loss: 1.9582361225533274, Correct answers: 88/226, Val metric: 0.88518582716201



775it [04:47,  2.70it/s]
  0%|          | 1/226 [00:00<00:24,  9.07it/s]

Epoch 22 training end. LR: 0.0009046039886902904, Loss: 1.91212434645622, Correct answers: 8351/12415/train_metric:0.8853662705428995


100%|██████████| 226/226 [00:24<00:00,  9.32it/s]

Epoch 22 validation end. LR: 0.0009046039886902904, Loss: 1.9565830167415923, Correct answers: 69/226, Val metric: 0.8795289503332184



775it [04:42,  2.74it/s]
  1%|          | 2/226 [00:00<00:21, 10.22it/s]

Epoch 23 training end. LR: 0.0007940987335200938, Loss: 1.7283028030395509, Correct answers: 8430/12415/train_metric:0.8965041118627208


100%|██████████| 226/226 [00:34<00:00,  6.62it/s]

Epoch 23 validation end. LR: 0.0007940987335200938, Loss: 2.0501531984953756, Correct answers: 76/226, Val metric: 0.8788120060890029



775it [04:34,  2.82it/s]
  0%|          | 1/226 [00:00<00:22,  9.82it/s]

Epoch 24 training end. LR: 0.0006548539886902891, Loss: 1.6172161515297427, Correct answers: 8584/12415/train_metric:0.9044168567980354


100%|██████████| 226/226 [00:24<00:00,  9.30it/s]

Epoch 24 validation end. LR: 0.0006548539886902891, Loss: 2.046553440853558, Correct answers: 71/226, Val metric: 0.8903070124669861



775it [04:45,  2.71it/s]
  1%|          | 2/226 [00:00<00:20, 10.70it/s]

Epoch 25 training end. LR: 0.0005005000000000021, Loss: 1.4855996907141902, Correct answers: 8714/12415/train_metric:0.91089771281408


100%|██████████| 226/226 [00:27<00:00,  8.16it/s]


Epoch 25 validation end. LR: 0.0005005000000000021, Loss: 2.119418817283833, Correct answers: 71/226, Val metric: 0.8801766381249381
Saving new best model at epoch 


775it [04:45,  2.72it/s]
  0%|          | 0/226 [00:00<?, ?it/s]

Epoch 26 training end. LR: 0.00034614601130971535, Loss: 1.357074643411944, Correct answers: 8656/12415/train_metric:0.9156062849764427


100%|██████████| 226/226 [00:29<00:00,  7.70it/s]


Epoch 26 validation end. LR: 0.00034614601130971535, Loss: 1.985338278576336, Correct answers: 75/226, Val metric: 0.8916556895326349
Saving new best model at epoch 


775it [04:40,  2.77it/s]
  1%|          | 2/226 [00:00<00:20, 10.67it/s]

Epoch 27 training end. LR: 0.00020690126647991062, Loss: 1.2222883706515835, Correct answers: 8832/12415/train_metric:0.9248794374576355


100%|██████████| 226/226 [00:24<00:00,  9.32it/s]


Epoch 27 validation end. LR: 0.00020690126647991062, Loss: 2.077532095191753, Correct answers: 73/226, Val metric: 0.8951969013645193
Saving new best model at epoch 


775it [04:40,  2.77it/s]
  1%|          | 2/226 [00:00<00:21, 10.54it/s]

Epoch 28 training end. LR: 9.639601130971425e-05, Loss: 1.1520200047762164, Correct answers: 8882/12415/train_metric:0.9276633205824376


100%|██████████| 226/226 [00:24<00:00,  9.29it/s]


Epoch 28 validation end. LR: 9.639601130971425e-05, Loss: 2.0254648406948665, Correct answers: 72/226, Val metric: 0.8961716201591439
Saving new best model at epoch 


775it [04:38,  2.78it/s]
  1%|          | 2/226 [00:00<00:20, 10.73it/s]

Epoch 29 training end. LR: 2.5447270110570967e-05, Loss: 1.0841361925486594, Correct answers: 8978/12415/train_metric:0.932213421122419


100%|██████████| 226/226 [00:24<00:00,  9.27it/s]


Epoch 29 validation end. LR: 2.5447270110570967e-05, Loss: 2.066255928140826, Correct answers: 73/226, Val metric: 0.8961203073554268
Saving new best model at epoch 


775it [04:47,  2.70it/s]
  0%|          | 1/226 [00:00<00:24,  9.15it/s]

Epoch 30 training end. LR: 1e-06, Loss: 1.0750956148678257, Correct answers: 8897/12415/train_metric:0.932464963163787


100%|██████████| 226/226 [00:34<00:00,  6.56it/s]


Epoch 30 validation end. LR: 1e-06, Loss: 2.1007781387430375, Correct answers: 70/226, Val metric: 0.8942004310523826
Saving new best model at epoch 


775it [04:48,  2.68it/s]
  1%|          | 2/226 [00:00<00:20, 10.70it/s]

Epoch 31 training end. LR: 2.5447270110570814e-05, Loss: 1.0572171018200536, Correct answers: 8877/12415/train_metric:0.9331106511445678


100%|██████████| 226/226 [00:27<00:00,  8.16it/s]


Epoch 31 validation end. LR: 2.5447270110570814e-05, Loss: 2.0875089485033422, Correct answers: 73/226, Val metric: 0.8944925631001448
Saving new best model at epoch 


775it [04:44,  2.73it/s]
  1%|          | 2/226 [00:00<00:21, 10.57it/s]

Epoch 32 training end. LR: 9.639601130971386e-05, Loss: 1.0630783533665442, Correct answers: 8940/12415/train_metric:0.9348420763333994


100%|██████████| 226/226 [00:24<00:00,  9.26it/s]


Epoch 32 validation end. LR: 9.639601130971386e-05, Loss: 2.0977960793317947, Correct answers: 66/226, Val metric: 0.8976993037996593
Saving new best model at epoch 


775it [04:39,  2.77it/s]
  0%|          | 1/226 [00:00<00:23,  9.64it/s]

Epoch 33 training end. LR: 0.00020690126647990997, Loss: 1.0994890101205919, Correct answers: 8825/12415/train_metric:0.931817040623997


100%|██████████| 226/226 [00:27<00:00,  8.16it/s]

Epoch 33 validation end. LR: 0.00020690126647990997, Loss: 2.1542835256694692, Correct answers: 72/226, Val metric: 0.8892765754994685



775it [04:39,  2.77it/s]
  1%|          | 2/226 [00:00<00:20, 10.71it/s]

Epoch 34 training end. LR: 0.0003461460113097143, Loss: 1.2294149330739053, Correct answers: 8836/12415/train_metric:0.9268733791317499


100%|██████████| 226/226 [00:24<00:00,  9.30it/s]

Epoch 34 validation end. LR: 0.0003461460113097143, Loss: 2.0813389963808313, Correct answers: 66/226, Val metric: 0.8894083893192432



775it [04:48,  2.69it/s]
  1%|          | 2/226 [00:00<00:20, 10.71it/s]

Epoch 35 training end. LR: 0.0005005000000000009, Loss: 1.2872935065530962, Correct answers: 8757/12415/train_metric:0.922857295096101


100%|██████████| 226/226 [00:24<00:00,  9.32it/s]

Epoch 35 validation end. LR: 0.0005005000000000009, Loss: 2.1455529833261946, Correct answers: 89/226, Val metric: 0.8853266512635081



775it [04:36,  2.80it/s]
  0%|          | 1/226 [00:00<00:23,  9.60it/s]

Epoch 36 training end. LR: 0.0006548539886902876, Loss: 1.4186541931667636, Correct answers: 8581/12415/train_metric:0.9128200007378078


100%|██████████| 226/226 [00:20<00:00, 10.78it/s]

Epoch 36 validation end. LR: 0.0006548539886902876, Loss: 2.1605464441586384, Correct answers: 72/226, Val metric: 0.8826325464101668



775it [04:51,  2.66it/s]
  0%|          | 1/226 [00:00<00:24,  9.07it/s]

Epoch 37 training end. LR: 0.0007940987335200921, Loss: 1.5200635308604087, Correct answers: 8582/12415/train_metric:0.9095766394930421


100%|██████████| 226/226 [00:24<00:00,  9.28it/s]

Epoch 37 validation end. LR: 0.0007940987335200921, Loss: 2.0474420733156458, Correct answers: 85/226, Val metric: 0.8858174950652487



775it [04:43,  2.74it/s]
  1%|          | 2/226 [00:00<00:20, 10.70it/s]

Epoch 38 training end. LR: 0.0009046039886902883, Loss: 1.6458690556403128, Correct answers: 8493/12415/train_metric:0.9009518747623237


100%|██████████| 226/226 [00:20<00:00, 10.82it/s]

Epoch 38 validation end. LR: 0.0009046039886902883, Loss: 2.3341251411269197, Correct answers: 66/226, Val metric: 0.865736721353072



775it [04:46,  2.71it/s]
  0%|          | 1/226 [00:00<00:25,  8.67it/s]

Epoch 39 training end. LR: 0.0009755527298894316, Loss: 1.6413901689744765, Correct answers: 8426/12415/train_metric:0.9020798454808936


100%|██████████| 226/226 [00:24<00:00,  9.05it/s]

Epoch 39 validation end. LR: 0.0009755527298894316, Loss: 2.220473658722059, Correct answers: 76/226, Val metric: 0.8816982671986238



775it [04:35,  2.81it/s]
  0%|          | 0/226 [00:00<?, ?it/s]

Epoch 40 training end. LR: 0.0010000000000000024, Loss: 1.646692199053303, Correct answers: 8495/12415/train_metric:0.9036318501126804


100%|██████████| 226/226 [00:27<00:00,  8.16it/s]

Epoch 40 validation end. LR: 0.0010000000000000024, Loss: 1.9999736262633738, Correct answers: 85/226, Val metric: 0.8907906332756814



775it [04:42,  2.74it/s]
  1%|          | 2/226 [00:00<00:21, 10.47it/s]

Epoch 41 training end. LR: 0.0009755527298894321, Loss: 1.6352423734049644, Correct answers: 8499/12415/train_metric:0.9018589398999893


100%|██████████| 226/226 [00:27<00:00,  8.15it/s]

Epoch 41 validation end. LR: 0.0009755527298894321, Loss: 2.0610372598192335, Correct answers: 65/226, Val metric: 0.8808658985616045



775it [04:43,  2.73it/s]
  0%|          | 1/226 [00:00<00:26,  8.47it/s]

Epoch 42 training end. LR: 0.0009046039886902886, Loss: 1.5565747088386166, Correct answers: 8540/12415/train_metric:0.9057054373970894


100%|██████████| 226/226 [00:31<00:00,  7.28it/s]

Epoch 42 validation end. LR: 0.0009046039886902886, Loss: 2.0390716843900427, Correct answers: 90/226, Val metric: 0.883334402405915



775it [04:38,  2.79it/s]
  0%|          | 1/226 [00:00<00:27,  8.24it/s]

Epoch 43 training end. LR: 0.0007940987335200926, Loss: 1.4772265680182364, Correct answers: 8700/12415/train_metric:0.911062975399244


100%|██████████| 226/226 [00:20<00:00, 10.81it/s]

Epoch 43 validation end. LR: 0.0007940987335200926, Loss: 2.1336914585754934, Correct answers: 69/226, Val metric: 0.8761405089141298



775it [04:48,  2.69it/s]
  1%|          | 2/226 [00:00<00:21, 10.56it/s]

Epoch 44 training end. LR: 0.0006548539886902891, Loss: 1.3180218724281556, Correct answers: 8701/12415/train_metric:0.9189981601575891


100%|██████████| 226/226 [00:34<00:00,  6.58it/s]

Epoch 44 validation end. LR: 0.0006548539886902891, Loss: 2.150601230891405, Correct answers: 79/226, Val metric: 0.8889247244166556



775it [04:41,  2.75it/s]
  0%|          | 1/226 [00:00<00:25,  8.82it/s]

Epoch 45 training end. LR: 0.0005005000000000015, Loss: 1.2273310175249654, Correct answers: 8729/12415/train_metric:0.9251151671313151


100%|██████████| 226/226 [00:27<00:00,  8.13it/s]

Epoch 45 validation end. LR: 0.0005005000000000015, Loss: 2.156507791671078, Correct answers: 81/226, Val metric: 0.8894096802469791



775it [04:47,  2.70it/s]
  0%|          | 1/226 [00:00<00:26,  8.40it/s]

Epoch 46 training end. LR: 0.0003461460113097149, Loss: 1.0838882282280153, Correct answers: 8811/12415/train_metric:0.9319592964823173


100%|██████████| 226/226 [00:37<00:00,  5.98it/s]

Epoch 46 validation end. LR: 0.0003461460113097149, Loss: 2.1162990616486135, Correct answers: 81/226, Val metric: 0.891841973995319



775it [04:42,  2.75it/s]
  1%|          | 2/226 [00:00<00:21, 10.58it/s]

Epoch 47 training end. LR: 0.00020690126647990968, Loss: 0.9893357729142712, Correct answers: 8969/12415/train_metric:0.9373683604589348


100%|██████████| 226/226 [00:20<00:00, 10.82it/s]


Epoch 47 validation end. LR: 0.00020690126647990968, Loss: 2.1888712760621467, Correct answers: 81/226, Val metric: 0.8936206616265489
Saving new best model at epoch 


775it [04:44,  2.72it/s]
  1%|          | 2/226 [00:00<00:20, 10.69it/s]

Epoch 48 training end. LR: 9.639601130971415e-05, Loss: 0.9283995743336216, Correct answers: 9087/12415/train_metric:0.9418801509877939


100%|██████████| 226/226 [00:27<00:00,  8.14it/s]


Epoch 48 validation end. LR: 9.639601130971415e-05, Loss: 2.1363469984679098, Correct answers: 79/226, Val metric: 0.8995504543175027
Saving new best model at epoch 


775it [04:48,  2.69it/s]
  1%|          | 2/226 [00:00<00:22, 10.15it/s]

Epoch 49 training end. LR: 2.5447270110571207e-05, Loss: 0.8983135244154161, Correct answers: 9094/12415/train_metric:0.9432807640090417


100%|██████████| 226/226 [00:27<00:00,  8.22it/s]


Epoch 49 validation end. LR: 2.5447270110571207e-05, Loss: 2.1322183440216875, Correct answers: 77/226, Val metric: 0.9013044907403311
Saving new best model at epoch 


775it [04:44,  2.73it/s]
  1%|          | 2/226 [00:00<00:20, 10.69it/s]

Epoch 50 training end. LR: 1e-06, Loss: 0.8480809907855527, Correct answers: 9064/12415/train_metric:0.9447262279098692


100%|██████████| 226/226 [00:24<00:00,  9.30it/s]


Epoch 50 validation end. LR: 1e-06, Loss: 2.1534402602541762, Correct answers: 77/226, Val metric: 0.8979977629808445
Saving new best model at epoch 


775it [04:47,  2.69it/s]
  1%|          | 2/226 [00:00<00:21, 10.57it/s]

Epoch 51 training end. LR: 2.5447270110570814e-05, Loss: 0.8737994203644414, Correct answers: 9044/12415/train_metric:0.9433031798679352


100%|██████████| 226/226 [00:27<00:00,  8.22it/s]

Epoch 51 validation end. LR: 2.5447270110570814e-05, Loss: 2.14824210858978, Correct answers: 81/226, Val metric: 0.8982042589896231



775it [04:40,  2.76it/s]
  0%|          | 1/226 [00:00<00:22,  9.89it/s]

Epoch 52 training end. LR: 9.63960113097151e-05, Loss: 0.9130358113204279, Correct answers: 8999/12415/train_metric:0.9426122145509479


100%|██████████| 226/226 [00:34<00:00,  6.59it/s]

Epoch 52 validation end. LR: 9.63960113097151e-05, Loss: 2.1629663273296527, Correct answers: 78/226, Val metric: 0.8947957689326116



775it [04:36,  2.80it/s]
  0%|          | 1/226 [00:00<00:24,  9.21it/s]

Epoch 53 training end. LR: 0.00020690126647991342, Loss: 0.914073918865573, Correct answers: 9014/12415/train_metric:0.9407230863396759


100%|██████████| 226/226 [00:20<00:00, 10.80it/s]

Epoch 53 validation end. LR: 0.00020690126647991342, Loss: 2.212163937830292, Correct answers: 77/226, Val metric: 0.8942184645997764



775it [04:43,  2.73it/s]
  1%|          | 2/226 [00:00<00:21, 10.61it/s]

Epoch 54 training end. LR: 0.000346146011309719, Loss: 0.9862238442417114, Correct answers: 8998/12415/train_metric:0.9390174239216832


100%|██████████| 226/226 [00:20<00:00, 10.82it/s]

Epoch 54 validation end. LR: 0.000346146011309719, Loss: 2.111309722461532, Correct answers: 72/226, Val metric: 0.8925989602754706



775it [04:38,  2.78it/s]
  0%|          | 1/226 [00:00<00:23,  9.42it/s]

Epoch 55 training end. LR: 0.0005005000000000086, Loss: 1.0765516342847579, Correct answers: 8905/12415/train_metric:0.9327954618214155


100%|██████████| 226/226 [00:27<00:00,  8.21it/s]

Epoch 55 validation end. LR: 0.0005005000000000086, Loss: 2.3586338980008015, Correct answers: 90/226, Val metric: 0.8836006811805276



775it [04:47,  2.69it/s]
  0%|          | 0/226 [00:00<?, ?it/s]

Epoch 56 training end. LR: 0.0006548539886902965, Loss: 1.1704868012666703, Correct answers: 8800/12415/train_metric:0.9256001615471415


100%|██████████| 226/226 [00:31<00:00,  7.26it/s]

Epoch 56 validation end. LR: 0.0006548539886902965, Loss: 2.1535642273658144, Correct answers: 70/226, Val metric: 0.8888953221548622



775it [04:45,  2.71it/s]
  1%|          | 2/226 [00:00<00:21, 10.66it/s]

Epoch 57 training end. LR: 0.0007940987335201021, Loss: 1.3058097012389092, Correct answers: 8754/12415/train_metric:0.9209527108725344


100%|██████████| 226/226 [00:31<00:00,  7.26it/s]

Epoch 57 validation end. LR: 0.0007940987335201021, Loss: 2.0735491039478675, Correct answers: 87/226, Val metric: 0.8919961463817278



775it [04:41,  2.76it/s]
  0%|          | 1/226 [00:00<00:22,  9.89it/s]

Epoch 58 training end. LR: 0.0009046039886903006, Loss: 1.385350898734985, Correct answers: 8650/12415/train_metric:0.9155998045093063


100%|██████████| 226/226 [00:24<00:00,  9.26it/s]

Epoch 58 validation end. LR: 0.0009046039886903006, Loss: 2.186882160406197, Correct answers: 79/226, Val metric: 0.8787272002538496



775it [04:48,  2.69it/s]
  1%|          | 2/226 [00:00<00:21, 10.61it/s]

Epoch 59 training end. LR: 0.0009755527298894444, Loss: 1.4027347331277786, Correct answers: 8709/12415/train_metric:0.9140915918546244


100%|██████████| 226/226 [00:37<00:00,  5.97it/s]

Epoch 59 validation end. LR: 0.0009755527298894444, Loss: 2.3718519970379046, Correct answers: 91/226, Val metric: 0.8629504210038094



775it [04:38,  2.78it/s]
  1%|          | 2/226 [00:00<00:21, 10.51it/s]

Epoch 60 training end. LR: 0.0010000000000000159, Loss: 1.413764071868312, Correct answers: 8636/12415/train_metric:0.9142631053174045


100%|██████████| 226/226 [00:31<00:00,  7.28it/s]

Epoch 60 validation end. LR: 0.0010000000000000159, Loss: 2.002259746062017, Correct answers: 79/226, Val metric: 0.8944347603229008



775it [04:44,  2.73it/s]
  0%|          | 1/226 [00:00<00:25,  8.85it/s]

Epoch 61 training end. LR: 0.0009755527298894445, Loss: 1.3850295092598084, Correct answers: 8655/12415/train_metric:0.9157536860832901


100%|██████████| 226/226 [00:24<00:00,  9.29it/s]

Epoch 61 validation end. LR: 0.0009755527298894445, Loss: 2.0376886245423713, Correct answers: 76/226, Val metric: 0.886698973839853



775it [04:39,  2.77it/s]
  1%|          | 2/226 [00:00<00:20, 10.70it/s]

Epoch 62 training end. LR: 0.0009046039886903009, Loss: 1.3671940509542342, Correct answers: 8732/12415/train_metric:0.917470962682062


100%|██████████| 226/226 [00:27<00:00,  8.18it/s]

Epoch 62 validation end. LR: 0.0009046039886903009, Loss: 2.162195385029886, Correct answers: 70/226, Val metric: 0.890991139726343



775it [04:47,  2.70it/s]
  1%|          | 2/226 [00:00<00:20, 10.71it/s]

Epoch 63 training end. LR: 0.0007940987335201026, Loss: 1.277372363305861, Correct answers: 8783/12415/train_metric:0.9221200631498192


100%|██████████| 226/226 [00:27<00:00,  8.13it/s]

Epoch 63 validation end. LR: 0.0007940987335201026, Loss: 2.136439568173569, Correct answers: 81/226, Val metric: 0.8846544854675042



775it [04:49,  2.68it/s]
  0%|          | 1/226 [00:00<00:24,  9.15it/s]

Epoch 64 training end. LR: 0.000654853988690297, Loss: 1.192944724309829, Correct answers: 8802/12415/train_metric:0.9244717208711432


100%|██████████| 226/226 [00:27<00:00,  8.14it/s]

Epoch 64 validation end. LR: 0.000654853988690297, Loss: 2.1957073718045663, Correct answers: 65/226, Val metric: 0.8905018519593576



775it [04:45,  2.72it/s]
  0%|          | 1/226 [00:00<00:23,  9.59it/s]

Epoch 65 training end. LR: 0.0005005000000000093, Loss: 1.067007034286376, Correct answers: 8865/12415/train_metric:0.9325112788280882


100%|██████████| 226/226 [00:27<00:00,  8.14it/s]

Epoch 65 validation end. LR: 0.0005005000000000093, Loss: 2.201848338135576, Correct answers: 78/226, Val metric: 0.8906952953961074



775it [04:42,  2.74it/s]
  1%|          | 2/226 [00:00<00:21, 10.64it/s]

Epoch 66 training end. LR: 0.00034614601130971963, Loss: 0.9776597053100986, Correct answers: 9048/12415/train_metric:0.9390884307509311


100%|██████████| 226/226 [00:24<00:00,  9.30it/s]

Epoch 66 validation end. LR: 0.00034614601130971963, Loss: 2.3128028274637407, Correct answers: 75/226, Val metric: 0.8916182078027778



775it [04:47,  2.70it/s]
  1%|          | 2/226 [00:00<00:21, 10.66it/s]

Epoch 67 training end. LR: 0.00020690126647991398, Loss: 0.8916314784365316, Correct answers: 9031/12415/train_metric:0.9426270528256151


100%|██████████| 226/226 [00:24<00:00,  9.30it/s]

Epoch 67 validation end. LR: 0.00020690126647991398, Loss: 2.2641257733370352, Correct answers: 80/226, Val metric: 0.8891921657006485



775it [04:45,  2.72it/s]
  1%|          | 2/226 [00:00<00:20, 10.67it/s]

Epoch 68 training end. LR: 9.639601130971554e-05, Loss: 0.7986888773018314, Correct answers: 9174/12415/train_metric:0.9490983027618612


100%|██████████| 226/226 [00:20<00:00, 10.83it/s]


Epoch 68 validation end. LR: 9.639601130971554e-05, Loss: 2.236720895345232, Correct answers: 78/226, Val metric: 0.8954899691016949
Saving new best model at epoch 


775it [04:39,  2.77it/s]
  0%|          | 1/226 [00:00<00:23,  9.45it/s]

Epoch 69 training end. LR: 2.5447270110571035e-05, Loss: 0.7576799906357642, Correct answers: 9124/12415/train_metric:0.950907918918347


100%|██████████| 226/226 [00:24<00:00,  9.29it/s]


Epoch 69 validation end. LR: 2.5447270110571035e-05, Loss: 2.234369296943192, Correct answers: 74/226, Val metric: 0.8957084983083473
Saving new best model at epoch 


775it [04:37,  2.79it/s]
  0%|          | 1/226 [00:00<00:22,  9.82it/s]

Epoch 70 training end. LR: 1e-06, Loss: 0.7512704860395001, Correct answers: 9120/12415/train_metric:0.9498720058707998


100%|██████████| 226/226 [00:31<00:00,  7.28it/s]

Epoch 70 validation end. LR: 1e-06, Loss: 2.261326450162229, Correct answers: 73/226, Val metric: 0.8951537669426851



775it [04:39,  2.77it/s]
  1%|          | 2/226 [00:00<00:20, 10.68it/s]

Epoch 71 training end. LR: 2.5447270110570814e-05, Loss: 0.7453123319389359, Correct answers: 9161/12415/train_metric:0.9520514574995509


100%|██████████| 226/226 [00:20<00:00, 10.82it/s]


Epoch 71 validation end. LR: 2.5447270110570814e-05, Loss: 2.266857413064062, Correct answers: 76/226, Val metric: 0.8935484974737251
Saving new best model at epoch 


775it [04:41,  2.76it/s]
  1%|          | 2/226 [00:00<00:20, 10.71it/s]

Epoch 72 training end. LR: 9.63960113097131e-05, Loss: 0.7756285530569091, Correct answers: 9156/12415/train_metric:0.9492678895336062


100%|██████████| 226/226 [00:34<00:00,  6.60it/s]

Epoch 72 validation end. LR: 9.63960113097131e-05, Loss: 2.2955441833597368, Correct answers: 73/226, Val metric: 0.8951664469115977



775it [04:42,  2.75it/s]
  0%|          | 1/226 [00:00<00:25,  8.88it/s]

Epoch 73 training end. LR: 0.0002069012664799077, Loss: 0.8083265879077296, Correct answers: 9120/12415/train_metric:0.9482400447752438


100%|██████████| 226/226 [00:24<00:00,  9.30it/s]

Epoch 73 validation end. LR: 0.0002069012664799077, Loss: 2.3335679273689744, Correct answers: 78/226, Val metric: 0.8916324290660235



775it [04:43,  2.73it/s]
  0%|          | 1/226 [00:00<00:26,  8.57it/s]

Epoch 74 training end. LR: 0.0003461460113097118, Loss: 0.8785748732282269, Correct answers: 9018/12415/train_metric:0.943375730374244


100%|██████████| 226/226 [00:20<00:00, 10.77it/s]

Epoch 74 validation end. LR: 0.0003461460113097118, Loss: 2.427501625719324, Correct answers: 89/226, Val metric: 0.8896404219960002



775it [04:35,  2.81it/s]
  1%|          | 2/226 [00:00<00:20, 10.70it/s]

Epoch 75 training end. LR: 0.0005004999999999965, Loss: 0.9486269805796684, Correct answers: 8966/12415/train_metric:0.9401270372788575


100%|██████████| 226/226 [00:24<00:00,  9.27it/s]

Epoch 75 validation end. LR: 0.0005004999999999965, Loss: 2.3205684307402215, Correct answers: 89/226, Val metric: 0.8878713100950525



775it [04:47,  2.70it/s]
  0%|          | 1/226 [00:00<00:25,  8.68it/s]

Epoch 76 training end. LR: 0.000654853988690283, Loss: 1.0605301184135099, Correct answers: 8914/12415/train_metric:0.9343113086467071


100%|██████████| 226/226 [00:27<00:00,  8.16it/s]

Epoch 76 validation end. LR: 0.000654853988690283, Loss: 2.1562026496482107, Correct answers: 72/226, Val metric: 0.8944443566873292



775it [04:45,  2.71it/s]
  1%|          | 2/226 [00:00<00:20, 10.70it/s]

Epoch 77 training end. LR: 0.0007940987335200873, Loss: 1.1500512392674722, Correct answers: 8860/12415/train_metric:0.9284123794342074


100%|██████████| 226/226 [00:34<00:00,  6.57it/s]

Epoch 77 validation end. LR: 0.0007940987335200873, Loss: 2.327410666288528, Correct answers: 65/226, Val metric: 0.8771671733286432



184it [01:10,  2.62it/s]


KeyboardInterrupt: 

In [14]:
torch.save(model.state_dict(), 'best_model_.pt')

In [30]:
model.load_state_dict(torch.load("best_model_.pt"))
files_ids = X_val.recording_id.unique()
model.cuda()
with torch.no_grad():
                # Stats
        answers_list =  []
        targets_list = []
        val_loss = []
        val_corr = []
        val_metrics = []
        model.eval()
        for i in tqdm(range(0, len(files_ids))):
            data, target = load_val_file(files_ids[i], X_val)
            data, target = torch.tensor(data), torch.tensor(target)
            data = data.float()
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda().unsqueeze(0)
            output = model(data)
            #output = output.squeeze()
            framewise_output = output["framewise_output"]
            output, _ = framewise_output.max(dim=1) 
            output, _ = torch.max(output, 0)
            #output, _ = torch.max(output["clipwise_output"], 0)
            output = output.unsqueeze(0)
            #print(output.shape)
            loss = lsep_loss_stable(output, target)
            val_metric = lwlrap(target.cpu().numpy(), output.cpu().numpy())
            vals, answers = torch.max(output, 1)
            vals, targets = torch.max(target, 1)
            answers_list.append(answers.item())
            targets_list.append(targets.item())
            val_metrics.append(val_metric.item())
            corrects = 0
            for i in range(0, len(answers)):
                if answers[i] == targets[i]:
                    corrects = corrects + 1
            val_corr.append(corrects)
            val_loss.append(loss.item())

    
valid_epoch_metric = sum(val_metrics) / len(val_loss)
# Stats
print('Loss: ' + str(sum(val_loss) / len(val_loss)) +
      ', Correct answers: ' + str(sum(val_corr)) + '/' + str(len(files_ids)) + ", Val metric: " + str(valid_epoch_metric))



100%|██████████| 227/227 [00:26<00:00,  8.49it/s]

Loss: 3.192483188822406, Correct answers: 107/227, Val metric: 0.89345172997735





In [31]:
errors = []
for i in range(len(answers_list)):
    if answers_list[i] != targets_list[i]:
        errors.append(targets_list[i])
        
from collections import Counter
error_count = sorted(Counter(errors).items(),key = lambda i: i[0])
target_count = sorted(Counter(targets_list).items(),key = lambda i: i[0])
print(error_count, target_count, sep = "\n")


[(0, 6), (1, 15), (3, 68), (4, 9), (5, 7), (7, 2), (8, 5), (10, 1), (13, 1), (14, 2), (17, 2), (23, 2)]
[(0, 27), (1, 30), (2, 7), (3, 84), (4, 18), (5, 10), (6, 2), (7, 2), (8, 6), (10, 2), (11, 10), (12, 3), (13, 3), (14, 6), (16, 1), (17, 3), (18, 2), (20, 1), (21, 1), (22, 4), (23, 5)]


In [19]:
# Already defined above; for reference
fft = 2048
hop = 512 * 1
# Less rounding errors this way
sr = 48000
length =  10 * sr
fmin = 84
fmax = 15056


def load_test_file(f): 
    wav, sr = librosa.load('../data/train/' + f, sr=None)

        # Split for enough segments to not miss anything
    segments = len(wav) / length
    segments = int(np.ceil(segments))
    
    mel_array = []
    
    for i in range(0, segments):
        # Last segment going from the end
        if (i + 1) * length > len(wav):
            wav_slice = wav[len(wav) - length:len(wav)]
        else:
            wav_slice = wav[i * length:(i + 1) * length]
        #new_sample_rate = 24000
        #wav_slice = librosa.resample(slice, Config.SR, new_sample_rate)
        #wav_slice = np.expand_dims(wav_slice, axis=0).astype(np.float32) 
        wav_slice = wav_slice.astype(np.float32) 
        mel_array.append(wav_slice)
    
    return np.array(mel_array)

In [21]:
model = AudioSEDModel(**model_param)
model.load_state_dict(torch.load(f"best_model_0_pseudo_resnet50.pt"))

model.cuda()
model.eval()
PERIOD = 10
global_time = 0.0
threshold = 0.1
estimated_event_list = []

test_files = train_files

for i in range(0, len(test_files)):
    global_time = 0.0
    data = load_test_file(test_files[i])
    file_id = str.split(test_files[i], '.')[0]
    for part in data:
    
        part = torch.tensor(part).unsqueeze(0)
        part = part.float()
        if torch.cuda.is_available():
            part = part.cuda()

        output = model(part)

        framewise_outputs = output["framewise_output"].detach().cpu().numpy()[0]

        thresholded = framewise_outputs >= threshold

        #print(thresholded)
        #print(thresholded.shape)

        for target_idx in range(thresholded.shape[1]):
            if thresholded[:, target_idx].mean() == 0:
                pass
            else:
                detected = np.argwhere(thresholded[:, target_idx]).reshape(-1)
                head_idx = 0
                tail_idx = 0
                while True:
                    if (tail_idx + 1 == len(detected)) or (detected[tail_idx + 1] - detected[tail_idx] != 1):
                        onset = 0.01 * detected[head_idx] + global_time
                        offset = 0.01 * detected[tail_idx] + global_time
                        onset_idx = detected[head_idx]
                        offset_idx = detected[tail_idx]
                        max_confidence = framewise_outputs[onset_idx:offset_idx, target_idx].max()
                        mean_confidence = framewise_outputs[onset_idx:offset_idx, target_idx].mean()
                        estimated_event = {
                            "file_id": file_id,
                            "species_id": target_idx,
                            "onset": onset,
                            "offset": offset,
                            "max_confidence": max_confidence,
                            "mean_confidence": mean_confidence
                        }
                        estimated_event_list.append(estimated_event)
                        head_idx = tail_idx + 1
                        tail_idx = tail_idx + 1
                        if head_idx >= len(detected):
                            break
                    else:
                        tail_idx += 1

        global_time += PERIOD

prediction_df = pd.DataFrame(estimated_event_list)

In [22]:
len(prediction_df.file_id.unique())

4718

In [23]:
prediction_df.to_csv("pseudolabels_raw_sed_train.csv", index=False)

In [34]:
submission = pd.DataFrame(columns=['recording_id','s0','s1','s2','s3','s4','s5','s6','s7','s8','s9','s10','s11',
                               's12','s13','s14','s15','s16','s17','s18','s19','s20','s21','s22','s23'])

In [37]:
prediction_df = pd.read_csv("pseudolabels_raw_sed.csv")

In [38]:
len(prediction_df.file_id.unique())

1992

In [39]:
#prediction_df = pd.read_csv("pseudolabels_raw_fold1.csv")
for file_id, sub_df in prediction_df.groupby("file_id"):
    events = sub_df[["file_id", "species_id", "onset", "offset", "max_confidence", ]]
    sub_row = []
    recording_id = events.file_id.unique()[0]
    sub_row.append(recording_id)
    unique = events.species_id.unique()
    label_array = np.zeros(24, dtype=np.float32)
    for i in unique:
        pred_proba = events[events.species_id==i].max_confidence.max()

        label_array[int(i)] = pred_proba 
    sub_row.extend(list(label_array))
    sub_series = pd.Series(sub_row, index = submission.columns)
    submission = submission.append(sub_series, ignore_index=True)



In [40]:
submission.shape

(1992, 25)

In [41]:
submission.to_csv("test_submission_from_frames.csv", index=False )