# 𝕊𝕖𝕞𝕚-𝕊𝕦𝕡𝕖𝕣𝕧𝕚𝕤𝕖𝕕 𝕝𝕖𝕒𝕣𝕟𝕚𝕟𝕘 𝕦𝕤𝕚𝕟𝕘 𝕄𝕖𝕒𝕟 𝕋𝕖𝕒𝕔𝕙𝕖𝕣

Implementation of pixel-wise Mean Teacher (MT)
    
This method is proposed in the paper: 
    'Mean Teachers are Better Role Models:
        Weight-Averaged Consistency Targets Improve Semi-Supervised Deep Learning Results'
This implementation only supports Gaussian noise as input perturbation, and the two-heads
outputs trick is not available.

Source:
https://github.com/ZHKKKe/PixelSSL/blob/master/pixelssl/ssl_algorithm/ssl_mt.py


Todo:
* [] get everything on cuda, cpu
* [] metrics for segmentation + unittests

# Imports

In [2]:
import glob
import os
import time
import random
from PIL import Image
import logging


import pandas as pd

import torch
import torch.nn
from torch.autograd import Variable
import torchvision
from torch.utils.tensorboard import SummaryWriter

import segmentation_models_pytorch as smp

from torch.utils.data import DataLoader


import sys
sys.path.insert(0, "helper")
from helper.dataset.mean_teacher import *
# from helper.model.mean_teacher import * 
from helper.sampler.mixed_batch import *
from helper.model.block.noise_block import GaussianNoiseBlock
from helper.compute.bin_seg import BCE_BinSeg_CU
from helper.compute.loss.shape import ShapeLoss
from helper.compute.loss.dice import DiceLoss

#from pixelssl.utils import REGRESSION, CLASSIFICATION
#from pixelssl.utils import logger, cmd, tool
#from pixelssl.nn import func
#from pixelssl.nn.module import patch_replication_callback, GaussianNoiseLayer

KeyboardInterrupt: 

# Experiment Configs

In [2]:
class Configs():
    
    def __init__(self):
        # =============================================================================
        # 
        # =============================================================================
        
        self.prefix = "mt_tmp"
        self.reduced_data = False
        
        # smp unet ++ parameters
        self.encoder_name = "efficientnet-b7"
        self.encoder_weights = "imagenet"
        self.in_channels =  1
        self.n_output_neurons = 2
        
        self.epochs = 100
        
        self.gaussian_noise = 0.1 # None
        
        self.ema_decay = 0.999 # default value
        
        # Sizes of tensors must match except in dimension 1
        # I solved the issue by resizing all the images size divisible to 32.
        self.image_size = 128 # 512
        
        self.num_workers = 0
        self.iterations = 50
        
        # batch size = n_samples_per_class_per_batch * classes
        # for mixed batch sampling
        self.n_samples_per_class_per_batch = 1
        
        self.lbs = 3 #  self.args.labeled_batch_size # .... remove this eventually and replace

        # optimisation
        self.optimiser = "sgd"
        self.base_lr = 0.01
        self.min_lr = 0.0001
        self.weight_decay = 1e-4
        self.momentum = 0.9
        
        # self.is_epoch_lrer = True # epoch or batch based learning rate updater
        
        self.dropout = None
        
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        
        # =============================================================================
        # Paths
        # =============================================================================
        
        self.base_path = r"C:/Users/Prinzessin/projects/decentnet"
        if not os.path.exists(self.base_path):
            os.makedirs(self.base_path)
        os.chdir(self.base_path) # this is now the main directory !!!!!!!!!!!!!!!!!!!!
        
        self.csv_filenames = glob.glob(r"datasceyence/data_prep/mt_*.csv")
        
        print(self.csv_filenames)
        
        # input
        self.load_checkpoint_file = None
        
        # all csv files used for run_mean_teacher.ipybn
        #self.csv_data_paths = [
        #    {"path" : r"data/data_ichallenge_amd.csv"}, 
        #    {"path" : r"data/data_ichallenge_non_amd.csv"}
        #]
        
        # output
        self.logger_path = f"results/{self.prefix}/logs"
        if not os.path.exists(self.logger_path):
            os.makedirs(self.logger_path)
            
        self.save_checkpoint_path = f"results/{self.prefix}/ckpts"
        if not os.path.exists(self.save_checkpoint_path):
            os.makedirs(self.save_checkpoint_path)

        
            
            
            
        
    def log(self):
        # =============================================================================
        # save all class variables to file "configs.txt"
        # =============================================================================
        c = pd.DataFrame.from_dict({'key': self.__dict__.keys(), 'value': self.__dict__.values()})
        c.to_csv(os.path.join(self.logger_path, "configs.txt"), sep=':', index=False)

