# import

In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
import sys

sys.executable

'/users/samova/lcances/.miniconda3/envs/pytorch-dev/bin/python'

In [9]:
import os

os.environ["MKL_NUM_THREADS"] = "2"
os.environ["NUMEXPR_NU M_THREADS"] = "2"
os.environ["OMP_NUM_THREADS"] = "2"
import time
import pprint

import numpy
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.cuda.amp import autocast
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

In [10]:
from SSL.util.loaders import (
    load_dataset,
    load_optimizer,
    load_callbacks,
    load_preprocesser,
)
from SSL.util.model_loader import load_model
from SSL.util.checkpoint import CheckPoint, mSummaryWriter
from SSL.util.mixup import MixUpBatchShuffle
from SSL.util.utils import (
    reset_seed,
    get_datetime,
    track_maximum,
    DotDict,
    get_train_format,
    DotDict,
)
from SSL.ramps import Warmup, sigmoid_rampup
from SSL.loss import JensenShanon

from metric_utils.metrics import CategoricalAccuracy, FScore, ContinueAverage, MAP

# Arguments

In [11]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--from_config", default="", type=str)
parser.add_argument("-d", "--dataset_root", default="../../datasets", type=str)
parser.add_argument("-D", "--dataset", default="ComParE2021_PRS", type=str)

group_t = parser.add_argument_group("Commun parameters")
group_t.add_argument("-m", "--model", default="wideresnet28_2", type=str)
group_t.add_argument("--supervised_ratio", default=0.1, type=float)
group_t.add_argument("--batch_size", default=128, type=int)
group_t.add_argument("--nb_epoch", default=125000, type=int)
group_t.add_argument("--learning_rate", default=0.001, type=float)
group_t.add_argument("--resume", action="store_true", default=False)
group_t.add_argument("--seed", default=1234, type=int)
group_t.add_argument("--num_classes", default=5, type=int)

group_u = parser.add_argument_group("Datasets parameters")
group_u.add_argument(
    "-t", "--train_folds", nargs="+", default=[1, 2, 3, 4, 5, 6, 7, 8, 9], type=int
)
group_u.add_argument("-v", "--val_folds", nargs="+", default=[10], type=int)

group_s = parser.add_argument_group("Student teacher parameters")
group_s.add_argument("--ema_alpha", default=0.999, type=float)
group_s.add_argument("--warmup_length", default=50, type=int)
group_s.add_argument("--lambda_cost_max", default=1, type=float)
group_s.add_argument("--teacher_noise", default=0, type=float)
group_s.add_argument("--ccost_softmax", action="store_true", default=False)
group_s.add_argument("--ccost_method", type=str, default="mse")

group_mixup = parser.add_argument_group("Mixup parameters")
group_mixup.add_argument("--mixup", action="store_true", default=False)
group_mixup.add_argument("--mixup_alpha", type=float, default=0.4)
group_mixup.add_argument("--mixup_max", action="store_true", default=False)
group_mixup.add_argument("--mixup_label", action="store_true", default=False)

group_l = parser.add_argument_group("Logs")
group_l.add_argument("--checkpoint_root", default="../../model_save/", type=str)
group_l.add_argument("--tensorboard_root", default="../../tensorboard/", type=str)
group_l.add_argument("--checkpoint_path", default="mean-teacher_mixup", type=str)
group_l.add_argument("--tensorboard_path", default="mean-teacher_mixup", type=str)
group_l.add_argument("--tensorboard_sufix", default="", type=str)

args = parser.parse_args(["--mixup", "--mixup_max", "--ccost_softmax"])

tensorboard_path = os.path.join(
    args.tensorboard_root, args.dataset, args.tensorboard_path
)
checkpoint_path = os.path.join(args.checkpoint_root, args.dataset, args.checkpoint_path)

