Mutil Mean Teacher

original
FAICG_FA_FA

backbone=DeeplabV3+

In [1]:
import os
import json
import logging
import copy
import shutil
import time

from tqdm import tqdm

import cv2
import numpy as np
from PIL import Image

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

from torchvision import transforms
from torchvision.transforms import functional as F

from math import ceil
from itertools import cycle
from itertools import chain
from collections import OrderedDict
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns

import random
random.seed(42)

import gc

In [None]:
#config
batch_size = 7
epochs = 80
warm_up = 100
labeled_examples = 515
lr = 1e-2
backbone = 101 #" the resnet x {50, 101} layers"
semi_p_th = 0.6 # positive_threshold for semi-supervised loss
semi_n_th = 0.0 # negative_threshold for semi-supervised loss
unsup_weight = 1.5 # unsupervised weight for the semi-supervised loss

config = json.load(open("configs/config_deeplab_v3+_onlyFA_range_selfsupervised.json"))

config['train_supervised']['batch_size'] = batch_size
config['train_unsupervised']['batch_size'] = batch_size
config['model']['epochs'] = epochs
config['model']['warm_up_epoch'] = warm_up
config['n_labeled_examples'] = labeled_examples
config['model']['resnet'] = backbone
config['optimizer']['args']['lr'] = lr
config['unsupervised_w'] = unsup_weight
config['model']['data_h_w'] = [config['train_supervised']['crop_size'], config['train_supervised']['crop_size']]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

logger = logging.getLogger("PS-MT")
logger.propagate = False
logger.warning("Training start, 总共 {} epochs".format(str(config['model']['epochs'])))
logger.critical("GPUs: {}".format(device))
logger.critical("DeeplabV3+ with ResNet {} backbone".format(str(config['model']['resnet'])))
logger.critical("Current Labeled Example: {}".format(config['n_labeled_examples']))
logger.critical("Learning rate: other {}, and head is the SAME [world]".format(config['optimizer']['args']['lr']))

logger.critical("Current batch: {} [world]".format(int(config['train_unsupervised']['batch_size']) +
                                                int(config['train_supervised']['batch_size'])) )

logger.critical("Current unsupervised loss function: {}, with weight {} and length {}".format(config['model']['un_loss'],
                                                                                            config['unsupervised_w'],
                                                                                            config['ramp_up']))
print("\nconfig json :")
for i in config:
    print(i, config[i])


In [None]:
# DATA LOADERS
from DataLoader.dataset_onlyFA import *
choose_data = "All"

config['train_supervised']['choose'] = choose_data
config['train_unsupervised']['choose'] = choose_data
config['val_loader']['choose'] = choose_data
config['test_loader']['choose'] = choose_data

print("train_supervised")
for i in config['train_supervised']:
    print("    ",i, ":", config['train_supervised'][i])
print("train_unsupervised")
for i in config['train_unsupervised']:
    print("    ", i, ":", config['train_unsupervised'][i])
print("val_loader")
for i in config['val_loader']:
    print("    ", i, ":", config['val_loader'][i])
print("test_loader")
for i in config['test_loader']:
    print("    ", i, ":", config['test_loader'][i])

supervised_set = BasicDataset(data_dir=config['train_supervised']['data_dir'], 
                                 choose=config['train_supervised']['choose'],
                                 split=config['train_supervised']['split'])
unsupervised_set = BasicDataset(data_dir=config['train_unsupervised']['data_dir'],
                                   choose=config['train_unsupervised']['choose'],
                                   split=config['train_unsupervised']['split'])
val_set = BasicDataset(data_dir=config['val_loader']['data_dir'],
                          choose=config['val_loader']['choose'],
                          split=config['val_loader']['split'])

test_set = BasicDataset(data_dir=config['test_loader']['data_dir'],
                          choose=config['test_loader']['choose'],
                          split=config['test_loader']['split'])

self_supervised_set = BasicDataset(data_dir=config['train_unsupervised']['data_dir'],
                                    choose=config['train_unsupervised']['choose'],
                                    split=config['train_unsupervised']['split'])

print("supervised_loader: ",len(supervised_set))
print("unsupervised_loader: ",len(unsupervised_set))
print("val_loader: ",len(val_set))
print("test_loader: ",len(test_set))


supervised_loader = DataLoader(dataset=supervised_set, batch_size=config['train_supervised']['batch_size'],
                               shuffle=config['train_supervised']['shuffle'], 
                               num_workers=config['train_supervised']['num_workers'])
