# Introduction
I wanted to share something that worked pretty well for me early on in this competition. The idea comes from a [2018 paper](https://arxiv.org/pdf/1703.01780.pdf) titled *Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results* by Antti Tarvainen and Harri Valpola. 

### Mean Teacher
Biefly, the idea is to use two models. A student model with weights trained the standard way, using backprop. And a teacher model with weights that are an exponential moving average of the student's weights. The teacher is the *mean* of the student \*ba dum tss\*. The student is then trained using two different losses, a standard classification loss and a consistency loss that penalizes student predictions that deviate from the teaher's. 

![](https://raw.githubusercontent.com/CuriousAI/mean-teacher/master/mean_teacher.png)

Mean teachers are useful in a semi-supervised context where we have both labeled and unlabeled samples. The consistency loss on the unlabeled samples acts as a form of regularization and helps the model generalize better. As an added bonus the final teacher model is a temporal ensemble which tends to perform better than the results at the end of a single epoch. 

### Missing Labels
As a few others have pointed out, there are a lot of missing labels. If we were to randomly sample a segment from the training data, we might consider it completely unlabeled rather than rely on the provided labels. We'll train our mean teacher model(s) on two classes of data, carefully selected positive samples and randomly selected unlabeled samples. The classification loss won't apply to the unlabeled samples. 

![](https://www.googleapis.com/download/storage/v1/b/kaggle-forum-message-attachments/o/inbox%2F4704212%2F9ca088bb386abf7114543c019c1d8a5f%2Ffig.png?generation=1609892974092435&alt=media)

*Thanks to [shinmura0](https://www.kaggle.com/shinmurashinmura) for the great visualization!*

### Results
For me, mean teacher worked a good bit better than baseline models with similar configurations. 

|                                         | Baseline | Mean Teacher |
|-----------------------------------------|----------|--------------|
| Well Tuned, 5 fold, from my local setup | 0.847        | **0.865**            |
| Single fold Expt1 on Kaggle                   | 0.592**        | **0.786**            |
| Single fold Expt2 on Kaggle                   | 0.826        | **0.830**            |
| 5 Fold on Kaggle***                        | 0.844        | **0.857**           |

\*\* I might have accidentally sabatoged this run.

\*\*\* There was a major bug in v21 of the notebook where the consistence_ramp was set to 1000 which means it was just normal / non-mean-teacher training. Setting consisteny_ramp to 6 and using the mean teacher, we get an improvement of 0.13.

In [1]:
import audiomentations as A
import os, time, librosa, random
from functools import partial
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
from tqdm import tqdm
import soundfile as sf
from contextlib import nullcontext
import datetime
from torch.utils.tensorboard import SummaryWriter
from timm.models import resnet34d, resnet34, resnext50d_32x4d, densenet121
import warnings
warnings.filterwarnings('ignore')

In [2]:
# from resnest.torch import resnest50, resnest101, resnest200
# from torchvision.models import resnet34, resnet50, resnet101, resnet152, densenet121, densenet169, densenet201, mobilenet_v2, resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2
# from efficientnet_pytorch import EfficientNet

In [3]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    
GLOBAL_SEED = 10
setup_seed(GLOBAL_SEED)

# Config
We'll start by setting up some global config variable that we'll access later.

In [4]:
# Global Vars
NO_LABEL = -1
NUM_CLASSES = 24

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
class config:
    seed = GLOBAL_SEED
    
    train_tp_csv = '/dev/shm/data/train_tp.csv'
    test_csv = '/dev/shm/data/sample_submission.csv'
    save_path = '/root/s/RFCX/model_save'
    res_path = '/root/s/RFCX/res'
    train_data_path = "/dev/shm/data/train"
    test_data_path = "/dev/shm/data/test"
    tensorboard_path = '/root/s/RFCX/tensorboard'
    model_name = "densenet121-teacher"
    model = densenet121
    
    percent_unlabeled = 1.0
    consistency_weight = 100.0
    consistency_rampup = 6
    ema_decay = 0.995
    positive_weight = 3.0
    
    lr = 1e-3
    epochs = 40
    batch_size = 32
    num_workers = 4
    train_5_folds = True
    
    period = 8 # 6 second clips
    step = 1
    model_params = {
        'sample_rate': 32000,
        'window_size': 2048,
        'hop_size': 512,
        'mel_bins': 256,
        'fmin': 20,
        'fmax': 16000,
        'classes_num': NUM_CLASSES
    }
    
    augmenter = A.Compose([
        A.AddGaussianNoise(p=0.33, max_amplitude=0.02),
        A.AddGaussianSNR(p=0.33),
        A.FrequencyMask(min_frequency_band=0.01,  max_frequency_band=0.25, p=0.33),
        A.TimeMask(min_band_part=0.01, max_band_part=0.25, p=0.33),
        A.Gain(p=0.33)
    ])


In [5]:
## Utils - Not much interesting going on here.

def get_n_fold_df(csv_path, folds=5):
    df = pd.read_csv(csv_path)
    df_group = df.groupby("recording_id")[["species_id"]].first().reset_index()
    df_group = df_group.sample(frac=1, random_state=config.seed).reset_index(drop=True)
    df_group.loc[:, 'fold'] = -1

    X = df_group["recording_id"].values
    y = df_group["species_id"].values

    kfold = StratifiedKFold(n_splits=folds, random_state=config.seed)
    for fold, (t_idx, v_idx) in enumerate(kfold.split(X, y)):
        df_group.loc[v_idx, "fold"] = fold

    return df.merge(df_group[['recording_id', 'fold']], on="recording_id", how="left")
    

def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)


def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class MetricMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.y_true = []
        self.y_pred = []

    def update(self, y_true, y_pred):
        try:
            self.y_true.extend(y_true.detach().cpu().numpy().tolist())
            self.y_pred.extend(torch.sigmoid(y_pred).cpu().detach().numpy().tolist())
        except:
            print("UPDATE FAILURE")

    def update_list(self, y_true, y_pred):
        self.y_true.extend(y_true)
        self.y_pred.extend(y_pred)

    @property
    def avg(self):
        score_class, weight = lwlrap(np.array(self.y_true), np.array(self.y_pred))
        self.score = (score_class * weight).sum()

        return self.score
    

def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.

    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled

def _one_sample_positive_class_precisions(scores, truth):
    num_classes = scores.shape[0]
    pos_class_indices = np.flatnonzero(truth > 0)

    if not len(pos_class_indices):
        return pos_class_indices, np.zeros(0)

    retrieved_classes = np.argsort(scores)[::-1]

    class_rankings = np.zeros(num_classes, dtype=np.int)
    class_rankings[retrieved_classes] = range(num_classes)

    retrieved_class_true = np.zeros(num_classes, dtype=np.bool)
    retrieved_class_true[class_rankings[pos_class_indices]] = True

    retrieved_cumulative_hits = np.cumsum(retrieved_class_true)

    precision_at_hits = (
            retrieved_cumulative_hits[class_rankings[pos_class_indices]] /
            (1 + class_rankings[pos_class_indices].astype(np.float)))
    return pos_class_indices, precision_at_hits


def lwlrap(truth, scores):
    assert truth.shape == scores.shape
    num_samples, num_classes = scores.shape
    precisions_for_samples_by_classes = np.zeros((num_samples, num_classes))
    for sample_num in range(num_samples):
        pos_class_indices, precision_at_hits = _one_sample_positive_class_precisions(scores[sample_num, :],
                                                                                     truth[sample_num, :])
        precisions_for_samples_by_classes[sample_num, pos_class_indices] = precision_at_hits

    labels_per_class = np.sum(truth > 0, axis=0)
    weight_per_class = labels_per_class / float(np.sum(labels_per_class))

    per_class_lwlrap = (np.sum(precisions_for_samples_by_classes, axis=0) /
                        np.maximum(1, labels_per_class))
    return per_class_lwlrap, weight_per_class


def pretty_print_metrics(fold, epoch, optimizer, train_loss_metrics, val_loss_metrics):
    print(f"""
    {time.ctime()} \n
    Fold:{fold}, Epoch:{epoch}, LR:{optimizer.param_groups[0]['lr']:.7}, Cons. Weight: {train_loss_metrics['consistency_weight']}\n
    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                {train_loss_metrics['loss']:0.4f}   |   {val_loss_metrics['loss']:0.4f}\n
    LWLRAP:              {train_loss_metrics['lwlrap']:0.4f}   |   {val_loss_metrics['lwlrap']:0.4f}\n
    Class Loss:          {train_loss_metrics['class_loss']:0.4f}   |   {val_loss_metrics['class_loss']:0.4f}\n
    Consistency Loss:    {train_loss_metrics['consistency_loss']:0.4f}   |   {val_loss_metrics['consistency_loss']:0.4f}\n
    --------------------------------------------------------\n
    """)
    

class TestDataset(Dataset):
    def __init__(self, df, data_path, period=10, step=1):
        self.data_path = data_path
        self.period = period
        self.step = step
        self.recording_ids = list(df["recording_id"].unique())

    def __len__(self):
        return len(self.recording_ids)

    def __getitem__(self, idx):
        recording_id = self.recording_ids[idx]

        y, sr = sf.read(f"{self.data_path}/{recording_id}.wav")

        len_y = len(y)
        effective_length = sr * self.period
        effective_step = sr * self.step

        y_ = []
        i = 0
        while i+effective_length <= len_y:
            y__ = y[i:i + effective_length]
            y_.append(y__)
            i = i + effective_step

        y = np.stack(y_)

        label = np.zeros(NUM_CLASSES, dtype='f')

        return {
            "waveform": y,
            "target": torch.tensor(label, dtype=torch.float),
            "id": recording_id
        }


def predict_on_test(model, test_loader):
    model.eval()
    pred_list = []
    id_list = []
    with torch.no_grad():
        t = tqdm(test_loader)
        for i, sample in enumerate(t):
            input = sample["waveform"].cuda()
            bs, seq, w = input.shape
            input = input.reshape(bs * seq, w)
            id = sample["id"]
            output, _ = model(input)
            output = output.reshape(bs, seq, -1)
            output, _ = torch.max(output, dim=1)
            
            output = output.cpu().detach().numpy().tolist()
            pred_list.extend(output)
            id_list.extend(id)

    return pred_list, id_list

# Model
The model should look pretty familiar if you're using [SED](https://arxiv.org/abs/1912.04761). (Huge thanks to [Hidehisa Arai](https://www.kaggle.com/hidehisaarai1213) and their [SED Notebook](https://www.kaggle.com/hidehisaarai1213/introduction-to-sound-event-detection)!) You could use any model you'd like here. There's just one small tweak we need to make for our mean teacher setup. We need to "detach" the teacher's parameters so they aren't updated by the optimizer.

In [6]:
class feature_extractor(nn.Module):
    def __init__(self, original):
        super().__init__()
        self.model = original
    def forward(self, x):
        x= self.model.extract_features(x)
        return x


class AttentionHead(nn.Module):
    
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.conv_attention = nn.Conv1d(in_channels=in_features, 
                                        out_channels=out_features,
                                        kernel_size=1, stride=1, 
                                        padding=0, bias=True)
        self.conv_classes = nn.Conv1d(in_channels=in_features, 
                                      out_channels=out_features,
                                      kernel_size=1, stride=1, 
                                      padding=0, bias=True)
        self.batch_norm_attention = nn.BatchNorm1d(out_features)
        self.init_weights()

    def init_weights(self):
        init_layer(self.conv_attention)
        init_layer(self.conv_classes)
        init_bn(self.batch_norm_attention)

    def forward(self, x):
        norm_att = torch.softmax(torch.tanh(self.conv_attention(x)), dim=-1)
        classes = self.conv_classes(x)
        x = torch.sum(norm_att * classes, dim=2)
        return x, norm_att, classes


class SEDAudioClassifier(nn.Module):

    def __init__(self, sample_rate, window_size, hop_size, 
                 mel_bins, fmin, fmax, classes_num):
        super().__init__()
        self.interpolate_ratio = 32

        self.spectrogram_extractor = Spectrogram(n_fft=window_size, 
                                                 hop_length=hop_size,
                                                 win_length=window_size, 
                                                 window='hann', center=True,
                                                 pad_mode='reflect', 
                                                 freeze_parameters=True)
        self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
                                                 n_mels=mel_bins, fmin=fmin, 
                                                 fmax=fmax, ref=1.0, 
                                                 amin=1e-10, top_db=None, 
                                                 freeze_parameters=True)

        self.batch_norm = nn.BatchNorm2d(mel_bins)
        
        self.model = config.model(pretrained=True, in_chans=1)
        if config.model_name.startswith('densenet'):
            self.in_features = self.model.classifier.in_features
        else:
            self.in_features = self.model.fc.in_features
        self.fc = nn.Linear(self.in_features, 
                            1024, bias=True)
        self.att_head = AttentionHead(1024, classes_num)
        self.avg_pool = nn.modules.pooling.AdaptiveAvgPool2d((1, 1))

        self.init_weight()

    def init_weight(self):
        init_bn(self.batch_norm)
        init_layer(self.fc)
        self.att_head.init_weights()

    def forward(self, input, spec_aug=False, 
                mixup_lambda=None, return_encoding=False):
        x = self.spectrogram_extractor(input.float())
        x = self.logmel_extractor(x)
        
        x = x.transpose(1, 3)
        x = self.batch_norm(x)
        x = x.transpose(1, 3)
        x = self.model.forward_features(x)
        x = torch.mean(x, dim=3)
        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2

        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc(x))
        x = x.transpose(1, 2)
        x = F.dropout(x, p=0.5, training=self.training)

        (clipwise_output, norm_att, segmentwise_output) = self.att_head(x)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        framewise_output = interpolate(segmentwise_output, self.interpolate_ratio)
        return clipwise_output, framewise_output


def get_model(is_mean_teacher=False):
    model = SEDAudioClassifier(**config.model_params).cuda()
    
    # Detach params for Exponential Moving Average Model (aka the Mean Teacher).
    # We'll manually update these params instead of using backprop.
    if is_mean_teacher:
        for param in model.parameters():
            param.detach_()
    return model

# Loss Function
The loss function has 2 components:

1. A classification loss that only applies to labeled samples.
2. A consistency loss that applies to all samples. 

For the consistency loss we'll use the mean square error between the student and teacher predictions. We'll slowly ramp up the influence of the consistency loss since we don't want bad, early predictions having too much influence. 

Notice that we're weighting the positive samples for the classification loss. This is because we know the positives are correct while we're less sure about the negatives due to the missing labels issue. I found that this works better in practice. 

In [7]:
def sigmoid_mse_loss(input_logits, target_logits):
    assert input_logits.size() == target_logits.size()
    input_softmax = torch.sigmoid(input_logits)
    target_softmax = torch.sigmoid(target_logits)
    num_classes = input_logits.size()[1]
    return F.mse_loss(input_softmax, target_softmax, size_average=False
                     ) / num_classes


class MeanTeacherLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.positive_weight = torch.ones(NUM_CLASSES) * config.positive_weight
        self.class_criterion = nn.BCEWithLogitsLoss(reduction='none', pos_weight=self.positive_weight)
        self.consistency_criterion = sigmoid_mse_loss

    def make_safe(self, pred):
        pred = torch.where(torch.isnan(pred), torch.zeros_like(pred), pred)
        return torch.where(torch.isinf(pred), torch.zeros_like(pred), pred)
        
    def get_consistency_weight(self, epoch):
        # Consistency ramp-up from https://arxiv.org/abs/1610.02242
        return config.consistency_weight * sigmoid_rampup(
            epoch, config.consistency_rampup)
    
    def forward(self, student_pred, teacher_pred, target, classif_weights, epoch):
        student_pred = self.make_safe(student_pred)
        teacher_pred = self.make_safe(teacher_pred).detach().data

        batch_size = len(target)
        labeled_batch_size = target.ne(NO_LABEL).all(axis=1).sum().item() + 1e-3

        student_classif, student_consistency = student_pred, student_pred
        student_class_loss = (self.class_criterion(
            student_classif, target) * classif_weights / labeled_batch_size).sum()

        consistency_weights = self.get_consistency_weight(epoch)
        consistency_loss = consistency_weights * self.consistency_criterion(
            student_consistency, teacher_pred) / batch_size
        loss = student_class_loss + consistency_loss
        return loss, student_class_loss, consistency_loss, consistency_weights

# Data Loader
The data loader produces two types of samples:

1. Labeled samples with the audio centered in the clip.
2. Random unlabeled clips without labels selected from files with at least one true positive label.

Each sample contains 2 different inputs, one for the student and one for the teacher. Different augmentations are applied to each input.

In [8]:
class MeanTeacherDataset(Dataset):
    
    def __init__(self, df, transforms, period=5, 
                 data_path=config.train_data_path, 
                 val=False, percent_unlabeled=0.0):
        self.period = period
        self.transforms = transforms
        self.data_path = data_path
        self.val = val
        self.percent_unlabeled = percent_unlabeled

        dfgby = df.groupby("recording_id").agg(lambda x: list(x)).reset_index()
        self.recording_ids = dfgby["recording_id"].values
        self.species_ids = dfgby["species_id"].values
        self.t_mins = dfgby["t_min"].values
        self.t_maxs = dfgby["t_max"].values

    def __len__(self):
        return int(len(self.recording_ids) * (1 + self.percent_unlabeled))

    def __getitem__(self, idx):
        if idx >= len(self.recording_ids):
            audio, label, rec_id, sr = self.get_unlabeled_item(idx)
            # For unlabeled samples, we zero out the classification loss.
            classif_weights = np.zeros(NUM_CLASSES, dtype='f')
        else:
            audio, label, rec_id, sr = self.get_labeled_item(idx)
            classif_weights = np.ones(NUM_CLASSES, dtype='f')

        audio_teacher = np.copy(audio)

        # The 2 samples fed to the 2 models have should have different augmentations.
        audio = self.transforms(samples=audio, sample_rate=sr)
        audio_teacher = self.transforms(samples=audio_teacher, sample_rate=sr)
        # assert (audio != audio_teacher).any()
        
        return {
            "waveform": audio,
            "teacher_waveform": audio_teacher,
            "target": torch.tensor(label, dtype=torch.float),
            "classification_weights": classif_weights,
            "id": rec_id
        }

    def get_labeled_item(self, idx):
        recording_id = self.recording_ids[idx]
        species_id = self.species_ids[idx]
        t_min, t_max = self.t_mins[idx], self.t_maxs[idx]

        rec, sr = sf.read(f"{self.data_path}/{recording_id}.wav")

        len_rec = len(rec)
        effective_length = sr * self.period
        rint = np.random.randint(len(t_min))
        tmin, tmax = round(sr * t_min[rint]), round(sr * t_max[rint])
        dur = tmax - tmin
        min_dur = min(dur, round(sr * self.period))

        center = round((tmin + tmax) / 2)
        rand_start = center - effective_length + max(min_dur - dur//2, 0)
        if rand_start < 0:
            rand_start = 0
        rand_end = center - max(min_dur - dur//2, 0)
        start = np.random.randint(rand_start, rand_end)
        rec = rec[start:start + effective_length]
        if len(rec) < effective_length:
            new_rec = np.zeros(effective_length, dtype=rec.dtype)
            start1 = np.random.randint(effective_length - len(rec))
            new_rec[start1:start1 + len(rec)] = rec
            rec = new_rec.astype(np.float32)
        else:
            rec = rec.astype(np.float32)

        start_time = start / sr
        end_time = (start + effective_length) / sr

        label = np.zeros(NUM_CLASSES, dtype='f')

        for i in range(len(t_min)):
            if (t_min[i] >= start_time) & (t_max[i] <= end_time):
                label[species_id[i]] = 1
            elif start_time <= ((t_min[i] + t_max[i]) / 2) <= end_time:
                label[species_id[i]] = 1

        return rec, label, recording_id, sr

    def get_unlabeled_item(self, idx, random_sample=False):
        real_idx = idx - len(self.recording_ids)
        # We want our validation set to be fixed.
        if self.val:
            rec_id = self.recording_ids[real_idx]
        else:
            rec_id = random.sample(list(self.recording_ids), 1)[0]

        rec, sr = sf.read(f"{self.data_path}/{rec_id}.wav")
        effective_length = int(sr * self.period)
        max_end = len(rec) - effective_length
        if self.val:
            # Fixed start for validation. Probaably a better way to do this.
            start = int(idx * 16963 % max_end)
        else:
            start = np.random.randint(0, max_end)
        rec = rec[start:(start+effective_length)]
        rec = rec.astype(np.float32)

        label = np.ones(NUM_CLASSES, dtype='f') * NO_LABEL

        return rec, label, rec_id, sr

    
def get_data_loader(df, is_val=False):
    dataset = MeanTeacherDataset(
        df=df,
        transforms=config.augmenter,
        period=config.period,
        percent_unlabeled=config.percent_unlabeled
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=not is_val,
        drop_last=not is_val,
        num_workers=config.num_workers
    )

# Training
At the end of each training step we update the teacher weights by averaging in the latest student weights.

In [9]:
# Update teacher to be exponential moving average of student params.
def update_teacher_params(student, teacher, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(teacher.parameters(), student.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)


def train_one_epoch(student, mean_teacher, loader, 
                    criterion, optimizer, scheduler, epoch, is_val=False):
    global_step = 0
    losses = AverageMeter()
    consistency_loss_avg = AverageMeter()
    class_loss_avg = AverageMeter()
    comp_metric = MetricMeter()
    
    if is_val:
        student.eval()
        mean_teacher.eval()
        context = torch.no_grad()
    else:
        student.train()
        mean_teacher.train()
        context = nullcontext()
    
    with context:
        t = tqdm(loader)
        for i, sample in enumerate(t):
            student_input = sample['waveform'].cuda()
            teacher_input = sample['teacher_waveform'].cuda()
            target = sample['target'].cuda()
            classif_weights = sample['classification_weights'].cuda()
            batch_size = len(target)

            student_pred, _  = student(student_input)
            teacher_pred, _  = mean_teacher(teacher_input)

            loss, class_loss, consistency_loss, consistency_weight = criterion(
                student_pred, teacher_pred, target, classif_weights, epoch)

            if not is_val:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                update_teacher_params(student, mean_teacher, 
                                      config.ema_decay, global_step)

                scheduler.step()

            comp_metric.update(target, student_pred)
            losses.update(loss.item(), batch_size)
            consistency_loss_avg.update(consistency_loss.item(), batch_size)
            class_loss_avg.update(class_loss.item(), batch_size)
            global_step += 1

            t.set_description(f"Epoch:{epoch} - Loss:{losses.avg:0.4f}")
        t.close()
    return {'lwlrap':comp_metric.avg, 
            'loss':losses.avg, 
            'consistency_loss':consistency_loss_avg.avg, 
            'class_loss':class_loss_avg.avg, 
            'consistency_weight':consistency_weight}

Finally putting everything together...

In [None]:
def train(df, fold, writer=None):
    train_df = df[df.fold != fold]
    val_df = df[df.fold == fold]
    train_loader = get_data_loader(train_df)
    val_loader = get_data_loader(val_df)

    student_model = get_model()
    teacher_model = get_model(is_mean_teacher=True)
    optimizer = torch.optim.AdamW(student_model.parameters(), lr=config.lr, weight_decay=0.01)
    warmup_prob = 0.2
    num_train_steps = int(len(train_loader) * config.epochs)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_train_steps)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, epochs=config.epochs, steps_per_epoch=len(train_loader), pct_start=warmup_prob, div_factor=25, anneal_strategy='cos', cycle_momentum=True)
    criterion = MeanTeacherLoss().cuda()
    best_val_metric = -np.inf
    val_metrics = []
    train_metrics = []
    for epoch in range(0, config.epochs):
        train_loss_metrics = train_one_epoch(
            student_model, teacher_model, train_loader, 
            criterion, optimizer, scheduler, epoch)
        val_loss_metrics = train_one_epoch(
            student_model, teacher_model, val_loader, 
            criterion, optimizer, scheduler, epoch, is_val=True)

        train_metrics.append(train_loss_metrics)
        val_metrics.append(val_loss_metrics)
        pretty_print_metrics(fold, epoch, optimizer, 
                             train_loss_metrics, val_loss_metrics)
        n_iters = len(train_loader) * (epoch + 1)
        writer.add_scalar('fold_{}/train_loss'.format(fold), train_loss_metrics['loss'], n_iters)
        writer.add_scalar('fold_{}/train_LWLRAP'.format(fold), train_loss_metrics['lwlrap'], n_iters)
        writer.add_scalar('fold_{}/learning_rate'.format(fold), scheduler.get_last_lr()[0], n_iters)
        writer.add_scalar('fold_{}/validate_loss'.format(fold), val_loss_metrics['loss'], n_iters)
        writer.add_scalar('fold_{}/validate_LWLRAP'.format(fold), val_loss_metrics['lwlrap'], n_iters)
        
        if val_loss_metrics['lwlrap'] > best_val_metric:
            print(f"    LWLRAP Improved from {best_val_metric} --> {val_loss_metrics['lwlrap']}\n")
            torch.save(teacher_model.state_dict(), os.path.join(config.save_path, f'{config.model_name}-fold-{fold}.bin'))
            best_val_metric = val_loss_metrics['lwlrap']
#     torch.save(teacher_model.state_dict(), os.path.join(config.save_path, f'{config.model_name}-fold-{fold}.bin'))
    

df = get_n_fold_df(config.train_tp_csv)
time_stamp = '{0:%m_%d_%H_%M}'.format(datetime.datetime.now())
writer=None
writer = SummaryWriter(log_dir=os.path.join(config.tensorboard_path, '{}_{}'.format(config.model_name, time_stamp)))
for fold in range(5 if config.train_5_folds else 1):
    train(df, fold, writer)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth" to /root/.cache/torch/hub/checkpoints/densenet121_ra-50efcf5c.pth
Epoch:0 - Loss:14.5036: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:0 - Loss:10.5784: 100%|██████████| 14/14 [00:04<00:00,  3.02it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:01:59 2021 

    Fold:0, Epoch:0, LR:7.669937e-05, Cons. Weight: 0.6737946999085467

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                14.5036   |   10.5784

    LWLRAP:              0.1760   |   0.2212

    Class Loss:          14.4482   |   10.5633

    Consistency Loss:    0.0553   |   0.0151

    --------------------------------------------------------

    
    LWLRAP Improved from -inf --> 0.2212203733374102



Epoch:1 - Loss:10.4970: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:1 - Loss:8.9118: 100%|██████████| 14/14 [00:04<00:00,  3.18it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:02:34 2021 

    Fold:0, Epoch:1, LR:0.0001811856, Cons. Weight: 3.1047958479329627

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                10.4970   |   8.9118

    LWLRAP:              0.2124   |   0.3078

    Class Loss:          10.4155   |   8.8688

    Consistency Loss:    0.0815   |   0.0430

    --------------------------------------------------------

    
    LWLRAP Improved from 0.2212203733374102 --> 0.30784085104042924



Epoch:2 - Loss:9.0896: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:2 - Loss:8.1459: 100%|██████████| 14/14 [00:04<00:00,  3.12it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:03:09 2021 

    Fold:0, Epoch:2, LR:0.0003374814, Cons. Weight: 10.836802322189582

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                9.0896   |   8.1459

    LWLRAP:              0.3209   |   0.4017

    Class Loss:          8.8325   |   7.9929

    Consistency Loss:    0.2571   |   0.1529

    --------------------------------------------------------

    
    LWLRAP Improved from 0.30784085104042924 --> 0.40167516058803876



Epoch:3 - Loss:8.3922: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:3 - Loss:7.3820: 100%|██████████| 14/14 [00:04<00:00,  3.21it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:03:44 2021 

    Fold:0, Epoch:3, LR:0.0005216868, Cons. Weight: 28.650479686019008

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                8.3922   |   7.3820

    LWLRAP:              0.4112   |   0.5206

    Class Loss:          7.8055   |   6.8553

    Consistency Loss:    0.5867   |   0.5267

    --------------------------------------------------------

    
    LWLRAP Improved from 0.40167516058803876 --> 0.520587289584405



Epoch:4 - Loss:7.5844: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:4 - Loss:6.8672: 100%|██████████| 14/14 [00:04<00:00,  3.12it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:04:19 2021 

    Fold:0, Epoch:4, LR:0.0007056342, Cons. Weight: 57.375342073743276

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                7.5844   |   6.8672

    LWLRAP:              0.5304   |   0.6357

    Class Loss:          6.6143   |   5.9313

    Consistency Loss:    0.9701   |   0.9359

    --------------------------------------------------------

    
    LWLRAP Improved from 0.520587289584405 --> 0.6357295221409716



Epoch:5 - Loss:6.5765: 100%|██████████| 56/56 [00:30<00:00,  1.87it/s]
Epoch:5 - Loss:6.1757: 100%|██████████| 14/14 [00:04<00:00,  3.09it/s]



    Sat Feb 13 16:04:53 2021 

    Fold:0, Epoch:5, LR:0.0008611956, Cons. Weight: 87.03247258333906

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                6.5765   |   6.1757

    LWLRAP:              0.6599   |   0.6936

    Class Loss:          5.3253   |   4.8662

    Consistency Loss:    1.2512   |   1.3095

    --------------------------------------------------------

    
    LWLRAP Improved from 0.6357295221409716 --> 0.6936470189978963



Epoch:6 - Loss:5.9334: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:6 - Loss:5.8902: 100%|██████████| 14/14 [00:04<00:00,  3.22it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:05:28 2021 

    Fold:0, Epoch:6, LR:0.0009645834, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                5.9334   |   5.8902

    LWLRAP:              0.7187   |   0.7434

    Class Loss:          4.6365   |   4.3414

    Consistency Loss:    1.2970   |   1.5488

    --------------------------------------------------------

    
    LWLRAP Improved from 0.6936470189978963 --> 0.7434038479851006



Epoch:7 - Loss:5.3950: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:7 - Loss:6.0759: 100%|██████████| 14/14 [00:04<00:00,  3.07it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:06:03 2021 

    Fold:0, Epoch:7, LR:0.0009999992, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                5.3950   |   6.0759

    LWLRAP:              0.7395   |   0.7603

    Class Loss:          4.2165   |   4.3066

    Consistency Loss:    1.1784   |   1.7693

    --------------------------------------------------------

    
    LWLRAP Improved from 0.7434038479851006 --> 0.760283210809797



Epoch:8 - Loss:4.8909: 100%|██████████| 56/56 [00:30<00:00,  1.85it/s]
Epoch:8 - Loss:5.3818: 100%|██████████| 14/14 [00:04<00:00,  3.15it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:06:38 2021 

    Fold:0, Epoch:8, LR:0.0009975057, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                4.8909   |   5.3818

    LWLRAP:              0.7818   |   0.7867

    Class Loss:          3.7077   |   3.8335

    Consistency Loss:    1.1832   |   1.5483

    --------------------------------------------------------

    
    LWLRAP Improved from 0.760283210809797 --> 0.7866831467602393



Epoch:9 - Loss:4.4040: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:9 - Loss:4.7662: 100%|██████████| 14/14 [00:04<00:00,  3.15it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:07:13 2021 

    Fold:0, Epoch:9, LR:0.0009902209, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                4.4040   |   4.7662

    LWLRAP:              0.8236   |   0.8042

    Class Loss:          3.3030   |   3.4850

    Consistency Loss:    1.1010   |   1.2812

    --------------------------------------------------------

    
    LWLRAP Improved from 0.7866831467602393 --> 0.804202350111961



Epoch:10 - Loss:3.8657: 100%|██████████| 56/56 [00:30<00:00,  1.85it/s]
Epoch:10 - Loss:5.2758: 100%|██████████| 14/14 [00:04<00:00,  3.06it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:07:48 2021 

    Fold:0, Epoch:10, LR:0.0009782151, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.8657   |   5.2758

    LWLRAP:              0.8342   |   0.7956

    Class Loss:          2.8796   |   4.0383

    Consistency Loss:    0.9861   |   1.2375

    --------------------------------------------------------

    


Epoch:11 - Loss:3.9101: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:11 - Loss:4.9435: 100%|██████████| 14/14 [00:04<00:00,  3.13it/s]



    Sat Feb 13 16:08:23 2021 

    Fold:0, Epoch:11, LR:0.0009616038, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.9101   |   4.9435

    LWLRAP:              0.8488   |   0.8250

    Class Loss:          2.8295   |   3.4987

    Consistency Loss:    1.0806   |   1.4448

    --------------------------------------------------------

    
    LWLRAP Improved from 0.804202350111961 --> 0.8249842668344872



Epoch:12 - Loss:3.4293: 100%|██████████| 56/56 [00:29<00:00,  1.87it/s]
Epoch:12 - Loss:5.1043: 100%|██████████| 14/14 [00:04<00:00,  3.26it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:08:57 2021 

    Fold:0, Epoch:12, LR:0.000940547, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.4293   |   5.1043

    LWLRAP:              0.8634   |   0.7972

    Class Loss:          2.4859   |   3.6687

    Consistency Loss:    0.9433   |   1.4356

    --------------------------------------------------------

    


Epoch:13 - Loss:3.4359: 100%|██████████| 56/56 [00:30<00:00,  1.84it/s]
Epoch:13 - Loss:4.4113: 100%|██████████| 14/14 [00:04<00:00,  3.10it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:09:32 2021 

    Fold:0, Epoch:13, LR:0.0009152475, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.4359   |   4.4113

    LWLRAP:              0.8629   |   0.8326

    Class Loss:          2.4507   |   3.4067

    Consistency Loss:    0.9853   |   1.0046

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8249842668344872 --> 0.8325659812543871



Epoch:14 - Loss:3.0140: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:14 - Loss:5.2966: 100%|██████████| 14/14 [00:04<00:00,  3.14it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:10:07 2021 

    Fold:0, Epoch:14, LR:0.000885949, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.0140   |   5.2966

    LWLRAP:              0.8922   |   0.8117

    Class Loss:          2.0897   |   3.9737

    Consistency Loss:    0.9244   |   1.3229

    --------------------------------------------------------

    


Epoch:15 - Loss:3.1196: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:15 - Loss:4.6319: 100%|██████████| 14/14 [00:04<00:00,  3.06it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:10:42 2021 

    Fold:0, Epoch:15, LR:0.0008529336, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                3.1196   |   4.6319

    LWLRAP:              0.8816   |   0.8086

    Class Loss:          2.2261   |   3.4411

    Consistency Loss:    0.8936   |   1.1908

    --------------------------------------------------------

    


Epoch:16 - Loss:2.6369: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:16 - Loss:4.6409: 100%|██████████| 14/14 [00:04<00:00,  3.20it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:11:17 2021 

    Fold:0, Epoch:16, LR:0.0008165193, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.6369   |   4.6409

    LWLRAP:              0.9032   |   0.8040

    Class Loss:          1.7884   |   3.4709

    Consistency Loss:    0.8485   |   1.1699

    --------------------------------------------------------

    


Epoch:17 - Loss:2.7288: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:17 - Loss:5.7599: 100%|██████████| 14/14 [00:04<00:00,  3.06it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:11:51 2021 

    Fold:0, Epoch:17, LR:0.0007770567, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.7288   |   5.7599

    LWLRAP:              0.9076   |   0.7989

    Class Loss:          1.8141   |   4.3494

    Consistency Loss:    0.9147   |   1.4105

    --------------------------------------------------------

    


Epoch:18 - Loss:2.5437: 100%|██████████| 56/56 [00:30<00:00,  1.87it/s]
Epoch:18 - Loss:4.3165: 100%|██████████| 14/14 [00:04<00:00,  3.22it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:12:26 2021 

    Fold:0, Epoch:18, LR:0.000734926, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.5437   |   4.3165

    LWLRAP:              0.9144   |   0.8424

    Class Loss:          1.7109   |   3.1636

    Consistency Loss:    0.8328   |   1.1529

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8325659812543871 --> 0.8424419370077845



Epoch:19 - Loss:2.7906: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:19 - Loss:4.8063: 100%|██████████| 14/14 [00:04<00:00,  3.16it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:13:01 2021 

    Fold:0, Epoch:19, LR:0.0006905328, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.7906   |   4.8063

    LWLRAP:              0.9071   |   0.8308

    Class Loss:          1.9137   |   3.6256

    Consistency Loss:    0.8769   |   1.1806

    --------------------------------------------------------

    


Epoch:20 - Loss:2.1940: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:20 - Loss:4.0270: 100%|██████████| 14/14 [00:04<00:00,  3.08it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:13:35 2021 

    Fold:0, Epoch:20, LR:0.0006443047, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.1940   |   4.0270

    LWLRAP:              0.9255   |   0.8374

    Class Loss:          1.4549   |   3.0663

    Consistency Loss:    0.7392   |   0.9606

    --------------------------------------------------------

    


Epoch:21 - Loss:2.0450: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:21 - Loss:3.9722: 100%|██████████| 14/14 [00:04<00:00,  3.09it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:14:10 2021 

    Fold:0, Epoch:21, LR:0.0005966869, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                2.0450   |   3.9722

    LWLRAP:              0.9330   |   0.8425

    Class Loss:          1.3063   |   3.0535

    Consistency Loss:    0.7387   |   0.9188

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8424419370077845 --> 0.8424838945108336



Epoch:22 - Loss:1.7801: 100%|██████████| 56/56 [00:30<00:00,  1.84it/s]
Epoch:22 - Loss:4.5900: 100%|██████████| 14/14 [00:04<00:00,  3.14it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:14:45 2021 

    Fold:0, Epoch:22, LR:0.000548138, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.7801   |   4.5900

    LWLRAP:              0.9448   |   0.8226

    Class Loss:          1.1292   |   3.5806

    Consistency Loss:    0.6510   |   1.0093

    --------------------------------------------------------

    


Epoch:23 - Loss:1.8173: 100%|██████████| 56/56 [00:30<00:00,  1.85it/s]
Epoch:23 - Loss:4.2262: 100%|██████████| 14/14 [00:04<00:00,  3.14it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:15:20 2021 

    Fold:0, Epoch:23, LR:0.0004991254, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.8173   |   4.2262

    LWLRAP:              0.9367   |   0.8299

    Class Loss:          1.1330   |   3.3697

    Consistency Loss:    0.6843   |   0.8566

    --------------------------------------------------------

    


Epoch:24 - Loss:1.8325: 100%|██████████| 56/56 [00:30<00:00,  1.85it/s]
Epoch:24 - Loss:4.0908: 100%|██████████| 14/14 [00:04<00:00,  3.18it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:15:55 2021 

    Fold:0, Epoch:24, LR:0.0004501214, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.8325   |   4.0908

    LWLRAP:              0.9452   |   0.8382

    Class Loss:          1.1717   |   3.2999

    Consistency Loss:    0.6608   |   0.7909

    --------------------------------------------------------

    


Epoch:25 - Loss:1.6239: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:25 - Loss:4.3682: 100%|██████████| 14/14 [00:04<00:00,  3.10it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:16:30 2021 

    Fold:0, Epoch:25, LR:0.0004015977, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.6239   |   4.3682

    LWLRAP:              0.9539   |   0.8326

    Class Loss:          0.9966   |   3.4648

    Consistency Loss:    0.6273   |   0.9034

    --------------------------------------------------------

    


Epoch:26 - Loss:1.4095: 100%|██████████| 56/56 [00:30<00:00,  1.85it/s]
Epoch:26 - Loss:3.7348: 100%|██████████| 14/14 [00:04<00:00,  3.13it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:17:05 2021 

    Fold:0, Epoch:26, LR:0.0003540217, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.4095   |   3.7348

    LWLRAP:              0.9581   |   0.8656

    Class Loss:          0.8026   |   2.9191

    Consistency Loss:    0.6069   |   0.8157

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8424838945108336 --> 0.8656460917730684



Epoch:27 - Loss:1.4517: 100%|██████████| 56/56 [00:30<00:00,  1.85it/s]
Epoch:27 - Loss:4.3264: 100%|██████████| 14/14 [00:04<00:00,  3.20it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:17:40 2021 

    Fold:0, Epoch:27, LR:0.0003078515, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.4517   |   4.3264

    LWLRAP:              0.9582   |   0.8297

    Class Loss:          0.8057   |   3.4419

    Consistency Loss:    0.6460   |   0.8846

    --------------------------------------------------------

    


Epoch:28 - Loss:1.3813: 100%|██████████| 56/56 [00:30<00:00,  1.87it/s]
Epoch:28 - Loss:3.5410: 100%|██████████| 14/14 [00:04<00:00,  3.04it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:18:14 2021 

    Fold:0, Epoch:28, LR:0.0002635319, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.3813   |   3.5410

    LWLRAP:              0.9620   |   0.8574

    Class Loss:          0.8085   |   2.8231

    Consistency Loss:    0.5727   |   0.7179

    --------------------------------------------------------

    


Epoch:29 - Loss:1.2818: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:29 - Loss:3.7229: 100%|██████████| 14/14 [00:04<00:00,  3.19it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:18:49 2021 

    Fold:0, Epoch:29, LR:0.0002214896, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.2818   |   3.7229

    LWLRAP:              0.9671   |   0.8549

    Class Loss:          0.7273   |   3.0361

    Consistency Loss:    0.5545   |   0.6868

    --------------------------------------------------------

    


Epoch:30 - Loss:1.2338: 100%|██████████| 56/56 [00:29<00:00,  1.87it/s]
Epoch:30 - Loss:4.1070: 100%|██████████| 14/14 [00:04<00:00,  3.05it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:19:24 2021 

    Fold:0, Epoch:30, LR:0.0001821295, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.2338   |   4.1070

    LWLRAP:              0.9655   |   0.8458

    Class Loss:          0.7312   |   3.3359

    Consistency Loss:    0.5026   |   0.7712

    --------------------------------------------------------

    


Epoch:31 - Loss:1.0813: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:31 - Loss:3.8815: 100%|██████████| 14/14 [00:04<00:00,  3.19it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:19:58 2021 

    Fold:0, Epoch:31, LR:0.0001458307, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.0813   |   3.8815

    LWLRAP:              0.9722   |   0.8365

    Class Loss:          0.6166   |   3.1893

    Consistency Loss:    0.4647   |   0.6923

    --------------------------------------------------------

    


Epoch:32 - Loss:1.1262: 100%|██████████| 56/56 [00:30<00:00,  1.84it/s]
Epoch:32 - Loss:4.0720: 100%|██████████| 14/14 [00:04<00:00,  3.01it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:20:33 2021 

    Fold:0, Epoch:32, LR:0.0001129428, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.1262   |   4.0720

    LWLRAP:              0.9736   |   0.8328

    Class Loss:          0.5990   |   3.4242

    Consistency Loss:    0.5272   |   0.6478

    --------------------------------------------------------

    


Epoch:33 - Loss:1.0130: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:33 - Loss:4.0810: 100%|██████████| 14/14 [00:04<00:00,  3.22it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:21:08 2021 

    Fold:0, Epoch:33, LR:8.378251e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.0130   |   4.0810

    LWLRAP:              0.9758   |   0.8409

    Class Loss:          0.5216   |   3.4756

    Consistency Loss:    0.4915   |   0.6053

    --------------------------------------------------------

    


Epoch:34 - Loss:1.2001: 100%|██████████| 56/56 [00:30<00:00,  1.85it/s]
Epoch:34 - Loss:3.6795: 100%|██████████| 14/14 [00:04<00:00,  3.09it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:21:43 2021 

    Fold:0, Epoch:34, LR:5.86306e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.2001   |   3.6795

    LWLRAP:              0.9626   |   0.8743

    Class Loss:          0.6895   |   3.1121

    Consistency Loss:    0.5106   |   0.5674

    --------------------------------------------------------

    
    LWLRAP Improved from 0.8656460917730684 --> 0.8743407743214244



Epoch:35 - Loss:1.0787: 100%|██████████| 56/56 [00:30<00:00,  1.87it/s]
Epoch:35 - Loss:3.8851: 100%|██████████| 14/14 [00:04<00:00,  3.07it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:22:18 2021 

    Fold:0, Epoch:35, LR:3.772935e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.0787   |   3.8851

    LWLRAP:              0.9736   |   0.8486

    Class Loss:          0.5869   |   3.1201

    Consistency Loss:    0.4918   |   0.7650

    --------------------------------------------------------

    


Epoch:36 - Loss:1.1218: 100%|██████████| 56/56 [00:29<00:00,  1.87it/s]
Epoch:36 - Loss:3.4808: 100%|██████████| 14/14 [00:04<00:00,  3.24it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:22:52 2021 

    Fold:0, Epoch:36, LR:2.128003e-05, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.1218   |   3.4808

    LWLRAP:              0.9710   |   0.8611

    Class Loss:          0.6389   |   2.8351

    Consistency Loss:    0.4829   |   0.6457

    --------------------------------------------------------

    


Epoch:37 - Loss:0.9338: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:37 - Loss:3.6722: 100%|██████████| 14/14 [00:04<00:00,  3.13it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:23:27 2021 

    Fold:0, Epoch:37, LR:9.441067e-06, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                0.9338   |   3.6722

    LWLRAP:              0.9773   |   0.8556

    Class Loss:          0.4826   |   2.9849

    Consistency Loss:    0.4512   |   0.6873

    --------------------------------------------------------

    


Epoch:38 - Loss:0.9430: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:38 - Loss:3.7942: 100%|██████████| 14/14 [00:04<00:00,  3.15it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:24:01 2021 

    Fold:0, Epoch:38, LR:2.326474e-06, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                0.9430   |   3.7942

    LWLRAP:              0.9794   |   0.8496

    Class Loss:          0.4779   |   3.1905

    Consistency Loss:    0.4651   |   0.6036

    --------------------------------------------------------

    


Epoch:39 - Loss:1.0561: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:39 - Loss:3.8480: 100%|██████████| 14/14 [00:04<00:00,  3.21it/s]



    Sat Feb 13 16:24:36 2021 

    Fold:0, Epoch:39, LR:4.768355e-09, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                1.0561   |   3.8480

    LWLRAP:              0.9737   |   0.8558

    Class Loss:          0.5787   |   3.2287

    Consistency Loss:    0.4774   |   0.6193

    --------------------------------------------------------

    


Epoch:0 - Loss:13.2505: 100%|██████████| 56/56 [00:30<00:00,  1.86it/s]
Epoch:0 - Loss:9.8836: 100%|██████████| 14/14 [00:04<00:00,  3.19it/s] 
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:25:15 2021 

    Fold:1, Epoch:0, LR:7.669937e-05, Cons. Weight: 0.6737946999085467

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                13.2505   |   9.8836

    LWLRAP:              0.1568   |   0.2149

    Class Loss:          13.2070   |   9.8718

    Consistency Loss:    0.0435   |   0.0119

    --------------------------------------------------------

    
    LWLRAP Improved from -inf --> 0.21490962606489594



Epoch:1 - Loss:10.4105: 100%|██████████| 56/56 [00:29<00:00,  1.87it/s]
Epoch:1 - Loss:9.2928: 100%|██████████| 14/14 [00:04<00:00,  3.17it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:25:50 2021 

    Fold:1, Epoch:1, LR:0.0001811856, Cons. Weight: 3.1047958479329627

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                10.4105   |   9.2928

    LWLRAP:              0.2254   |   0.2534

    Class Loss:          10.3244   |   9.2448

    Consistency Loss:    0.0862   |   0.0480

    --------------------------------------------------------

    
    LWLRAP Improved from 0.21490962606489594 --> 0.25337313771533826



Epoch:2 - Loss:9.4091: 100%|██████████| 56/56 [00:30<00:00,  1.85it/s] 
Epoch:2 - Loss:8.3727: 100%|██████████| 14/14 [00:04<00:00,  3.02it/s]



    Sat Feb 13 16:26:25 2021 

    Fold:1, Epoch:2, LR:0.0003374814, Cons. Weight: 10.836802322189582

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                9.4091   |   8.3727

    LWLRAP:              0.2879   |   0.3719

    Class Loss:          9.1582   |   8.1746

    Consistency Loss:    0.2509   |   0.1980

    --------------------------------------------------------

    
    LWLRAP Improved from 0.25337313771533826 --> 0.3719296644709902



Epoch:3 - Loss:8.3112: 100%|██████████| 56/56 [00:54<00:00,  1.02it/s]
Epoch:3 - Loss:8.2104: 100%|██████████| 14/14 [00:08<00:00,  1.71it/s]



    Sat Feb 13 16:27:28 2021 

    Fold:1, Epoch:3, LR:0.0005216868, Cons. Weight: 28.650479686019008

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                8.3112   |   8.2104

    LWLRAP:              0.4268   |   0.4914

    Class Loss:          7.7517   |   7.5470

    Consistency Loss:    0.5595   |   0.6634

    --------------------------------------------------------

    
    LWLRAP Improved from 0.3719296644709902 --> 0.4913750661103602



Epoch:4 - Loss:7.5719: 100%|██████████| 56/56 [00:57<00:00,  1.02s/it]
Epoch:4 - Loss:7.0433: 100%|██████████| 14/14 [00:08<00:00,  1.72it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:28:34 2021 

    Fold:1, Epoch:4, LR:0.0007056342, Cons. Weight: 57.375342073743276

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                7.5719   |   7.0433

    LWLRAP:              0.5380   |   0.5828

    Class Loss:          6.6278   |   6.0833

    Consistency Loss:    0.9441   |   0.9601

    --------------------------------------------------------

    
    LWLRAP Improved from 0.4913750661103602 --> 0.5827901063753803



Epoch:5 - Loss:6.5933: 100%|██████████| 56/56 [00:56<00:00,  1.00s/it]
Epoch:5 - Loss:6.5576: 100%|██████████| 14/14 [00:08<00:00,  1.67it/s]



    Sat Feb 13 16:29:39 2021 

    Fold:1, Epoch:5, LR:0.0008611956, Cons. Weight: 87.03247258333906

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                6.5933   |   6.5576

    LWLRAP:              0.6388   |   0.6947

    Class Loss:          5.3768   |   5.0049

    Consistency Loss:    1.2165   |   1.5526

    --------------------------------------------------------

    
    LWLRAP Improved from 0.5827901063753803 --> 0.694725555213344



Epoch:6 - Loss:6.0863: 100%|██████████| 56/56 [00:55<00:00,  1.00it/s]
Epoch:6 - Loss:6.6991: 100%|██████████| 14/14 [00:08<00:00,  1.72it/s]
  0%|          | 0/56 [00:00<?, ?it/s]


    Sat Feb 13 16:30:43 2021 

    Fold:1, Epoch:6, LR:0.0009645834, Cons. Weight: 100.0

    --------------------------------------------------------
    Metric:              Train    |   Val
    --------------------------------------------------------
    Loss:                6.0863   |   6.6991

    LWLRAP:              0.6995   |   0.6828

    Class Loss:          4.7192   |   5.0322

    Consistency Loss:    1.3671   |   1.6669

    --------------------------------------------------------

    


Epoch:7 - Loss:4.6191:   0%|          | 0/56 [00:01<?, ?it/s]

# Predict on Test Set
We'll predict using the teacher model but you could also use the student or a combination of the two. Inference works just like it would for a vanilla baseline model.

In [None]:
def test(test_df, train_fold):
    test_dataset = TestDataset(
        df=test_df,
        data_path=config.test_data_path,
        period=config.period,
        step=config.step
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        drop_last=False,
        num_workers=config.num_workers
    )
    
    weights_path = os.path.join(config.save_path, f'{config.model_name}-fold-{fold}.bin')
    model = get_model()
    model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')), strict=False)
    
    test_pred, ids = predict_on_test(model, test_loader)

    # Build Submission File
    test_pred_df = pd.DataFrame({
        "recording_id": test_df.recording_id.values
    })
    target_cols = test_df.columns[1:].values.tolist()
    test_pred_df = test_pred_df.join(pd.DataFrame(np.array(test_pred), 
                                                  columns=target_cols))
    test_pred_df.to_csv(os.path.join(config.save_path, 
                                     f"{config.model_name}-fold-{train_fold}-submission.csv"), 
                        index=False)
    
    
test_df = pd.read_csv(config.test_csv)
for fold in range(5 if config.train_5_folds else 1):
    test(test_df, fold)

## 5 Fold Ensemble
For 5 fold runs, we'll create a single ensemble prediction by simply averaging all of the folds.

In [None]:
def ensemble(submission_path):
    dfs = [pd.read_csv(os.path.join(
        config.save_path, f"{config.model_name}-fold-{i}-submission.csv")) for i in range(5)]
    anchor = dfs[0].copy()
    cols = anchor.columns[1:]
    for c in cols:
        total = 0
        for df in dfs:
            total += df[c]
        anchor[c] = total / len(dfs)
    anchor.to_csv(submission_path, index=False)


submission_path = os.path.join(config.res_path, f"{config.model_name}-submission.csv")
if config.train_5_folds:
    ensemble(submission_path)
else:
    fold0_submission = os.path.join(config.res_path, f"fold-0-submission.csv")
    os.rename(fold0_submission, submission_path)

# Conclusion 
Thanks for reading! I dropped some unrelated tricks from this and didn't spend much time tuning so there's almost definetely room for improvement.

I know it's pretty late in the competition for new notebooks, but considering that there are a few other public notebooks that score higher, I'm hoping this won't cause a significant shakeup. 