In [12]:
cfg = DotDict(
    hardware=DotDict(
        nb_cpu=0,
        nb_gpu=1,
    ),
    dataset=DotDict(
        dataset="ComParE2021_PRS",
        num_classes=5,
    ),
    model=DotDict(
        model="MobileNetV2",
    ),
    train_param=DotDict(
        supervised_ratio=0.1,
        batch_size=128,
        nb_iteration=75000,
        learning_rate=0.001,
        seed=1234,
        resume=False,
        train_folds=None,
        val_folds=None,
    ),
    mt=DotDict(
        alpha=0.999,
        warmup_length=50,
        lambda_ccost_max=1,
        use_softmax=True,
        ccost_method="mse",
    ),
    mixup=DotDict(use=False, alpha=1.0, max=True, label=True),
    specaugment=DotDict(
        use=False,
        time_drop_width=32,
        time_stripe_num=1,
        freq_drop_width=4,
        freq_stripe_num=1,
    ),
)

cfg["path"] = DotDict(
    dataset_root="../../datasets",
    checkpoint_root="../../model_save",
    tensorboard_root="../../tensorboard",
)
cfg.path[
    "checkpoint_path"
] = f"{cfg.path.checkpoint_root}/{cfg.dataset.dataset}/fixmatch"
cfg.path[
    "tensorboard_path"
] = f"{cfg.path.tensorboard_root}/{cfg.dataset.dataset}/fixmatch"

In [13]:
pprint.pprint(vars(args))

{'batch_size': 128,
 'ccost_method': 'mse',
 'ccost_softmax': True,
 'checkpoint_path': 'mean-teacher_mixup',
 'checkpoint_root': '../../model_save/',
 'dataset': 'ComParE2021_PRS',
 'dataset_root': '../../datasets',
 'ema_alpha': 0.999,
 'from_config': '',
 'lambda_cost_max': 1,
 'learning_rate': 0.001,
 'mixup': True,
 'mixup_alpha': 0.4,
 'mixup_label': False,
 'mixup_max': True,
 'model': 'wideresnet28_2',
 'nb_epoch': 125000,
 'num_classes': 5,
 'resume': False,
 'seed': 1234,
 'supervised_ratio': 0.1,
 'teacher_noise': 0,
 'tensorboard_path': 'mean-teacher_mixup',
 'tensorboard_root': '../../tensorboard/',
 'tensorboard_sufix': '',
 'train_folds': [1, 2, 3, 4, 5, 6, 7, 8, 9],
 'val_folds': [10],
 'warmup_length': 50}


# initialisation

In [14]:
reset_seed(cfg.train_param.seed)

# Prepare the dataset

In [15]:
train_transform, val_transform = load_preprocesser(cfg.dataset.dataset, "mean-teacher")
train_transform



Sequential(
  (0): PadUpTo()
  (1): MelSpectrogram(
    (spectrogram): Spectrogram()
    (mel_scale): MelScale()
  )
  (2): AmplitudeToDB()
  (3): Squeeze()
)

In [16]:
manager, train_loader, val_loader = load_dataset(
    cfg.dataset.dataset,
    "mean-teacher",
    dataset_root=cfg.path.dataset_root,
    supervised_ratio=cfg.train_param.supervised_ratio,
    batch_size=cfg.train_param.batch_size,
    train_folds=cfg.train_param.train_folds,
    val_folds=cfg.train_param.val_folds,
    train_transform=train_transform,
    val_transform=val_transform,
    num_workers=cfg.hardware.nb_cpu,
    pin_memory=False,
    verbose=1,
)

/users/samova/lcances/semi-supervised/SSL/dataset/speechcommands.py
<bound method cache_feature.<locals>.decorator of <SSL.dataset_loader.ComParE2021_PRS.ComParE2021_PRS object at 0x7f5b3728efa0>>
cache path:  .ComParE2021_PRS/.cache_batch_size=128_seed=1234
split ready, loading cache file
s_idx:  693
u_idx:  6222
Sort the classes


In [30]:
input_shape = train_loader._iterables[0].dataset[0][0].shape
input_shape

torch.Size([64, 94])

# Prep model

In [31]:
torch.cuda.empty_cache()

model_func = load_model(cfg.dataset.dataset, cfg.model.model)

student = model_func(input_shape=input_shape, num_classes=cfg.dataset.num_classes)
teacher = model_func(input_shape=input_shape, num_classes=cfg.dataset.num_classes)

