# 𝕊𝕖𝕞𝕚-𝕊𝕦𝕡𝕖𝕣𝕧𝕚𝕤𝕖𝕕 𝕝𝕖𝕒𝕣𝕟𝕚𝕟𝕘 𝕦𝕤𝕚𝕟𝕘 𝕄𝕖𝕒𝕟 𝕋𝕖𝕒𝕔𝕙𝕖𝕣

Implementation of pixel-wise Mean Teacher (MT)
    
This method is proposed in the paper: 
    'Mean Teachers are Better Role Models:
        Weight-Averaged Consistency Targets Improve Semi-Supervised Deep Learning Results'
This implementation only supports Gaussian noise as input perturbation, and the two-heads
outputs trick is not available.

Source:
https://github.com/ZHKKKe/PixelSSL/blob/master/pixelssl/ssl_algorithm/ssl_mt.py

# Imports

In [1]:
import glob
import os
import time
import random
from PIL import Image

import pandas as pd

import torch
import torch.nn
from torch.autograd import Variable
import torchvision

import segmentation_models_pytorch as smp

from torch.utils.data import DataLoader

import sys
sys.path.insert(0, "helper")
from helper.dataset.mean_teacher import *
# from helper.model.mean_teacher import * 
from helper.sampler.mixed_batch import *
from helper.model.block.noise_block import GaussianNoiseBlock
from helper.compute.compute_bin_seg import BCE_BinSeg_CU
from helper.compute.loss.shape import ShapeLoss
from helper.compute.loss.dice import DiceLoss

#from pixelssl.utils import REGRESSION, CLASSIFICATION
#from pixelssl.utils import logger, cmd, tool
#from pixelssl.nn import func
#from pixelssl.nn.module import patch_replication_callback, GaussianNoiseLayer

# Experiment Configs

In [2]:
class Configs():
    
    def __init__(self):
        # =============================================================================
        # 
        # =============================================================================
        
        self.prefix = "tmp"
        self.reduced_data = False
        
        # smp unet ++ parameters
        self.encoder_name = "efficientnet-b7"
        self.encoder_weights = "imagenet"
        self.in_channels =  1
        self.classes = 1
        
        self.epochs = 100
        
        self.gaussian_noise = 0.1 # None
        
        self.ema_decay = 0.999 # default value
        
        # Sizes of tensors must match except in dimension 1
        # I solved the issue by resizing all the images size divisible to 32.
        self.image_size = 128 # 512
        
        self.num_workers = 0
        self.iterations = 50
        
        # batch size = n_samples_per_class_per_batch * classes
        # for mixed batch sampling
        self.n_samples_per_class_per_batch = 1
        
        self.lbs = 3 #  self.args.labeled_batch_size # .... remove this eventually and replace

        # optimisation
        self.optimiser = "sgd"
        self.learning_rate = 0.01
        self.min_learning_rate = 0.0001
        self.weight_decay = 1e-4
        self.momentum = 0.9
        
        # self.is_epoch_lrer = True # epoch or batch based learning rate updater
        
        self.dropout = None
        
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        
        
    
        
        # =============================================================================
        # Paths
        # =============================================================================
        
        self.base_path = r"C:/Users/Prinzessin/projects/decentnet"
        if not os.path.exists(self.base_path):
            os.makedirs(self.base_path)
        os.chdir(self.base_path) # this is now the main directory !!!!!!!!!!!!!!!!!!!!
        
        self.csv_filenames = glob.glob(r"datasceyence/data_prep/mt_*.csv")
        
        print(self.csv_filenames)
        
        # input
        self.load_checkpoint_file = None
        
        # all csv files used for run_mean_teacher.ipybn
        #self.csv_data_paths = [
        #    {"path" : r"data/data_ichallenge_amd.csv"}, 
        #    {"path" : r"data/data_ichallenge_non_amd.csv"}
        #]
        
        # output
        self.logger_path = f"results/{self.prefix}"
        if not os.path.exists(self.logger_path):
            os.makedirs(self.logger_path)
            
        self.save_checkpoint_path = f"results/{self.prefix}/ckpts"
        if not os.path.exists(self.save_checkpoint_path):
            os.makedirs(self.save_checkpoint_path)
            
        
    def log(self):
        # =============================================================================
        # save all class variables to file "configs.txt"
        # =============================================================================
        c = pd.DataFrame.from_dict({'key': self.__dict__.keys(), 'value': self.__dict__.values()})
        c.to_csv(os.path.join(self.logger_path, "configs.txt"), sep=':', index=False)