# Routine

In [3]:
class RoutineMT:

    def __init__(self, configs):
        super(RoutineMT, self).__init__()
        
        self.configs = configs
        
        self.prefix = configs.prefix
        self.ema_decay = configs.ema_decay
        
        self.load_ckpt = torch.load(configs.load_checkpoint_file) if configs.load_checkpoint_file is not None else None
        
        self.step_counter = 0
        
        
        # =============================================================================
        # Logger
        # =============================================================================
        
        self.writer =  SummaryWriter(log_dir=self.configs.logger_path)
        logging.basicConfig(filename=os.path.join(self.configs.logger_path, 'logger.log'), encoding='utf-8', level=logging.DEBUG)
        
        # =============================================================================
        # Models
        # =============================================================================
        s_model = smp.UnetPlusPlus(
                        encoder_name=self.configs.encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                        encoder_weights=self.configs.encoder_weights,  # use `imagenet` pre-trained weights for encoder initialization
                        in_channels=self.configs.in_channels,          # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                        classes=self.configs.n_output_neurons,         # model output channels (number of classes in your dataset)
                    )
        
        t_model = smp.UnetPlusPlus(
                        encoder_name=self.configs.encoder_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                        encoder_weights=self.configs.encoder_weights,  # use `imagenet` pre-trained weights for encoder initialization
                        in_channels=self.configs.in_channels,          # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                        classes=self.configs.n_output_neurons,         # model output channels (number of classes in your dataset)
                    )
        # detach the teacher model
        for param in t_model.parameters():
            param.detach_()
            
        self.models = {'s': s_model, 
                       't': t_model}
        
        # add gaussian noise
        # currently not in use
        self.gaussian_noiser = GaussianNoiseBlock(self.configs.gaussian_noise).cuda()
        
        # =============================================================================
        # Computing Units
        # =============================================================================
        self.computing_unit = {
            "s_train" : BCE_BinSeg_CU(n_output_neurons=self.configs.n_output_neurons, mode="train"),
            "s_val" : BCE_BinSeg_CU(n_output_neurons=self.configs.n_output_neurons, mode="val"),
            "t_train" : BCE_BinSeg_CU(n_output_neurons=self.configs.n_output_neurons, mode="train"),
            "t_val" : BCE_BinSeg_CU(n_output_neurons=self.configs.n_output_neurons, mode="val")
        }
        
        
        
        # =============================================================================
        # Optimisers
        # =============================================================================
        self.optims = {'s': torch.optim.SGD(self.models["s"].parameters(), lr=self.configs.base_lr, momentum=self.configs.momentum) # optimizer_funcs[0](self.models["s"].module.param_groups)
                          }

        # =============================================================================
        # Learning rate schedulers
        # =============================================================================
        self.lrsers = {'s': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optims["s"], 
                                                                              T_0 = 1, # number of iterations for the first restart.
                                                                              eta_min = self.configs.min_lr
                                                                               )
                      
                     } # lrer_funcs[0](self.optimizers['s_optimizer'])
        
        # =============================================================================
        # Loss functions
        # =============================================================================
        # TODO: support more types of the consistency criterion
        # something with head and each head has a loss function attached??
        self.criterions = {'shape': ShapeLoss(), # criterion_funcs[0](self.args),
                           'pixel': DiceLoss(n_output_neurons=self.configs.n_output_neurons),
                           # TODO
                           # BINARY_MODE MULTICLASS_MODE - the loss wants a different encoding ...
                           # ground_truth = torch.nn.functional.one_hot(ground_truth, num_classes)  # N,H*W -> N,H*W, C
                           # ground_truth = ground_truth.permute(0, 2, 1)  # N, C, H*W
                           # RuntimeError: one_hot is only applicable to index tensor.
                           'cons': torch.nn.MSELoss()
                          }
        
        um = UncertaintyMetric(n_noise=4, n_repeat=2, n_output_neurons=self.configs.n_output_neurons)
        
        # =============================================================================
        # Datasets: train, val, test
        # =============================================================================
                
        train_set = MeanTeacherTrainDataset(mode="train", channels=self.configs.in_channels, image_size=self.configs.image_size, csv_filenames=self.configs.csv_filenames)
        train_mbs = MixedBatchSampler(train_set.get_mbs_labels(), n_samples_per_class_per_batch=self.configs.n_samples_per_class_per_batch)
        
        val_set = MeanTeacherValDataset(mode="val", channels=self.configs.in_channels, image_size=self.configs.image_size, csv_filenames=self.configs.csv_filenames)
                
        self.dataloader = {"train" : DataLoader(train_set, batch_sampler=train_mbs),
                           "val" :   DataLoader(val_set)
                          }
          
        # =============================================================================
        # Resume training
        # =============================================================================
        if self.load_ckpt:
            self.models["s"].load_state_dict(self.load_ckpt['s_model'])
            self.models["t"].load_state_dict(self.load_ckpt['t_model'])
            self.optims["s"].load_state_dict(self.load_ckpt['s_optimizer'])
            self.lrsers["s"].load_state_dict(self.load_ckpt['s_lrer'])


    
    def run_training(self, epoch):
        # =============================================================================
        # Training
        # =============================================================================
        
        mode="train"
        
        self.models["s"].train()
        self.models["t"].train()
        
        for i_item, item in enumerate(self.dataloader[mode]):
            # =============================================================================
            # Process Batch
            # =============================================================================
            
            # unlabelled for consistency loss
            unlabelled_batch_ids = np.where(item["has_mask"] == False) 
            # labelled for task loss
            labelled_batch_ids = np.where(item["has_mask"] == True) 
            
            # get predictions of student model for all images
            s_model_output = self.models["s"](item["img"]) # we want all images (task loss and consistency loss)
            
            # run batch for student model
            self.computing_unit["s_train"].run_batch(configs=self.configs, criterions=self.criterions, model_output=s_model_output[labelled_batch_ids], ground_truth=item["msk"][labelled_batch_ids])
            
            if False:
                print(labelled_batch_ids)
                print(unlabelled_batch_ids)
                print("item"*40)
                print(item)
                print("item"*40)
                print("s model")
                print(s_model_output)
                print(item["msk"])
                print(s_model_output.shape)
                print(item["msk"].shape)
                print("next")

            # =============================================================================
            # Teacher Model
            # =============================================================================
            
            # forward the teacher model
            with torch.no_grad():
                
                # get predictions of teacher model for all images
                t_model_output = self.models["t"](item["img"])                
                self.computing_unit["t_train"].run_batch(configs=self.configs, criterions=self.criterions, model_output=t_model_output[labelled_batch_ids], ground_truth=item["msk"][labelled_batch_ids])
                
                
            
            uncertainy_map = um.run(self.models["t"], item["img"])
        
            # =============================================================================
            # Consistency Loss
            # =============================================================================
            
            def sigmoid_rampup(current, rampup_length):
                print("sigmoid")
                # https://github.com/HiLab-git/SSL4MIS/blob/master/code/utils/ramps.py
                # Consistency ramp-up from https://arxiv.org/abs/1610.02242
                if rampup_length == 0:
                    return 1.0
                else:
                    current = np.clip(current, 0.0, rampup_length)
                    phase = 1.0 - current / rampup_length
                    return float(np.exp(-5.0 * phase * phase))
            
            # calculate the ramp-up coefficient of the consistency constraint
            # use mean squared error as the consistency cost and ramp up its weight from 0 to its final value during the first 80 epochs. 
            #self.step_counter += 1
            #total_steps = 5 # len(self.dataloader["train"]) * 100 # self.args.cons_rampup_epochs # ????
            #cons_rampup_scale = sigmoid_rampup(self.step_counter, total_steps)
            cons_rampup_scale = 0.5
            
            cons_weight = 0.5 # not in use rn
            
            
            # calculate the consistency constraint from the teacher model to the student model
            t_pseudo_gt = Variable(t_model_output.detach().data, requires_grad=False)
            
            if unlabelled_batch_ids:
                cons_loss = self.criterions["cons"](s_model_output[unlabelled_batch_ids], t_pseudo_gt[unlabelled_batch_ids])
            else:
                cons_loss = self.zero_tensor
                
            cons_loss = torch.mean(cons_loss) 

            # =============================================================================
            # Backprop for student model
            # =============================================================================
            
            # combined loss
            
            
            loss = self.computing_unit["s_train"].task_loss + cons_loss * cons_rampup_scale #  * cons_weight
            
            if False: # todo xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx x x x x x x x x x x
                # x sajldn asffas y y y y x x x x x x x x  x x x x x x x x x x x x x 
                # reset student optimiser
                self.optims["s"].zero_grad()
                loss.backward()
                self.optims["s"].step()
            
            print("training loss")
            print(loss)
            
            # =============================================================================
            # EMA for teacher model
            # https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_uncertainty_aware_mean_teacher_2D.py
            # https://github.com/ZHKKKe/PixelSSL/blob/master/pixelssl/ssl_algorithm/ssl_mt.py
            # =============================================================================
            self.ema_decay = min(1 - 1 / (self.step_counter + 1), self.ema_decay)
            for t_param, s_param in zip(self.models["t"].parameters(), self.models["s"].parameters()):
                t_param.data.mul_(self.ema_decay).add_(1 - self.ema_decay, s_param.data)
        
        
        # =============================================================================
        # Epoch Process (basically logging)
        # =============================================================================
        self.computing_unit["s_train"].run_epoch()
        self.computing_unit["t_train"].run_epoch()
        
        # =============================================================================
        # Epoch LR Scheduler
        # =============================================================================
        self.lrsers["s"].step()
        
        # =============================================================================
        # Epoch log and reset (training)
        # =============================================================================
        self.computing_unit["s_train"].log()
        self.computing_unit["t_train"].log()
        self.computing_unit["s_train"].reset_epoch()
        self.computing_unit["t_train"].reset_epoch()
        
        logger.info("Traning of epoch %d done" % epoch)

    def run_validation(self, data_loader, epoch):
        self.s_model.eval()
        self.t_model.eval()
        
        # =============================================================================
        # for each batch
        # =============================================================================

        for i_item, item in enumerate(self.dataloader["val"]):
            
            timer = time.time()
            
            # =============================================================================
            # Student
            # =============================================================================

            with torch.no_grad():
                # get predictions of student model for all images
                s_model_output = self.models["s"](item["img"]) # we want all images (task loss and consistency loss)
                s_model_output = torch.nn.functional.softmax(s_model_output, dim=0) # I DON'T KNOOOOW

                # run batch for student model
                self.computing_unit["s_val"].run_batch(configs=self.configs, criterions=self.criterions, model_output=s_model_output, ground_truth=item["msk"])
            
            # =============================================================================
            # Teacher
            # =============================================================================
            
            with torch.no_grad():
                # get predictions of teacher model for all images
                t_model_output = self.models["t"](item["img"])                
            
                self.computing_unit["t_val"].run_batch(configs=self.configs, criterions=self.criterions, model_output=t_model_output[labelled_batch_ids], ground_truth=item["msk"][labelled_batch_ids])
                
            # =============================================================================
            # Consistency loss
            # =============================================================================

            with torch.no_grad():
                # calculate the consistency constraint from the teacher model to the student model
                t_pseudo_gt = Variable(t_model_output.detach().data, requires_grad=False)
                cons_loss = self.criterions["cons"](s_model_output, t_pseudo_gt)    
                cons_loss = torch.mean(cons_loss) # * self.args.cons_scale

            
        # =============================================================================
        # Epoch log and reset (validation)
        # =============================================================================
        self.computing_unit["s_val"].log()
        self.computing_unit["t_val"].log()
        self.computing_unit["s_val"].reset_epoch()
        self.computing_unit["t_val"].reset_epoch()
        
        logger.info("Validation of epoch %d done" % epoch)
        

    def log(self, epoch):
        # =============================================================================
        # Save checkpoint
        # =============================================================================
                
        if epoch > 5 and self.computing_unit["s_val"].epoch_collector["fscore"] > self.computing_unit["s_val"].best["fscore"]:
            state = {
                'name': self.prefix,
                'epoch': epoch, 
                's_model': self.models["s"].state_dict(),
                't_model': self.models["t"].state_dict(),
                's_optim': self.optims["s"].state_dict(),
                's_lrer': self.lrsers["s"].state_dict()
            }

            checkpoint = os.path.join(self.configs.save_checkpoint_path, 'checkpoint_{0}.ckpt'.format(epoch))
            torch.save(state, checkpoint)
            
            self.computing_unit["s_val"].best["fscore"] = self.computing_unit["s_val"].epoch_collector["fscore"]
            
            logger.info("Saved model at epoch %d" % epoch)
            
        
        