unsupervised_loader = DataLoader(dataset=unsupervised_set, batch_size=config['train_unsupervised']['batch_size'],
                               shuffle=config['train_unsupervised']['shuffle'], 
                               num_workers=config['train_unsupervised']['num_workers'])
val_loader = DataLoader(dataset=val_set, batch_size=config['val_loader']['batch_size'],
                               shuffle=config['val_loader']['shuffle'], 
                               num_workers=config['val_loader']['num_workers'])
test_loader = DataLoader(dataset=test_set, batch_size=config['test_loader']['batch_size'],
                               shuffle=config['test_loader']['shuffle'], 
                               num_workers=config['test_loader']['num_workers'])



In [None]:
#model setting (1 teacher + 1 student)
from torch import optim

from Model.Deeplabv3_plus.psmt_model import *
from Utils.ramps import *

cons_w_unsup = ConsistencyWeight(final_w=config['unsupervised_w'], iters_per_epoch=len(unsupervised_loader),
                                 rampup_starts=0, rampup_ends=config['ramp_up'],  ramp_type="cosine_rampup")


model_t1 = Teacher_Net(num_classes=2, config=config['model'])
model_t1 = model_t1.to(device)

model_s = Student_Net(num_classes=2, config=config['model'],  cons_w_unsup=cons_w_unsup)
model_s = model_s.to(device)

optimizer_t1 = optim.SGD(model_t1.parameters(), 
                      lr=config['optimizer']['args']['lr'],
                      momentum=config['optimizer']['args']['momentum'],
                      weight_decay=config['optimizer']['args']['weight_decay'])

optimizer_s = optim.SGD(model_s.parameters(), 
                      lr=config['optimizer']['args']['lr'],
                      momentum=config['optimizer']['args']['momentum'],
                      weight_decay=config['optimizer']['args']['weight_decay'])

In [5]:
#calculate metrics
from sklearn import metrics, neighbors
from sklearn.metrics import confusion_matrix

#compute mean IoU & DSC  of given outputs & targets per patient (5 time points)
def information_index(outputs, targets):
    eps = np.finfo(np.float64).eps
    output = outputs.flatten()
    target = targets.flatten()
    TN, FP, FN, TP = confusion_matrix(target,output).ravel()

    index_MIou =  ( TP / (TP + FP + FN + eps) + TN / (TN + FN + FP + eps) ) / 2
    mean_iou = np.mean(index_MIou)
    index_dice = 2*TP / (2*TP + FP + FN + eps)
    mean_dice = np.mean(index_dice)

    return mean_iou, mean_dice

#compute mean IoU & DSC of validaton (test set)
def count_index(pre, tar):
        path_pre = pre
        path_target = tar
        dirs = os.listdir(path_pre)
        # print(len(dirs))
        con_mIOU = 0
        con_mdice = 0
        for imgs in dirs:
            pre_path = path_pre + '/' + str(imgs)
            target_path = path_target + '/' + str(imgs)

            target = cv2.imread(target_path, cv2.IMREAD_GRAYSCALE)
            _, tar = cv2.threshold(target, 128, 255, cv2.THRESH_BINARY)

            predict = cv2.imread(pre_path, cv2.IMREAD_GRAYSCALE)
            _, pre = cv2.threshold(predict, 128, 255, cv2.THRESH_BINARY)

            tIOU, tdice = information_index(pre,tar)
            con_mIOU += tIOU
            con_mdice += tdice
        val_mIoU = con_mIOU/len(dirs)
        val_mDice = con_mdice/len(dirs)
        
        return val_mIoU, val_mDice

In [6]:
# from medpy import metric
# def calculate_metric_percase(pred, gt):
#     pred[pred > 0] = 1
#     gt[gt > 0] = 1
#     if pred.sum() > 0:
#         dice = metric.binary.dc(pred, gt)
#         hd95 = metric.binary.hd95(pred, gt)
#         return dice, hd95
#     else:
#         return 0, 0
# def test_single_volume(label, output, classes=2):
#     label = torch.clamp(label, 0, 1)
#     label = label.squeeze(0).cpu().detach().numpy()
#     output = output.cpu().detach().numpy()
#     metric_list = []
#     for i in range(1, classes):
#         metric_list.append(calculate_metric_percase(
#             output == i, label == i))
    
#     return metric_list

In [None]:
# warm
import torch.nn.functional
from Utils.early_stop import EarlyStopping

early_stopper = EarlyStopping(patience=5, delta=0.0001)