# Routine

In [3]:
class RoutineMT:

    def __init__(self, configs):
        super(RoutineMT, self).__init__()
        
        self.configs = configs
        
        self.prefix = configs.prefix
        self.ema_decay = configs.ema_decay
        
        self.load_ckpt = torch.load(configs.load_checkpoint_file) if configs.load_checkpoint_file is not None else None
        
        self.step_counter = 0
        
        # =============================================================================
        # Models
        # =============================================================================
        s_model = smp.UnetPlusPlus(
                        encoder_name=self.configs.encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                        encoder_weights=self.configs.encoder_weights,  # use `imagenet` pre-trained weights for encoder initialization
                        in_channels=self.configs.in_channels,          # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                        classes=self.configs.classes,                  # model output channels (number of classes in your dataset)
                    )
        
        t_model = smp.UnetPlusPlus(
                        encoder_name=self.configs.encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                        encoder_weights=self.configs.encoder_weights,  # use `imagenet` pre-trained weights for encoder initialization
                        in_channels=self.configs.in_channels,          # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                        classes=self.configs.classes,                  # model output channels (number of classes in your dataset)
                    )
        # detach the teacher model
        for param in t_model.parameters():
            param.detach_()
            
        self.models = {'s': s_model, 
                       't': t_model}
        
        # add gaussian noise
        # currently not in use
        self.gaussian_noiser = GaussianNoiseBlock(self.configs.gaussian_noise).cuda()
        
        # =============================================================================
        # Computing Units
        # =============================================================================
        self.computing_unit = {
            "s" : BCE_BinSeg_CU(),
            "t" : BCE_BinSeg_CU()
        }
        
        # =============================================================================
        # Optimisers
        # =============================================================================
        self.optims = {'s': torch.optim.SGD(self.models["s"].parameters(), lr=self.configs.learning_rate, momentum=self.configs.momentum) # optimizer_funcs[0](self.models["s"].module.param_groups)
                          }

        # =============================================================================
        # Learning rate schedulers
        # =============================================================================
        self.lrsers = {'s': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optims["s"], 
                                                                              T_0 = 32, # number of iterations for the first restart.
                                                                              eta_min = self.configs.min_learning_rate
                                                                               )
                      
                     } # lrer_funcs[0](self.optimizers['s_optimizer'])
        
        # =============================================================================
        # Loss functions
        # =============================================================================
        # TODO: support more types of the consistency criterion
        # something with head and each head has a loss function attached??
        self.criterions = {'s_shape': ShapeLoss(), # criterion_funcs[0](self.args),
                           's_pixel': DiceLoss(mode="BINARY_MODE"), 
                           # TODO
                           # BINARY_MODE MULTICLASS_MODE - the loss wants a different encoding ...
                           # ground_truth = torch.nn.functional.one_hot(ground_truth, num_classes)  # N,H*W -> N,H*W, C
                           # ground_truth = ground_truth.permute(0, 2, 1)  # N, C, H*W
                           # RuntimeError: one_hot is only applicable to index tensor.
                           'cons': torch.nn.MSELoss()
                          }
        
        # =============================================================================
        # Datasets: train, val, test
        # =============================================================================
                
        train_set = MeanTeacherTrainDataset(mode="train", channels=self.configs.in_channels, image_size=self.configs.image_size, csv_filenames=self.configs.csv_filenames)
        train_mbs = MixedBatchSampler(train_set.get_mbs_labels(), n_samples_per_class_per_batch=self.configs.n_samples_per_class_per_batch)
        
        val_set = MeanTeacherValDataset(mode="val", channels=self.configs.in_channels, image_size=self.configs.image_size, csv_filenames=self.configs.csv_filenames)
                
        self.dataloader = {"train" : DataLoader(train_set, batch_sampler=train_mbs),
                           "val" :   DataLoader(val_set)
                          }
          
        
        # =============================================================================
        # Resume training
        # =============================================================================
        if self.load_ckpt:
            self.models["s"].load_state_dict(self.load_ckpt['s_model'])
            self.models["t"].load_state_dict(self.load_ckpt['t_model'])
            self.optims["s"].load_state_dict(self.load_ckpt['s_optimizer'])
            self.lrsers["s"].load_state_dict(self.load_ckpt['s_lrer'])


    
    def run_training(self, epoch):
        
        mode="train"
        
        self.models["s"].train()
        self.models["t"].train()
        
        for idx, item in enumerate(self.dataloader[mode]):
            # =============================================================================
            # Process Batch
            # =============================================================================
                        
            # reset student optimiser
            self.optims["s"].zero_grad()
            
            s_model_output = self.models["s"](item["img"]) # img_t
            
            s_model_output = torch.nn.functional.softmax(s_model_output, dim=0) # I DON'T KNOOOOW
            
            print("s model")
            print(s_model_output)
            print(item["msk"])
            print(s_model_output.shape)
            print(item["msk"].shape)
            print("next")
            
            # run student batch
            s_epoch_collector = self.computing_unit["s"].run_batch(configs=self.configs, criterions=self.criterions, model_output=s_model_output, ground_truth=item["msk"], mode=mode)

            print("sigmoid")
            
            def sigmoid_rampup(current, rampup_length):
                """ Exponential rampup from https://arxiv.org/abs/1610.02242 . 
                """
                if rampup_length == 0:
                    return 1.0
                else:
                    current = np.clip(current, 0.0, rampup_length)
                    phase = 1.0 - current / rampup_length
                    return float(np.exp(-5.0 * phase * phase))
            
            # calculate the ramp-up coefficient of the consistency constraint
            # use mean squared error as the consistency cost and ramp up its weight from 0 to its final value during the first 80 epochs. 
            self.step_counter += 1
            total_steps = 5 # len(self.dataloader["train"]) * 100 # self.args.cons_rampup_epochs # ????
            cons_rampup_scale = sigmoid_rampup(self.step_counter, total_steps)
            
            # =============================================================================
            # Teacher Model
            # =============================================================================
            
            # forward the teacher model
            with torch.no_grad():
                
                t_model_output = self.models["t"](item["img"]) # img_t
                                
                self.computing_unit["t"].run_batch(configs=self.configs, criterions=self.criterions, model_output=t_model_output, ground_truth=item["msk"], mode=mode)
                
                """
                t_resulter, t_debugger = self.t_model.forward(t_inp)
                if not 'pred' in t_resulter.keys():
                    self._pred_err()
                t_pred = tool.dict_value(t_resulter, 'pred')
                t_activated_pred = tool.dict_value(t_resulter, 'activated_pred')
            
                # calculate 't_task_loss' for recording
                l_t_pred = func.split_tensor_tuple(t_pred, 0, lbs)
                l_t_inp = func.split_tensor_tuple(t_inp, 0, lbs)
                t_task_loss = self.s_criterion.forward(l_t_pred, l_gt, l_t_inp)
                t_task_loss = torch.mean(t_task_loss)
                self.meters.update('t_task_loss', t_task_loss.data)
                """
            # =============================================================================
            # Consistency Loss
            # =============================================================================
            # calculate the consistency constraint from the teacher model to the student model
            t_pseudo_gt = Variable(t_model_output[0].detach().data, requires_grad=False)

            if True: # self.args.cons_for_labeled:
                cons_loss = self.criterions["cons"](s_model_output[0], t_pseudo_gt)
            elif False: # self.args.unlabeled_batch_size > 0:
                cons_loss = self.criterions["cons"](s_model_output[0][lbs:, ...], t_pseudo_gt[lbs:, ...])
            else:
                cons_loss = self.zero_tensor
            cons_loss = cons_rampup_scale * torch.mean(cons_loss) # self.args.cons_scale * 

            # =============================================================================
            # Backprop for student model
            # =============================================================================
            loss = s_epoch_collector["loss"] + cons_loss
            loss.backward()
            
            print("loss")
            print(loss)
            
            
            self.optims["s"].step()
            
            

            # =============================================================================
            # EMA for teacher model
            # =============================================================================
            # self._update_ema_variables(self.s_model, self.t_model, self.args.ema_decay, cur_step)
            self.ema_decay = min(1 - 1 / (self.step_counter + 1), self.ema_decay)
            for t_param, s_param in zip(self.models["t"].parameters(), self.models["s"].parameters()):
                t_param.data.mul_(self.ema_decay).add_(1 - self.ema_decay, s_param.data)
            
            # =============================================================================
            # LR Scheduler (Batch)
            # =============================================================================
            # if not self.configs.is_epoch_lrer:
            self.lrsers["s"].step()
        
        
        # =============================================================================
        # Process Epoch
        # =============================================================================
        
        self.computing_unit["s"].run_epoch()
        self.computing_unit["t"].run_epoch()
        
        # =============================================================================
        # LR Scheduler (Epoch)
        # =============================================================================
        #if self.configs.is_epoch_lrer:
        #    self.lrsers["s"].step()

    def run_validation(self, data_loader, epoch):
        self.s_model.eval()
        self.t_model.eval()
        
        # =============================================================================
        # for each batch
        # =============================================================================

        for idx, item in enumerate(self.dataloader["val"]):
            
            timer = time.time()
            
            # =============================================================================
            # Student
            # =============================================================================

            s_resulter, s_debugger = self.s_model.forward(s_inp)
            if not 'pred' in s_resulter.keys() or not 'activated_pred' in s_resulter.keys():
                self._pred_err()
            s_pred = tool.dict_value(s_resulter, 'pred')
            s_activated_pred = tool.dict_value(s_resulter, 'activated_pred')

            s_task_loss = self.s_criterion.forward(s_pred, gt, s_inp)
            s_task_loss = torch.mean(s_task_loss)
            self.meters.update('s_task_loss', s_task_loss.data)

            # =============================================================================
            # Teacher
            # =============================================================================
            
            t_resulter, t_debugger = self.t_model.forward(t_inp)
            if not 'pred' in t_resulter.keys() or not 'activated_pred' in t_resulter.keys():
                self._pred_err()
            t_pred = tool.dict_value(t_resulter, 'pred')
            t_activated_pred = tool.dict_value(t_resulter, 'activated_pred')

            t_task_loss = self.s_criterion.forward(t_pred, gt, t_inp)
            t_task_loss = torch.mean(t_task_loss)
            self.meters.update('t_task_loss', t_task_loss.data)
            
            # =============================================================================
            # Pseudo ???
            # =============================================================================

            t_pseudo_gt = Variable(t_pred[0].detach().data, requires_grad=False)
            cons_loss = self.cons_criterion(s_pred[0], t_pseudo_gt)
            cons_loss = self.args.cons_scale * torch.mean(cons_loss)
            self.meters.update('cons_loss', cons_loss.data)

            #self.task_func.metrics(s_activated_pred, gt, s_inp, self.meters, id_str='student')
            #self.task_func.metrics(t_activated_pred, gt, t_inp, self.meters, id_str='teacher')
            
            # =============================================================================
            # Logger
            # =============================================================================
            
            self.meters.update('batch_time', time.time() - timer)
            if idx % self.args.log_freq == 0:
                logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                                '  student-{3}\t=>\t'
                                's-task-loss: {meters[s_task_loss]:.6f}\t'
                                's-cons-loss: {meters[cons_loss]:.6f}\n'
                                '  teacher-{3}\t=>\t'
                                't-task-loss: {meters[t_task_loss]:.6f}\n'
                                .format(epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters))

            if self.args.visualize and idx % self.args.visual_freq == 0:
                self._visualize(epoch, idx, False, 
                                func.split_tensor_tuple(s_inp, 0, 1, reduce_dim=True),
                                func.split_tensor_tuple(s_activated_pred, 0, 1, reduce_dim=True),
                                func.split_tensor_tuple(t_inp, 0, 1, reduce_dim=True),
                                func.split_tensor_tuple(t_activated_pred, 0, 1, reduce_dim=True),
                                func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))
    
        # =============================================================================
        # Metrics
        # =============================================================================
        # metrics
        metrics_info = {'student': '', 'teacher': ''}
        for key in sorted(list(self.meters.keys())):
            #if self.task_func.METRIC_STR in key:
            if True:
                for id_str in metrics_info.keys():
                    if key.startswith(id_str):
                        metrics_info[id_str] += '{0}: {1:.6}\t'.format(key, self.meters[key])

        logger.log_info('Validation metrics:\n  student-metrics\t=>\t{0}\n  teacher-metrics\t=>\t{1}\n'
            .format(metrics_info['student'].replace('_', '-'), metrics_info['teacher'].replace('_', '-')))

    def run_cleanup(self, epoch):
        # =============================================================================
        # Logger
        # =============================================================================
        
        # if save_model is True:
        
        if epoch > 5 and self.computing_unit["s"].epoch_collector["fscore"] > self.computing_unit["s"].best["fscore"]:
            state = {
                'name': self.prefix,
                'epoch': epoch, 
                's_model': self.models["s"].state_dict(),
                't_model': self.models["t"].state_dict(),
                's_optim': self.optims["s"].state_dict(),
                's_lrer': self.lrsers["s"].state_dict()
            }

            checkpoint = os.path.join(self.configs.save_checkpoint_path, 'checkpoint_{0}.ckpt'.format(epoch))
            torch.save(state, checkpoint)
            
            self.computing_unit["s"].best["fscore"] = self.computing_unit["s"].epoch_collector["fscore"]
            
        self.computing_unit["s"].reset_epoch()
        self.computing_unit["t"].reset_epoch()
        
        
    def log(self, epoch):
        # =============================================================================
        # Logger
        # =============================================================================
        
        self.meters.update('batch_time', time.time() - timer)
        if idx % self.args.log_freq == 0:
            logger.log_info('step: [{0}][{1}/{2}]\tbatch-time: {meters[batch_time]:.3f}\n'
                            '  student-{3}\t=>\t'
                            's-task-loss: {meters[s_task_loss]:.6f}\t'
                            's-cons-loss: {meters[cons_loss]:.6f}\n'
                            '  teacher-{3}\t=>\t'
                            't-task-loss: {meters[t_task_loss]:.6f}\n'
                            .format(epoch + 1, idx, len(data_loader), self.args.task, meters=self.meters))

        # visualization
        if self.args.visualize and idx % self.args.visual_freq == 0:
            self._visualize(epoch, idx, True, 
                            func.split_tensor_tuple(s_inp, 0, 1, reduce_dim=True),
                            func.split_tensor_tuple(s_activated_pred, 0, 1, reduce_dim=True),
                            func.split_tensor_tuple(t_inp, 0, 1, reduce_dim=True),
                            func.split_tensor_tuple(t_activated_pred, 0, 1, reduce_dim=True),
                            func.split_tensor_tuple(gt, 0, 1, reduce_dim=True))

    # -------------------------------------------------------------------------------------------
    # Tool Functions for SSL_MT
    # -------------------------------------------------------------------------------------------

    def _visualize(self, epoch, idx, is_train, 
                   s_inp, s_pred, t_inp, t_pred, gt):

        visualize_path = self.args.visual_train_path if is_train else self.args.visual_val_path
        out_path = os.path.join(visualize_path, '{0}_{1}'.format(epoch, idx))

        #self.task_func.visualize(out_path, id_str='student', inp=s_inp, pred=s_pred, gt=gt)
        #self.task_func.visualize(out_path, id_str='teacher', inp=t_inp, pred=t_pred, gt=gt)

    def _batch_prehandle(self, inp, gt, is_train):
        # add extra data augmentation process here if necessary

        # 'self.gaussian_noiser' will add the noise to the first input element
        s_inp_var, t_inp_var = [], []
        for idx, i in enumerate(inp):
            if is_train and idx == 0:
                s_inp_var.append(self.gaussian_noiser.forward(Variable(i).cuda())) 
                t_inp_var.append(self.gaussian_noiser.forward(Variable(i).cuda())) 
            else:
                s_inp_var.append(Variable(i).cuda()) 
                t_inp_var.append(Variable(i).cuda())
        s_inp = tuple(s_inp_var)
        t_inp = tuple(t_inp_var)
        
        gt_var = []
        for g in gt:
            gt_var.append(Variable(g).cuda())
        gt = tuple(gt_var)

        return s_inp, t_inp, gt