# Run

In [None]:
# Configs
configs = Configs()
configs.log()

# Run
run = RoutineMT(configs)

for epoch in range(configs.epochs):
    
    run.run_training(epoch=epoch)
    #run.run_validation(epoch=epoch)
    #run.log(epoch=epoch)
    


['datasceyence/data_prep\\mt_data_ichallenge_amd.csv', 'datasceyence/data_prep\\mt_data_ichallenge_glaucoma.csv', 'datasceyence/data_prep\\mt_data_ichallenge_non_amd.csv', 'datasceyence/data_prep\\mt_data_ichallenge_unlabelled.csv']
loss
tensor(1.8001, grad_fn=<AddBackward0>)


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\utils\python_arg_parser.cpp:1420.)
  t_param.data.mul_(self.ema_decay).add_(1 - self.ema_decay, s_param.data)


loss
tensor(0.6012, grad_fn=<AddBackward0>)
loss
tensor(0.5841, grad_fn=<AddBackward0>)
loss
tensor(0.5913, grad_fn=<AddBackward0>)
loss
tensor(0.6336, grad_fn=<AddBackward0>)
loss
tensor(0.5909, grad_fn=<AddBackward0>)
loss
tensor(0.4830, grad_fn=<AddBackward0>)
Traceback (most recent call last):
  File "C:\Users\Prinzessin\anaconda3\envs\feta\lib\site-packages\IPython\core\interactiveshell.py", line 3524, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\Prinzessin\AppData\Local\Temp\ipykernel_12496\3393773207.py", line 10, in <cell line: 8>
    run.run_training(epoch=epoch)
  File "C:\Users\Prinzessin\AppData\Local\Temp\ipykernel_12496\3466807000.py", line 164, in run_training
    t_model_output = self.models["t"](item["img"])
  File "C:\Users\Prinzessin\anaconda3\envs\feta\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Prinzessin\anaconda3\envs\feta\lib\site-packages

In [4]:
import torch.nn

m = torch.nn.Softmax(dim=1)
input = torch.randn(1, 2, 3)
output = m(input)

ERROR! Session/line number was not unique in database. History logging moved to new session 1756