for epoch in range(config['model']['warm_up_epoch']):

    epoch_loss_t1 = 0
    epoch_loss_s = 0
    t1_mIoU, t1_mDice = 0, 0
    s_mIoU, s_mDice = 0, 0
    
    localtime = time.asctime( time.localtime(time.time()) )
    print('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['warm_up_epoch'],localtime))
    print('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['warm_up_epoch'],localtime)))

    folder_name = None
    folder_name = os.path.join("see_image", "self-warm")
    os.makedirs(folder_name, exist_ok=True)

    for batch in tqdm(unsupervised_loader):
        input_ul, target_ul, id_ul = batch
        input_ul, target_ul, id_ul = input_ul.to(device), target_ul.to(device), id_ul

        for i in range(0, int(input_ul.size(0))):
            folder_target = os.path.join(folder_name, "target")
            os.makedirs(folder_target, exist_ok=True)
#           image = input_ul[i].squeeze().detach()
#           image = torch.argmax(image, dim=0).cpu().numpy()
            image = input_ul[i].detach().cpu().numpy().transpose(1,2,0)
            image = np.clip(image * 255, 0, 255)
            image = Image.fromarray(image.astype(np.uint8))
            image.save(os.path.join(folder_target, str(id_ul[i]) + ".png"))

        # warm teacher
        model_t1.train()
        model_s.eval()
        optimizer_t1.zero_grad()
        loss_t1, outputs = model_t1(input_ul=input_ul, target_ul=input_ul, warm_up=True, mix_up=False)
        output_t1 = outputs["self_pred"]
        

        for i in range(0, int(output_t1.size(0))):
            folder_t1_prob = os.path.join(folder_name, "t1_prob")
            os.makedirs(folder_t1_prob, exist_ok=True)
#           image_prob = output_t1[i].squeeze().detach()
#           image_prob = torch.argmax(image_prob, dim=0).cpu().numpy()
            image_prob = output_t1[i].detach().cpu().numpy().transpose(1,2, 0)         
            image_prob = np.clip(image_prob * 255, 0, 255)
            image_prob = Image.fromarray(image_prob.astype(np.uint8))
            image_prob.save(os.path.join(folder_t1_prob, str(id_ul[i]) + ".png"))


        epoch_loss_t1 += loss_t1
        loss_t1.backward()
        optimizer_t1.step()
        
        # warm student
        model_t1.eval()
        model_s.train()
        optimizer_s.zero_grad()
        loss_s, outputs = model_s(x_ul=input_ul, warm_up=True, mix_up=False)
        output_s = outputs["self_pred"]
        
        for i in range(0, int(output_s.size(0))):
            folder_s_prob = os.path.join(folder_name, "s_prob")
            os.makedirs(folder_s_prob, exist_ok=True)
#           image_prob = output_s[i].squeeze().detach()
#           image_prob = torch.argmax(image_prob, dim=0).cpu().numpy()
            image_prob = output_s[i].detach().cpu().numpy().transpose(1, 2, 0)                           
            image_prob = np.clip(image_prob * 255, 0, 255)
            image_prob = Image.fromarray(image_prob.astype(np.uint8))
            image_prob.save(os.path.join(folder_s_prob, str(id_ul[i]) + ".png"))

        epoch_loss_s += loss_s
        loss_s.backward()
        optimizer_s.step()

    t1_mIoU, t1_mDice = count_index(folder_t1_prob, folder_target)
    s_mIoU, s_mDice = count_index(folder_s_prob, folder_target)

    print(f'Epoch{epoch+1} loss : \nteacher 1 loss = {epoch_loss_t1/len(supervised_set)},student loss = {epoch_loss_s/len(supervised_set)}') 
    print(f'teacher 1 dsc loss = {t1_mDice}, sutdent dsc loss = {s_mDice}')


    if s_mDice > 0.95:
        early_stopper(epoch_loss_s/len(supervised_set))
        print(f'teacher 1 : mIoU = {t1_mIoU}, DCS = {t1_mDice}')
        print(f'student : mIoU = {s_mIoU}, DCS = {s_mDice}')
        if early_stopper.early_stop: 
            print("Early stopping")   
            break

In [8]:
def freeze_teachers_parameters(model):
    for p in model.encoder.parameters():
        p.requires_grad = False
    for p in model.decoder.parameters():
        p.requires_grad = False

freeze_teachers_parameters(model_t1)

def update_teachers(teacher, student, keep_rate=0.996):
    student_encoder_dict = student.encoder.state_dict()
    student_decoder_dict = student.decoder.state_dict()
    new_teacher_encoder_dict = OrderedDict()
    new_teacher_decoder_dict = OrderedDict()

    for key, value in teacher.encoder.state_dict().items():

        if key in student_encoder_dict.keys():
            new_teacher_encoder_dict[key] = (
                    student_encoder_dict[key] * (1 - keep_rate) + value * keep_rate
            )
        else:
            raise Exception("{} is not found in student encoder model".format(key))

    for key, value in teacher.decoder.state_dict().items():

        if key in student_decoder_dict.keys():
            new_teacher_decoder_dict[key] = (
                    student_decoder_dict[key] * (1 - keep_rate) + value * keep_rate
            )
        else:
            raise Exception("{} is not found in student decoder model".format(key))
    teacher.encoder.load_state_dict(new_teacher_encoder_dict, strict=True)
    teacher.decoder.load_state_dict(new_teacher_decoder_dict, strict=True)

In [9]:
def predict_with_out_grad(model_t1, model_t2, image):
    with torch.no_grad():
        f = model_t1.encoder(image)
        _, predict_target_ul1 = model_t1.decoder(f, data_shape=[image.shape[-2], image.shape[-1]])
        f = model_t2.encoder(image)
        _, predict_target_ul2 = model_t2.decoder(f, data_shape=[image.shape[-2], image.shape[-1]])
        
        predict_target_ul1 = torch.nn.functional.interpolate(predict_target_ul1,
                                                                size=(image.shape[-2], image.shape[-1]),
                                                                mode='bilinear',
                                                                align_corners=True)

        predict_target_ul2 = torch.nn.functional.interpolate(predict_target_ul2,
                                                                size=(image.shape[-2], image.shape[-1]),
                                                                mode='bilinear',
                                                                align_corners=True)

        assert predict_target_ul1.shape == predict_target_ul2.shape, "Expect two prediction in same shape,"
    return predict_target_ul1, predict_target_ul2

In [None]:
# semi train
from Utils.early_stop import EarlyStopper

early_stopper = EarlyStopper(patience=5, delta=0.001)

best_model_t1_params = copy.deepcopy(model_t1.state_dict())
best_model_s_params = copy.deepcopy(model_s.state_dict())
do_best_Dice = 0
do_best_mIoU = 0
# do_best_hd95 = 100000
do_best_epoch = 0

for epoch in range(config['model']['epochs']):

    model_s.train()
    epoch_loss = 0
    epoch_dsc = 0

    dataloader = iter(zip(cycle(supervised_loader), unsupervised_loader))

    localtime = time.asctime( time.localtime(time.time()) )
    print('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['epochs'],localtime))
    print('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['epochs'],localtime)))

    tbar = tqdm(range(len(unsupervised_loader)))
    for batch_idx in tbar:
        (image_FA, image_ICG, target_l, id_l), (input_ul, target_ul, id_ul) = next(dataloader)
        image_FA, image_ICG, target_l, id_l = image_FA.to(device), image_ICG.to(device), target_l.to(device), id_l
        input_ul, target_ul, id_ul = input_ul.to(device), target_ul.to(device), id_ul
        print(input_ul.shape)
        optimizer_s.zero_grad()

        with torch.no_grad():
            f = model_t1.encoder(input_ul)
            _, predict_target_ul1 = model_t1.decoder(f, data_shape=[input_ul.shape[-2], input_ul.shape[-1]])
            
            predict_target_ul1 = torch.nn.functional.interpolate(predict_target_ul1,
                                                                    size=(input_ul.shape[-2], input_ul.shape[-1]),
                                                                    mode='bilinear',
                                                                    align_corners=True)

        total_loss, cur_losses, outputs = model_s(x_FA=image_FA, x_ICG=image_ICG, target_l=target_l,
                                                  x_ul=input_ul, target_ul=predict_target_ul1,
                                                  epoch=epoch, curr_iter=batch_idx, warm_up=False,
                                                  mix_up=False, t1=model_t1, t2=model_t1)
        
        epoch_loss += total_loss
        total_loss.backward()
        optimizer_s.step()

        with torch.no_grad():
            update_teachers(teacher=model_t1,
                            student=model_s)

    # validation
    # metric_list = 0.0
    model_s.eval()
    for batch in tqdm(val_loader):
        image_val, label, id_val = batch
        image_val, label, id_val = image_val.to(device), label.to(device), id_val

        H, W = label.size(1), label.size(2)
        up_sizes = (ceil(H / 8) * 8, ceil(W / 8) * 8)

        for i in range(0, int(label.size(0))):
            folder_val = os.path.join(folder_name, "val_original_target")
            os.makedirs(folder_val, exist_ok=True)
            image = Image.fromarray(np.uint8(label[i].detach().cpu().numpy()))
            image.save(os.path.join(folder_val, str(id_val[i]) + ".png"))

        data = torch.nn.functional.interpolate(image_val, size=(up_sizes[0], up_sizes[1]),
                                               mode='bilinear', align_corners=True)
        
        with torch.no_grad():
            f = model_t1.encoder(data)
            _, output = model_t1.decoder(f, data_shape=[data.shape[-2], data.shape[-1]])
        output = torch.nn.functional.interpolate(output, size=(H, W),
                                                 mode='bilinear', align_corners=True)

        for i in range(0, int(output.size(0))):
            folder_val_prob = os.path.join(folder_name, "val_original_prob")
            os.makedirs(folder_val_prob, exist_ok=True)
            image_prob = output[i].squeeze().detach()
            image_prob = torch.argmax(image_prob, dim=0).cpu().numpy()
            image_prob = Image.fromarray((image_prob * 255).astype(np.uint8))
            image_prob.save(os.path.join(folder_val_prob, str(id_val[i]) + ".png"))
        
    #     out = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze(0)
    #     metric_i = test_single_volume(label, out, classes=2)
    #     metric_list += np.array(metric_i)
    # metric_list = metric_list / len(val_set)
    # # index_mDice = np.mean(metric_list, axis=0)[0]
    # index_mean_hd95 = np.mean(metric_list, axis=0)[1]
    
    # show epoch mIoU, mDice
    index_mIoU, index_mDice = count_index(folder_val_prob, folder_val)
    # print(f'Epoch {epoch+1}' + " val:" + f'DSC: {index_mDice:.4f}, HD95 = {index_mean_hd95:.4f}')
    print(f'Epoch {epoch+1}' + " val:" + f'mIoU: {index_mIoU:.4f}, DSC: {index_mDice:.4f}')

    # find the best mIoU, mDice
    # if index_mDice > do_best_Dice and (index_mDice > do_best_Dice or index_mean_hd95 < do_best_hd95):
    if index_mDice > do_best_Dice:
        early_stopper.best_score = index_mDice
        early_stopper.counter = 0
        do_best_Dice = index_mDice
        do_best_mIoU = index_mIoU
        # do_best_hd95 = index_mean_hd95
        do_best_epoch = epoch+1
        best_model_t1_params = copy.deepcopy(model_t1.state_dict())
        best_model_s_params = copy.deepcopy(model_s.state_dict())
        # print("Change Best: " + f'epoch: {do_best_epoch} DSC: {do_best_Dice:.4f}, HD95 = {do_best_hd95:.4f}')
        print("Change Best: " + f'epoch : {do_best_epoch} mIoU: {do_best_mIoU:.4f}, DSC: {do_best_Dice:.4f}')

        # save the best valiation prod
        folder_val_best_prob = os.path.join(folder_name, "val_original_best_prob")
        os.makedirs(folder_val_best_prob, exist_ok=True)
        file_names = os.listdir(folder_val_prob)
        for file_name in file_names:
            source_path = os.path.join(folder_val_prob, file_name)
            destination_path = os.path.join(folder_val_best_prob, file_name)
            shutil.copyfile(source_path, destination_path)
    else:
        if epoch >= 5:
            early_stopper(index_mDice)
    
    if early_stopper.early_stop: 
        print("Early stopping")   
        break

    # show the best mIoU, mDice
    # print("Best model : " + f'epoch : {do_best_epoch}  DSC : {do_best_Dice:.4f}, HD95 = {do_best_hd95:.4f}')
    print("Best model : " + f'epoch : {do_best_epoch} mIoU: {do_best_mIoU:.4f}, DSC: {do_best_Dice:.4f}')
   

In [None]:
# save the best teacher1&2 and student model
model_path = None
model_path = os.path.join("saved_models")
os.makedirs(model_path, exist_ok=True)

model_params = os.path.join(model_path, "original_models")
os.makedirs(model_params, exist_ok=True)

torch.save(best_model_t1_params, os.path.join(model_params, f'original_epoch_{do_best_epoch}_dsc_{do_best_Dice:.4f}_best_t1.pth'))
torch.save(best_model_s_params, os.path.join(model_params, f'original_epoch_{do_best_epoch}_dsc_{do_best_Dice:.4f}_best_s.pth'))
print("Best model : " + f'epoch : {do_best_epoch}  DSC : {do_best_Dice:.4f}')