# Run

In [4]:
# Configs
configs = Configs()
configs.log()

# Run
run = RoutineMT(configs)

for epoch in range(configs.epochs):

    run.run_training(epoch=epoch)
    #run.run_validation(epoch=epoch)
    #run.run_cleanup(epoch=epoch)
    
    
    

    routine.log()

['datasceyence/data_prep\\mt_data_ichallenge_amd.csv', 'datasceyence/data_prep\\mt_data_ichallenge_glaucoma.csv', 'datasceyence/data_prep\\mt_data_ichallenge_non_amd.csv', 'datasceyence/data_prep\\mt_data_ichallenge_unlabelled.csv']
C:\Users\Prinzessin\projects\image_data\iChallenge_AMD_OD_Fovea_lesions\images_AMD\A0044.jpg
C:\Users\Prinzessin\projects\image_data\iChallenge_AMD_OD_Fovea_lesions\Disc_Masks_bin\A0044.bmp
C:\Users\Prinzessin\projects\image_data\iChallenge_AMD_OD_Fovea_lesions\images_AMD\A0057.jpg
C:\Users\Prinzessin\projects\image_data\iChallenge_AMD_OD_Fovea_lesions\Disc_Masks_bin\A0057.bmp
C:\Users\Prinzessin\projects\image_data\iChallenge_Glaucoma_OD_Fovea\images_Glaucoma\g0007.jpg
C:\Users\Prinzessin\projects\image_data\iChallenge_Glaucoma_OD_Fovea\Disc_Masks_bin\g0007.bmp
s model
tensor([[[[0.2450, 0.2615, 0.2796,  ..., 0.2385, 0.2138, 0.2352],
          [0.2468, 0.3105, 0.2773,  ..., 0.2630, 0.2452, 0.2448],
          [0.2420, 0.2571, 0.2047,  ..., 0.2765, 0.3044, 0

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\utils\python_arg_parser.cpp:1420.)
  t_param.data.mul_(self.ema_decay).add_(1 - self.ema_decay, s_param.data)


C:\Users\Prinzessin\projects\image_data\iChallenge_AMD_OD_Fovea_lesions\images_AMD\A0063.jpg
C:\Users\Prinzessin\projects\image_data\iChallenge_AMD_OD_Fovea_lesions\Disc_Masks_bin\A0063.bmp
C:\Users\Prinzessin\projects\image_data\iChallenge_Glaucoma_OD_Fovea\images_Glaucoma\g0040.jpg
C:\Users\Prinzessin\projects\image_data\iChallenge_Glaucoma_OD_Fovea\Disc_Masks_bin\g0040.bmp
C:\Users\Prinzessin\projects\image_data\iChallenge_AMD_OD_Fovea_lesions\images_AMD\A0058.jpg
C:\Users\Prinzessin\projects\image_data\iChallenge_AMD_OD_Fovea_lesions\Disc_Masks_bin\A0058.bmp
s model
tensor([[[[0.2520, 0.2481, 0.2803,  ..., 0.2503, 0.2835, 0.2467],
          [0.2733, 0.2980, 0.2621,  ..., 0.1942, 0.2235, 0.2448],
          [0.2757, 0.3023, 0.2656,  ..., 0.2196, 0.2918, 0.2940],
          ...,
          [0.2374, 0.2529, 0.2358,  ..., 0.2951, 0.3170, 0.2565],
          [0.2425, 0.2690, 0.2461,  ..., 0.2561, 0.2393, 0.2160],
          [0.2429, 0.2321, 0.2542,  ..., 0.2433, 0.2282, 0.2233]]],


        

KeyboardInterrupt: 

In [None]:
import torch.nn

m = torch.nn.Softmax(dim=1)
input = torch.randn(1, 2, 3)
output = m(input)