# Mean Teacher

This method was proposed by Tarvainen et al. The general approach is similar to Temporal Ensembling but it uses Exponential Moving Average(EMA) of the model parameters instead of predictions.

The key idea is to have two models called “Student” and “Teacher”. The student model is a regular model with dropout. And the teacher model has the same architecture as the student model but its weights are set using an exponential moving average of the weights of student model. For a labeled or unlabeled image, we create two random augmented versions of the image. Then, the student model is used to predict label distribution for first image. And, the teacher model is used to predict the label distribution for the second augmented image. The square difference of these two predictions is used as a consistency loss. For labeled images, we also calculate the cross-entropy loss. The final loss is a weighted sum of these two loss terms. A weight w(t) is applied to decide how much the consistency loss contributes in the overall loss.

In [None]:
import sys
import os
import numpy as np
sys.path.append(os.path.abspath(os.path.join('..', '..')))

from data.dataloaders import load_segmentation_data
from models.segmentation.semi_supervised import kidney_segmentor
from utils import save_figures_and_show
from evaluation.classification.evaluate import bootstrap_ci


## Load Data

In [None]:
# Extract labeled and unlabeled data from path

labeled_train_loader, labeled_val_loader = load_segmentation_data("labeled", "kidney")
unlabeled_train_loader, unlabeled_val_loader = load_segmentation_data("unlabeled", "kidney")

## Train Kidney Mean Teacher (kMT) Model

In [None]:
val_interval, epoch_loss_values, metric_values, metric_values_kidney, metric_values_tumor = kidney_segmentor(labeled_train_loader, 
                                                                                                        labeled_val_loader, 
                                                                                                        unlabeled_train_loader)

In [None]:
low_ic, up_ic = bootstrap_ci(metric_values_kidney)
print(f"(Dice) mean ± std : {np.mean(metric_values_kidney):.3f} ± {np.std(metric_values_kidney):.3f} ({low_ic:.2f},{up_ic:.2f})")


## Plot and save Metrics

In [None]:
# visualize dice loss and dice coefficient
save_figures_and_show(val_interval, epoch_loss_values, metric_values, metric_values_kidney, metric_values_tumor)

# Evaluate kMT