student = student.cuda()
teacher = teacher.cuda()

# We do not need gradient for the teacher model
for p in teacher.parameters():
    p.detach()

In [32]:
from torchsummary import summary

s = summary(student, input_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 64, 94]             288
         AvgPool2d-2           [-1, 32, 32, 47]               0
       BatchNorm2d-3           [-1, 32, 32, 47]              64
             ReLU6-4           [-1, 32, 32, 47]               0
            Conv2d-5           [-1, 32, 32, 47]             288
         AvgPool2d-6           [-1, 32, 32, 47]               0
       BatchNorm2d-7           [-1, 32, 32, 47]              64
             ReLU6-8           [-1, 32, 32, 47]               0
            Conv2d-9           [-1, 16, 32, 47]             512
      BatchNorm2d-10           [-1, 16, 32, 47]              32
 InvertedResidual-11           [-1, 16, 32, 47]               0
           Conv2d-12           [-1, 96, 32, 47]           1,536
      BatchNorm2d-13           [-1, 96, 32, 47]             192
            ReLU6-14           [-1, 96,

# training parameters

In [33]:
tensorboard_root = f"{cfg.model.model}/{cfg.train_param.supervised_ratio}/{get_datetime()}_{model_func.__name__}"
checkpoint_root = (
    f"{cfg.model.model}/{cfg.train_param.supervised_ratio}/{model_func.__name__}"
)

# mea teacher parameters
sufix_title = f"_{cfg.mt.ema_alpha}-emaa"
sufix_title += f"_{cfg.mt.warmup_length}-wl"
sufix_title += f"_{cfg.mt.lambda_cost_max}-lccm"

# mixup parameters
if cfg.mixup.use:
    sufix_title += "_mixup"
    if cfg.mixup.mixup_max:
        sufix_title += "-max"
    if cfg.mixup.mixup_label:
        sufix_title += "-label"
    sufix_title += f"-{cfg.mixup.mixup_alpha}-a"

# ccost function and method
if cfg.mt.ccost_method:
    sufix_title += "_cc-MSE"
if cfg.mt.use_softmax:
    sufix_title += "-SOFTMAX"

# normale training parameters
sufix_title += f"_{cfg.train_param.learning_rate}-lr"
sufix_title += f"_{cfg.train_param.supervised_ratio}-sr"
sufix_title += f"_{cfg.train_param.nb_epoch}-e"
sufix_title += f"_{cfg.train_param.batch_size}-bs"
sufix_title += f"_{cfg.train_param.seed}-seed"

tensorboard_title = tensorboard_root + sufix_title
checkpoint_title = checkpoint_root + sufix_title

In [34]:
tensorboard = mSummaryWriter(
    log_dir="%s/%s" % (tensorboard_path, tensorboard_title), comment=model_func.__name__
)
print(os.path.join(tensorboard_path, tensorboard_title))

../../tensorboard/ComParE2021_PRS/mean-teacher_mixup/MobileNetV2/0.1/2021-03-19_10:01:58_MobileNetV2_None-emaa_50-wl_None-lccm_cc-MSE-SOFTMAX_0.001-lr_0.1-sr_None-e_128-bs_1234-seed


## optimizer & callbacks

In [35]:
optimizer = load_optimizer(
    cfg.dataset.dataset,
    "mean-teacher",
    student=student,
    learning_rate=cfg.train_param.learning_rate,
)
callbacks = load_callbacks(
    cfg.dataset.dataset,
    "mean-teacher",
    optimizer=optimizer,
    nb_epoch=cfg.train_param.nb_iteration,
)

In [36]:
# losses
loss_ce = nn.CrossEntropyLoss(reduction="mean")  # Supervised loss

if cfg.mt.ccost_method == "mse":
    consistency_cost = nn.MSELoss(reduction="mean")  # Unsupervised loss
elif cfg.mt.ccost_method == "js":
    consistency_cost = JensenShanon

lambda_cost = Warmup(cfg.mt.lambda_ccost_max, cfg.mt.warmup_length, sigmoid_rampup)
callbacks += [lambda_cost]

# Checkpoint
checkpoint = CheckPoint(
    [student, teacher],
    optimizer,
    mode="max",
    name="%s/%s.torch" % (checkpoint_path, checkpoint_title),
)


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

checkpoint initialise at:  /users/samova/lcances/semi-supervised/model_save/ComParE2021_PRS/mean-teacher_mixup/MobileNetV2/0.1/MobileNetV2_None-emaa_50-wl_None-lccm_cc-MSE-SOFTMAX_0.001-lr_0.1-sr_None-e_128-bs_1234-seed.torch
name:  MobileNetV2_None-emaa_50-wl_None-lccm_cc-MSE-SOFTMAX_0.001-lr_0.1-sr_None-e_128-bs_1234-seed.torch
mode:  max


In [37]:
consistency_cost

MSELoss()

## Metrics definition

In [38]:
m_ = lambda x: x.mean(size=100)

In [39]:
def metrics_calculator():
    def c(logits, y):
        with torch.no_grad():
            y_one_hot = F.one_hot(y, num_classes=cfg.dataset.num_classes)

            pred = torch.softmax(logits, dim=1)
            arg = torch.argmax(logits, dim=1)

            acc = m_(c.fn.acc(arg, y))
            f1 = m_(c.fn.f1(pred, y_one_hot))

            return (
                acc,
                f1,
            )

    c.fn = DotDict(
        acc=CategoricalAccuracy(),
        f1=FScore(),
    )

    return c

In [40]:
metrics = DotDict(
    calc_student_s_metrics=metrics_calculator(),
    calc_student_u_metrics=metrics_calculator(),
    calc_teacher_s_metrics=metrics_calculator(),
    calc_teacher_u_metrics=metrics_calculator(),
    avg_Sce=ContinueAverage(),
    avg_Tce=ContinueAverage(),
    avg_ccost=ContinueAverage(),
)

val_metrics = DotDict(
    calc_student_s_metrics=metrics_calculator(),
    calc_student_u_metrics=metrics_calculator(),
    calc_teacher_s_metrics=metrics_calculator(),
    calc_teacher_u_metrics=metrics_calculator(),
    student_mAP=MAP(),
    teacher_mAP=MAP(),
    avg_Sce=ContinueAverage(),
    avg_Tce=ContinueAverage(),
    avg_ccost=ContinueAverage(),
)

softmax_fn = lambda x: x
if cfg.mt.ccost_softmax:
    softmax_fn = nn.Softmax(dim=1)


def reset_metrics(dd):
    for m in dd.values():
        if not isinstance(m, (ContinueAverage, MAP)):
            for m_ in m.fn.values():
                m_.reset()

        else:
            m.reset()


maximum_tracker = track_maximum()

## training function

In [41]:
header_form = "{:<8.8} {:<6.6} - {:<6.6} - {:<10.8} {:<8.6} {:<8.6} {:<8.6} {:<8.6} {:<8.6} {:<8.6} | {:<10.8} {:<8.6} {:<8.6} {:<8.6} {:<8.6} {:<8.6} - {:<8.6}"
value_form = "{:<8.8} {:<6d} - {:<6d} - {:<10.8} {:<8.4f} {:<8.4f} {:<8.4f} {:<8.4f} {:<8.4f} {:<8.4f} | {:<10.8} {:<8.4f} {:<8.4f} {:<8.4f} {:<8.4f} {:<8.4f} - {:<8.4f}"
header = header_form.format(
    ".               ",
    "Epoch",
    "%",
    "Student:",
    "ce",
    "ccost",
    "acc_s",
    "f1_s",
    "mAP",
    "AUR",
    "Teacher:",
    "ce",
    "acc_s",
    "f1_s",
    "mAP",
    "AUR",
    "Time",
)

UNDERLINE_SEQ = "\033[1;4m"
RESET_SEQ = "\033[0m"

train_format = value_form
val_format = UNDERLINE_SEQ + value_form + RESET_SEQ

In [42]:
def update_teacher_model(student_model, teacher_model, alpha, epoch):

    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (epoch + 1), alpha)

    for param, ema_param in zip(student_model.parameters(), teacher_model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

In [45]:
mixup_fn = MixUpBatchShuffle(
    alpha=cfg.mt.alpha, apply_max=cfg.mt.max, mix_labels=cfg.mt.label
)

In [46]:
def train(epoch, S, U, start_time):
    # aliases
    M = metrics
    T = tensorboard.add_scalar

    student.train()

    x_s, y_s = S
    x_u, y_u = U

    # Apply mixup if needed, otherwise no mixup.
    n_x_s, n_y_s, n_x_u, n_y_u = x_s, y_s, x_u, y_u
    if cfg.mixup.use:
        n_x_s, n_y_s = mixup_fn(x_s, y_s)
        n_x_u, n_y_u = mixup_fn(x_u, y_u)

    n_x_s, n_x_u = n_x_s.cuda(), n_x_u.cuda()
    x_s, x_u = x_s.cuda(), x_u.cuda()
    y_s, y_u = y_s.cuda(), y_u.cuda()

    # Predictions
    student_s_logits = student(x_s)
    student_u_logits = student(x_u)
    teacher_s_logits = teacher(n_x_s)
    teacher_u_logits = teacher(n_x_u)

    # Calculate supervised loss (only student on S)
    loss = loss_ce(student_s_logits, y_s)

    # Calculate consistency cost (mse(student(x), teacher(x))) x is S + U
    student_logits = torch.cat((student_s_logits, student_u_logits), dim=0)
    teacher_logits = torch.cat((teacher_s_logits, teacher_u_logits), dim=0)
    ccost = consistency_cost(softmax_fn(student_logits), softmax_fn(teacher_logits))

    total_loss = loss + lambda_cost() * ccost

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    with torch.set_grad_enabled(False):
        # Teacher prediction (for metrics purpose)
        _teacher_loss = loss_ce(teacher_s_logits, y_s)

        # Update teacher
        update_teacher_model(student, teacher, args.ema_alpha, epoch)

        # Compute the metrics for the student
        student_s_metrics = M.calc_student_s_metrics(student_s_logits, y_s)
        student_u_metrics = M.calc_student_u_metrics(student_u_logits, y_u)
        student_s_acc, student_s_f1, student_u_acc, student_u_f1 = (
            *student_s_metrics,
            *student_u_metrics,
        )

        # Compute the metrics for the teacher
        teacher_s_metrics = M.calc_teacher_s_metrics(teacher_s_logits, y_s)
        teacher_u_metrics = M.calc_teacher_u_metrics(teacher_u_logits, y_u)
        teacher_s_acc, teacher_s_f1, teacher_u_acc, teacher_u_f1 = (
            *teacher_s_metrics,
            *teacher_u_metrics,
        )

        # Running average of the two losses
        student_running_loss = m_(M.avg_Sce(loss.item()))
        teacher_running_loss = m_(M.avg_Tce(_teacher_loss.item()))
        running_ccost = m_(M.avg_ccost(ccost.item()))

        # logs
        print(
            train_format.format(
                "Training: ",
                epoch + 1,
                cfg.train_param.nb_iteration,
                "",
                student_running_loss,
                running_ccost,
                student_s_acc,
                student_s_f1,
                0.0,
                0.0,
                "",
                teacher_running_loss,
                teacher_s_acc,
                teacher_s_f1,
                0.0,
                0.0,
                time.time() - start_time,
            ),
            end="\r",
        )

    tensorboard.add_scalar("train/student_acc_s", student_s_acc, epoch)
    tensorboard.add_scalar("train/student_acc_u", student_u_acc, epoch)
    tensorboard.add_scalar("train/student_f1_s", student_s_f1, epoch)
    tensorboard.add_scalar("train/student_f1_u", student_u_f1, epoch)

    tensorboard.add_scalar("train/teacher_acc_s", teacher_s_acc, epoch)
    tensorboard.add_scalar("train/teacher_acc_u", teacher_u_acc, epoch)
    tensorboard.add_scalar("train/teacher_f1_s", teacher_s_f1, epoch)
    tensorboard.add_scalar("train/teacher_f1_u", teacher_u_f1, epoch)

    tensorboard.add_scalar("train/student_loss", student_running_loss, epoch)
    tensorboard.add_scalar("train/teacher_loss", teacher_running_loss, epoch)
    tensorboard.add_scalar("train/consistency_cost", running_ccost, epoch)

In [50]:
def val(epoch):
    # aliases
    M = val_metrics
    uar_avg = ContinueAverage()
    mAP_avg = ContinueAverage()

    start_time = time.time()
    print("")
    reset_metrics(val_metrics)
    student.eval()

    with torch.set_grad_enabled(False):
        for i, (X, y) in enumerate(val_loader):
            X = X.cuda()
            y = y.cuda()

            # Predictions
            student_logits = student(X)
            teacher_logits = teacher(X)

            # Calculate supervised loss (only student on S)
            loss = loss_ce(student_logits, y)
            _teacher_loss = loss_ce(teacher_logits, y)  # for metrics only
            ccost = consistency_cost(
                softmax_fn(student_logits), softmax_fn(teacher_logits)
            )

            # Compute the metrics
            y_one_hot = F.one_hot(y, num_classes=cfg.dataset.num_classes)

            # ---- student ----
            student_metrics = M.calc_student_s_metrics(student_logits, y)
            student_acc, student_f1 = student_metrics

            # ---- teacher ----
            teacher_metrics = M.calc_teacher_s_metrics(teacher_logits, y)
            teacher_acc, teacher_f1 = teacher_metrics

            # Running average of the two losses
            student_running_loss = m_(M.avg_Sce(loss.item()))
            teacher_running_loss = m_(M.avg_Tce(_teacher_loss.item()))
            running_ccost = m_(M.avg_ccost(ccost.item()))

            student_pred = F.one_hot(
                torch.argmax(student_logits, dim=1), num_classes=cfg.dataset.num_classes
            )
            teacher_pred = F.one_hot(
                torch.argmax(teacher_logits, dim=1), num_classes=cfg.dataset.num_classes
            )
            val_student_mAP = m_(
                M.student_mAP(
                    student_pred.cpu().reshape(-1), y_one_hot.cpu().reshape(-1)
                )
            )
            val_teacher_mAP = m_(
                M.teacher_mAP(
                    teacher_pred.cpu().reshape(-1), y_one_hot.cpu().reshape(-1)
                )
            )

            # logs
            print(
                val_format.format(
                    "Validation: ",
                    epoch + 1,
                    int(100 * (i + 1) / len(val_loader)),
                    "",
                    student_running_loss,
                    running_ccost,
                    student_acc,
                    student_f1,
                    val_student_mAP,
                    0.0,
                    "",
                    teacher_running_loss,
                    teacher_acc,
                    teacher_f1,
                    val_teacher_mAP,
                    0.0,
                    time.time() - start_time,
                ),
                end="\r",
            )

    tensorboard.add_scalar("val/student_acc", student_acc, epoch)
    tensorboard.add_scalar("val/student_f1", student_f1, epoch)
    tensorboard.add_scalar("val/student_mAP", val_student_mAP, epoch)
    tensorboard.add_scalar("val/teacher_acc", teacher_acc, epoch)
    tensorboard.add_scalar("val/teacher_f1", teacher_f1, epoch)
    tensorboard.add_scalar("val/teacher_mAP", val_teacher_mAP, epoch)
    tensorboard.add_scalar("val/student_loss", student_running_loss, epoch)
    tensorboard.add_scalar("val/teacher_loss", teacher_running_loss, epoch)
    tensorboard.add_scalar("val/consistency_cost", running_ccost, epoch)

    tensorboard.add_scalar("hyperparameters/learning_rate", get_lr(optimizer), epoch)
    tensorboard.add_scalar("hyperparameters/lambda_cost_max", lambda_cost(), epoch)

    tensorboard.add_scalar(
        "max/student_acc", maximum_tracker("student_acc", student_acc), epoch
    )
    tensorboard.add_scalar(
        "max/teacher_acc", maximum_tracker("teacher_acc", teacher_acc), epoch
    )
    tensorboard.add_scalar(
        "max/student_f1", maximum_tracker("student_f1", student_f1), epoch
    )
    tensorboard.add_scalar(
        "max/teacher_f1", maximum_tracker("teacher_f1", teacher_f1), epoch
    )

    for c in callbacks:
        c.step()

    return val_teacher_mAP

# Training

In [None]:
if cfg.train_param.resume:
    checkpoint.load_last()

start_iteration = checkpoint.epoch_counter
end_iteration = cfg.train_param.nb_iteration

train_iterator = iter(train_loader)
start_time = time.time()

print(header)
for e in range(start_iteration, end_iteration):
    # Validation every 500 iteration
    if e % 500 == 0:
        val_teacher_mAP = val(e)
        print("")
        checkpoint.step(val_teacher_mAP)
        tensorboard.flush()

    train(e, *next(train_iterator), start_time)

.        Epoch  - %      - Student:   ce       ccost    acc_s    f1_s     mAP      AUR      | Teacher:   ce       acc_s    f1_s     mAP      AUR      - Time    
Training 500    - 75000  -            0.7739   1.8919   0.7008   0.6787   0.0000   0.0000   |            0.7235   0.7200   0.6834   0.0000   0.0000   - 180.3780
[1;4mValidati 501    - 100    -            0.5515   1.0535   0.3412   0.3225   0.2544   0.0000   |            0.6353   0.2921   0.2755   0.2075   0.0000   - 2.0570  [0m

 better performance: saving ...
Training 1000   - 75000  -            0.5407   1.7824   0.7958   0.7877   0.0000   0.0000   |            0.5141   0.8192   0.8061   0.0000   0.0000   - 267.8175
[1;4mValidati 1001   - 100    -            0.4890   0.9291   0.3555   0.3449   0.2699   0.0000   |            0.5698   0.3317   0.3150   0.2449   0.0000   - 2.0784  [0m

 better performance: saving ...
Training 1500   - 75000  -            0.4026   1.6589   0.8767   0.8706   0.0000   0.0000   |            0.38

## Save the hyper parameters and the metrics

In [1]:
hparams = {}
for key, value in args.__dict__.items():
    hparams[key] = str(value)

final_metrics = {
    "max_acc_student": maximum_tracker.max["student_acc"],
    "max_f1_student": maximum_tracker.max["student_f1"],
    "max_acc_teacher": maximum_tracker.max["teacher_acc"],
    "max_f1_teacher": maximum_tracker.max["teacher_f1"],
}

tensorboard.add_hparams(hparams, final_metrics)

tensorboard.flush()
tensorboard.close()

NameError: name 'args' is not defined

# test

In [174]:
from SSL.dataset_loader.ComParE2021_PRS import ComParE2021_PRS
from torch.utils.data import DataLoader
from metric_utils.metrics import MAP
from sklearn.metrics import recall_score, average_precision_score

In [175]:
test_dataset = ComParE2021_PRS(
    root="../../datasets", subset="test", transform=val_transform
)
test_loader = DataLoader(
    test_dataset, batch_size=64, shuffle=False, num_workers=cfg.hardware.nb_cpu
)

In [176]:
checkpoint.load_best()

['state_dict', 'optimizer', 'epoch', 'best_metric']


In [177]:
student.eval()
S = nn.Softmax(dim=1)
nb_batch = len(val_loader)

all_pred = []
all_targets = []

for i, (x, y) in enumerate(val_loader):
    x, y = x.cuda(), y.cuda()

    pred = torch.argmax(S(student(x)), dim=1)
    pred = F.one_hot(pred, num_classes=cfg.dataset.num_classes)
    y = F.one_hot(y, num_classes=cfg.dataset.num_classes)

    all_pred.append(pred.detach().cpu())
    all_targets.append(y.detach().cpu())

    print(f"{i} / {nb_batch}", end="\r")

54 / 55

In [178]:
all_pred_ = torch.vstack(all_pred)
all_targets_ = torch.vstack(all_targets)

In [179]:
uar = recall_score(all_targets_, all_pred_, average="macro")

In [180]:
mAP = average_precision_score(all_targets_, all_pred_, average=None)

In [181]:
print("uar: ", uar)
print("mAP: ", mAP.mean())

uar:  0.5746229771610213
mAP:  0.4657394601030007
