In [1]:
import os
import json
import logging
import copy
import shutil
import time

from tqdm import tqdm

import cv2
import numpy as np
from PIL import Image

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

from torchvision import transforms
from torchvision.transforms import functional as F

from math import ceil
from itertools import cycle
from itertools import chain
from collections import OrderedDict
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns

import random
random.seed(42)

import gc

In [2]:
#config
batch_size = 7
epochs = 80
warm_up = 100
labeled_examples = 515
lr = 1e-2
backbone = 101 #the resnet x {50, 101} layers
semi_p_th = 0.6 # positive_threshold for semi-supervised loss
semi_n_th = 0.0 # negative_threshold for semi-supervised loss
unsup_weight = 1.5 # unsupervised weight for semi-supervised loss

config = json.load(open("configs/config_deeplab_v3+_onlyFA_range_selfsupervised.json"))

config['train_supervised']['batch_size'] = batch_size
config['train_unsupervised']['batch_size'] = batch_size
config['warm_selfsupervised']['batch_size'] = batch_size
config['model']['epochs'] = epochs
config['model']['warm_up_epoch'] = warm_up
config['n_labeled_examples'] = labeled_examples
config['model']['resnet'] = backbone
config['optimizer']['args']['lr'] = lr
config['unsupervised_w'] = unsup_weight
config['model']['data_h_w'] = [config['train_supervised']['crop_size'], config['train_supervised']['crop_size']]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger = logging.getLogger("PS-MT")
logger.propagate = False
logger.warning("Training start, 总共 {} epochs".format(str(config['model']['epochs'])))
logger.critical("GPUs: {}".format(device))
logger.critical("DeeplabV3+ with ResNet {} backbone".format(str(config['model']['resnet'])))
logger.critical("Current Labeled Example: {}".format(config['n_labeled_examples']))
logger.critical("Learning rate: other {}, and head is the SAME [world]".format(config['optimizer']['args']['lr']))

logger.critical("Current batch: {} [world]".format(int(config['train_unsupervised']['batch_size']) +
                                                int(config['train_supervised']['batch_size'])) )

logger.critical("Current unsupervised loss function: {}, with weight {} and length {}".format(config['model']['un_loss'],
                                                                                            config['unsupervised_w'],
                                                                                            config['ramp_up']))


#Need to add self-supervised loss function info
print("\nconfig json :")
for i in config:
    print(i, config[i])

Training start, 总共 80 epochs
GPUs: cuda
DeeplabV3+ with ResNet 101 backbone
Current Labeled Example: 515
Learning rate: other 0.01, and head is the SAME [world]
Current batch: 14 [world]
Current unsupervised loss function: semi_ce, with weight 1.5 and length 12



config json :
name PS-MT(DeeplabV3+)
experim_name TEST_warm
n_labeled_examples 515
ramp_up 12
unsupervised_w 1.5
lr_scheduler Poly
gamma 0.5
model {'supervised': False, 'semi': True, 'resnet': 101, 'sup_loss': 'DE', 'un_loss': 'semi_ce', 'epochs': 80, 'warm_up_epoch': 100, 'data_h_w': [224, 224]}
optimizer {'type': 'SGD', 'args': {'lr': 0.01, 'weight_decay': 0.0001, 'momentum': 0.9}}
train_supervised {'data_dir': 'FA', 'batch_size': 7, 'shuffle': True, 'crop_size': 224, 'split': 'train_supervised', 'num_workers': 8}
train_unsupervised {'data_dir': 'range_unFA', 'batch_size': 7, 'shuffle': True, 'crop_size': 224, 'split': 'train_unsupervised', 'num_workers': 8}
warm_selfsupervised {'data_dir': 'range_unFA', 'batch_size': 7, 'shuffle': True, 'crop_size': 224, 'split': 'train_unsupervised', 'num_workers': 8}
val_loader {'data_dir': 'FA', 'batch_size': 1, 'split': 'val', 'shuffle': False, 'num_workers': 4}
test_loader {'data_dir': 'FA', 'batch_size': 1, 'split': 'test', 'shuffle': False, 

In [3]:
# DATA LOADERS
from DataLoader.dataset_onlyFA import *
choose_data = "All"

config['train_supervised']['choose'] = choose_data
config['train_unsupervised']['choose'] = choose_data
config['warm_selfsupervised']['choose'] = choose_data
config['val_loader']['choose'] = choose_data
config['test_loader']['choose'] = choose_data

print("train_supervised")
for i in config['train_supervised']:
    print("    ",i, ":", config['train_supervised'][i])
print("train_unsupervised")
for i in config['train_unsupervised']:
    print("    ", i, ":", config['train_unsupervised'][i])
print("warm_selfsupervised")
for i in config['warm_selfsupervised']:
    print("    ", i, ":", config['warm_selfsupervised'][i])
print("val_loader")
for i in config['val_loader']:
    print("    ", i, ":", config['val_loader'][i])
print("test_loader")
for i in config['test_loader']:
    print("    ", i, ":", config['test_loader'][i])

supervised_set = BasicDataset(data_dir=config['train_supervised']['data_dir'], 
                                 choose=config['train_supervised']['choose'],
                                 split=config['train_supervised']['split'])

unsupervised_set = BasicDataset(data_dir=config['train_unsupervised']['data_dir'],
                                   choose=config['train_unsupervised']['choose'],
                                   split=config['train_unsupervised']['split'])

warm_selfsupervised_set = BasicDataset(data_dir=config['warm_selfsupervised']['data_dir'],
                                      choose=config['warm_selfsupervised']['choose'],
                                      split=config['warm_selfsupervised']['split'])

val_set = BasicDataset(data_dir=config['val_loader']['data_dir'],
                          choose=config['val_loader']['choose'],
                          split=config['val_loader']['split'])

test_set = BasicDataset(data_dir=config['test_loader']['data_dir'],
                          choose=config['test_loader']['choose'],
                          split=config['test_loader']['split'])

print("supervised_loader: ",len(supervised_set))
print("unsupervised_loader: ",len(unsupervised_set))
print("warm_selfsupervised_loader: ",len(warm_selfsupervised_set))
print("val_loader: ",len(val_set))
print("test_loader: ",len(test_set))


supervised_loader = DataLoader(dataset=supervised_set, batch_size=config['train_supervised']['batch_size'],
                               shuffle=config['train_supervised']['shuffle'], 
                               num_workers=config['train_supervised']['num_workers'])

unsupervised_loader = DataLoader(dataset=unsupervised_set, batch_size=config['train_unsupervised']['batch_size'],
                               shuffle=config['train_unsupervised']['shuffle'], 
                               num_workers=config['train_unsupervised']['num_workers'])
                               
warm_selfsupervised_loader = DataLoader(dataset=unsupervised_set, batch_size=config['warm_selfsupervised']['batch_size'],
                              shuffle=config['warm_selfsupervised']['shuffle'],
                              num_workers=config['warm_selfsupervised']['num_workers'])

val_loader = DataLoader(dataset=val_set, batch_size=config['val_loader']['batch_size'],
                               shuffle=config['val_loader']['shuffle'], 
                               num_workers=config['val_loader']['num_workers'])

test_loader = DataLoader(dataset=test_set, batch_size=config['test_loader']['batch_size'],
                               shuffle=config['test_loader']['shuffle'], 
                               num_workers=config['test_loader']['num_workers'])



train_supervised
     data_dir : FA
     batch_size : 7
     shuffle : True
     crop_size : 224
     split : train_supervised
     num_workers : 8
     choose : All
train_unsupervised
     data_dir : range_unFA
     batch_size : 7
     shuffle : True
     crop_size : 224
     split : train_unsupervised
     num_workers : 8
     choose : All
warm_selfsupervised
     data_dir : range_unFA
     batch_size : 7
     shuffle : True
     crop_size : 224
     split : train_unsupervised
     num_workers : 8
     choose : All
val_loader
     data_dir : FA
     batch_size : 1
     split : val
     shuffle : False
     num_workers : 4
     choose : All
test_loader
     data_dir : FA
     batch_size : 1
     split : test
     shuffle : False
     num_workers : 4
     choose : All
supervised_loader:  515
unsupervised_loader:  5263
warm_selfsupervised_loader:  5263
val_loader:  162
test_loader:  157


In [4]:
#model setting (1 teacher + 1 student)
from torch import optim

from Utils.losses import ConsistencyWeight
from Model.selfsupervised.selfsupervised_model import *
from Utils.ramps import *

cons_w_unsup = ConsistencyWeight(final_w=config['unsupervised_w'], iters_per_epoch=len(unsupervised_loader),
                                 rampup_starts=0, rampup_ends=config['ramp_up'],  ramp_type="cosine_rampup")


model_t1 = Teacher_Net(num_classes=2, config=config['model'])
model_t1 = model_t1.to(device)

model_s = Student_Net(num_classes=2, config=config['model'],  cons_w_unsup=cons_w_unsup)
model_s = model_s.to(device)

optimizer_t1 = optim.SGD(model_t1.parameters(), 
                      lr=config['optimizer']['args']['lr'],
                      momentum=config['optimizer']['args']['momentum'],
                      weight_decay=config['optimizer']['args']['weight_decay'])

optimizer_s = optim.SGD(model_s.parameters(), 
                      lr=config['optimizer']['args']['lr'],
                      momentum=config['optimizer']['args']['momentum'],
                      weight_decay=config['optimizer']['args']['weight_decay'])

In [5]:
#calculate metrics
from sklearn import metrics, neighbors
from sklearn.metrics import confusion_matrix

#compute mean IoU & DSC  of given outputs & targets per patient (5 time points)
def information_index(outputs, targets):
    eps = np.finfo(np.float64).eps
    output = outputs.flatten()
    target = targets.flatten()
    
    # Compute the confusion matrix
    cm = confusion_matrix(target, output).ravel()
    
    # Handle different shapes of confusion matrix
    if len(cm) == 4:
        TN, FP, FN, TP = cm
    elif len(cm) == 1:
        # Case where only one class is present in both target and output
        if target[0] == 0:
            TN, FP, FN, TP = cm[0], 0, 0, 0  # All True Negatives
        else:
            TN, FP, FN, TP = 0, 0, 0, cm[0]  # All True Positives
    elif len(cm) == 2:
        # Case where target and output contain only a single class
        if target[0] == 0:
            TN, FP, FN, TP = cm[0], cm[1], 0, 0  # No False Negatives or True Positives
        else:
            TN, FP, FN, TP = 0, 0, cm[0], cm[1]  # No True Negatives or False Positives
    else:
        raise ValueError("Unexpected confusion matrix size.")

    # Compute IoU and Dice coefficients
    index_MIou = (TP / (TP + FP + FN + eps) + TN / (TN + FN + FP + eps)) / 2
    mean_iou = np.mean(index_MIou)
    index_dice = 2 * TP / (2 * TP + FP + FN + eps)
    mean_dice = np.mean(index_dice)

    return mean_iou, mean_dice
#compute mean IoU & DSC of validaton (test set)
def count_index(pre, tar):
        path_pre = pre
        path_target = tar
        dirs = os.listdir(path_pre)
        # print(len(dirs))
        con_mIOU = 0
        con_mdice = 0
        for imgs in dirs:
            pre_path = path_pre + '/' + str(imgs)
            target_path = path_target + '/' + str(imgs)

            target = cv2.imread(target_path, cv2.IMREAD_GRAYSCALE)
            _, tar = cv2.threshold(target, 128, 1, cv2.THRESH_BINARY)

            predict = cv2.imread(pre_path, cv2.IMREAD_GRAYSCALE)
            _, pre = cv2.threshold(predict, 128, 1, cv2.THRESH_BINARY)

            tIOU, tdice = information_index(pre,tar)
            con_mIOU += tIOU
            con_mdice += tdice
        val_mIoU = con_mIOU/len(dirs)
        val_mDice = con_mdice/len(dirs)
        
        return val_mIoU, val_mDice

In [6]:
# warm
import torch.nn.functional
from Utils.early_stop import EarlyStopping

early_stopper = EarlyStopping(patience=5, delta=0.0001)

for epoch in range(config['model']['warm_up_epoch']):

    epoch_loss_t1 = 0
    epoch_loss_s = 0
    t1_mIoU, t1_mDice = 0, 0
    s_mIoU, s_mDice = 0, 0
    
    localtime = time.asctime( time.localtime(time.time()) )
    print('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['warm_up_epoch'],localtime))
    print('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['warm_up_epoch'],localtime)))

    folder_name = None
    folder_name = os.path.join("see_image", "self-warm")
    os.makedirs(folder_name, exist_ok=True)

    for batch in tqdm(warm_selfsupervised_loader):
        input_ul, target_ul, id_ul = batch
        input_ul, target_ul, id_ul = input_ul.to(device), target_ul.to(device), id_ul


        for i in range(0, int(target_ul.size(0))):
            folder_target = os.path.join(folder_name, "target")
            os.makedirs(folder_target, exist_ok=True)
            image = Image.fromarray(np.int8(target_ul[i].detach().cpu().numpy()))
            image.save(os.path.join(folder_target, str(id_ul[i]) + ".png"))

        # warm teacher
        model_t1.train()
        model_s.eval()
        optimizer_t1.zero_grad()
        loss_t1, outputs = model_t1(input_ul=input_ul, target_ul=input_ul, warm_up=True, mix_up=False)
        output_t1 = outputs["self_pred"]
        

        for i in range(0, int(output_t1.size(0))):
            folder_t1_prob = os.path.join(folder_name, "t1_prob")
            os.makedirs(folder_t1_prob, exist_ok=True)
            image_prob = output_t1[i].squeeze().detach()
            image_prob = torch.argmax(image_prob, dim=0).cpu().numpy()
            image_prob = Image.fromarray((image_prob * 255).astype(np.uint8))
            image_prob.save(os.path.join(folder_t1_prob, str(id_ul[i]) + ".png"))


        epoch_loss_t1 += loss_t1
        loss_t1.backward()
        optimizer_t1.step()
        
        # warm student
        model_t1.eval()
        model_s.train()
        optimizer_s.zero_grad()
        loss_s, outputs = model_s(x_ul=input_ul, warm_up=True, mix_up=False)
        output_s = outputs["self_pred"]
        
        for i in range(0, int(output_s.size(0))):
            folder_s_prob = os.path.join(folder_name, "s_prob")
            os.makedirs(folder_s_prob, exist_ok=True)
            image_prob = output_s[i].squeeze().detach()
            image_prob = torch.argmax(image_prob, dim=0).cpu().numpy()
            image_prob = Image.fromarray((image_prob * 255).astype(np.uint8))
            image_prob.save(os.path.join(folder_s_prob, str(id_ul[i]) + ".png"))

        epoch_loss_s += loss_s
        loss_s.backward()
        optimizer_s.step()

    t1_mIoU, t1_mDice = count_index(folder_t1_prob, folder_target)
    s_mIoU, s_mDice = count_index(folder_s_prob, folder_target)

    print(f'Epoch{epoch+1} loss : \nteacher 1 loss = {epoch_loss_t1/len(supervised_set)},student loss = {epoch_loss_s/len(supervised_set)}') 


    if s_mDice > 0.95:
        early_stopper(epoch_loss_s/len(supervised_set))
        print(f'teacher 1 : mIoU = {t1_mIoU}, DCS = {t1_mDice}')
        print(f'student : mIoU = {s_mIoU}, DCS = {s_mDice}')
        if early_stopper.early_stop: 
            print("Early stopping")   
            break

  0%|          | 0/752 [00:00<?, ?it/s]

Epoch: 1/100 --- < Starting Time : Mon Sep 23 22:57:32 2024 >
-------------------------------------------------------------


100%|██████████| 752/752 [02:13<00:00,  5.63it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch1 loss : 
teacher 1 loss = 0.044037364423274994,student loss = 0.04407580941915512
Epoch: 2/100 --- < Starting Time : Mon Sep 23 23:00:38 2024 >
-------------------------------------------------------------


100%|██████████| 752/752 [02:36<00:00,  4.81it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch2 loss : 
teacher 1 loss = 0.04334758594632149,student loss = 0.04333049803972244
Epoch: 3/100 --- < Starting Time : Mon Sep 23 23:04:10 2024 >
-------------------------------------------------------------


100%|██████████| 752/752 [02:41<00:00,  4.67it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch3 loss : 
teacher 1 loss = 0.04336659237742424,student loss = 0.043339282274246216
Epoch: 4/100 --- < Starting Time : Mon Sep 23 23:07:48 2024 >
-------------------------------------------------------------


100%|██████████| 752/752 [02:43<00:00,  4.59it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch4 loss : 
teacher 1 loss = 0.043422769755125046,student loss = 0.043397240340709686
Epoch: 5/100 --- < Starting Time : Mon Sep 23 23:11:27 2024 >
-------------------------------------------------------------


100%|██████████| 752/752 [02:44<00:00,  4.57it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch5 loss : 
teacher 1 loss = 0.04341918230056763,student loss = 0.043406181037425995
Epoch: 6/100 --- < Starting Time : Mon Sep 23 23:15:09 2024 >
-------------------------------------------------------------


100%|██████████| 752/752 [02:44<00:00,  4.57it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch6 loss : 
teacher 1 loss = 0.04336138814687729,student loss = 0.04329809918999672
Epoch: 7/100 --- < Starting Time : Mon Sep 23 23:18:50 2024 >
-------------------------------------------------------------


100%|██████████| 752/752 [02:44<00:00,  4.57it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch7 loss : 
teacher 1 loss = 0.04340789467096329,student loss = 0.0433475598692894
Epoch: 8/100 --- < Starting Time : Mon Sep 23 23:22:31 2024 >
-------------------------------------------------------------


100%|██████████| 752/752 [02:46<00:00,  4.51it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch8 loss : 
teacher 1 loss = 0.04333417862653732,student loss = 0.04327346384525299
Epoch: 9/100 --- < Starting Time : Mon Sep 23 23:26:13 2024 >
-------------------------------------------------------------


100%|██████████| 752/752 [02:46<00:00,  4.52it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch9 loss : 
teacher 1 loss = 0.0433523990213871,student loss = 0.0432744026184082
Epoch: 10/100 --- < Starting Time : Mon Sep 23 23:29:57 2024 >
--------------------------------------------------------------


100%|██████████| 752/752 [02:45<00:00,  4.54it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch10 loss : 
teacher 1 loss = 0.04330949857831001,student loss = 0.04316982626914978
Epoch: 11/100 --- < Starting Time : Mon Sep 23 23:33:39 2024 >
--------------------------------------------------------------


100%|██████████| 752/752 [02:45<00:00,  4.55it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch11 loss : 
teacher 1 loss = 0.043358657509088516,student loss = 0.04261606186628342
Epoch: 12/100 --- < Starting Time : Mon Sep 23 23:37:21 2024 >
--------------------------------------------------------------


100%|██████████| 752/752 [02:44<00:00,  4.58it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch12 loss : 
teacher 1 loss = 0.04324383661150932,student loss = 0.028575142845511436
Epoch: 13/100 --- < Starting Time : Mon Sep 23 23:41:01 2024 >
--------------------------------------------------------------


100%|██████████| 752/752 [02:51<00:00,  4.39it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch13 loss : 
teacher 1 loss = 0.043027911335229874,student loss = 0.024160167202353477
Epoch: 14/100 --- < Starting Time : Mon Sep 23 23:44:53 2024 >
--------------------------------------------------------------


100%|██████████| 752/752 [02:45<00:00,  4.55it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch14 loss : 
teacher 1 loss = 0.03464961796998978,student loss = 0.022775663062930107
Epoch: 15/100 --- < Starting Time : Mon Sep 23 23:48:38 2024 >
--------------------------------------------------------------


100%|██████████| 752/752 [02:51<00:00,  4.39it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch15 loss : 
teacher 1 loss = 0.024566488340497017,student loss = 0.021190933883190155
Epoch: 16/100 --- < Starting Time : Mon Sep 23 23:52:30 2024 >
--------------------------------------------------------------


100%|██████████| 752/752 [02:53<00:00,  4.33it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch16 loss : 
teacher 1 loss = 0.0231124609708786,student loss = 0.018612105399370193
Epoch: 17/100 --- < Starting Time : Mon Sep 23 23:56:23 2024 >
--------------------------------------------------------------


100%|██████████| 752/752 [02:50<00:00,  4.41it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch17 loss : 
teacher 1 loss = 0.021575195714831352,student loss = 0.01592775247991085
Epoch: 18/100 --- < Starting Time : Tue Sep 24 00:00:11 2024 >
--------------------------------------------------------------


100%|██████████| 752/752 [02:48<00:00,  4.45it/s]
  0%|          | 0/752 [00:00<?, ?it/s]

Epoch18 loss : 
teacher 1 loss = 0.01938559114933014,student loss = 0.014094416983425617
Epoch: 19/100 --- < Starting Time : Tue Sep 24 00:03:58 2024 >
--------------------------------------------------------------


100%|██████████| 752/752 [02:47<00:00,  4.48it/s]


In [7]:
#keep the parameters of teacher unchanged
def freeze_teachers_parameters(model):
    for p in model.maskAutoEncoder.encoder.parameters():
        p.requires_grad = False
    for p in model.maskAutoEncoder.decoder.parameters():
        p.requires_grad = False

freeze_teachers_parameters(model_t1)

# update teacher's parameters according to student's parameters
def update_teachers(teacher, student, keep_rate=0.996):
    student_encoder_dict = student.encoder.state_dict()
    student_decoder_dict = student.decoder.state_dict()
    new_teacher_encoder_dict = OrderedDict()
    new_teacher_decoder_dict = OrderedDict()

    for key, value in teacher.encoder.state_dict().items():

        if key in student_encoder_dict.keys():
            new_teacher_encoder_dict[key] = (
                    student_encoder_dict[key] * (1 - keep_rate) + value * keep_rate
            )
        else:
            raise Exception("{} is not found in student encoder model".format(key))

    for key, value in teacher.decoder.state_dict().items():

        if key in student_decoder_dict.keys():
            new_teacher_decoder_dict[key] = (
                    student_decoder_dict[key] * (1 - keep_rate) + value * keep_rate
            )
        else:
            raise Exception("{} is not found in student decoder model".format(key))
    teacher.encoder.load_state_dict(new_teacher_encoder_dict, strict=True)
    teacher.decoder.load_state_dict(new_teacher_decoder_dict, strict=True)

In [None]:
# semi train
from Utils.early_stop import EarlyStopper

early_stopper = EarlyStopper(patience=5, delta=0.001)

best_model_t1_params = copy.deepcopy(model_t1.state_dict())
best_model_s_params = copy.deepcopy(model_s.state_dict())
do_best_Dice = 0
do_best_mIoU = 0
# do_best_hd95 = 100000
do_best_epoch = 0

for epoch in range(config['model']['epochs']):

    model_s.train()
    epoch_loss = 0
    epoch_dsc = 0

    dataloader = iter(zip(cycle(supervised_loader), unsupervised_loader))

    localtime = time.asctime( time.localtime(time.time()) )
    print('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['epochs'],localtime))
    print('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['epochs'],localtime)))

    tbar = tqdm(range(len(unsupervised_loader)))
    for batch_idx in tbar:
        (image_FA, image_ICG, target_l, id_l), (input_ul, target_ul, id_ul) = next(dataloader)
        image_FA, image_ICG, target_l, id_l = image_FA.to(device), image_ICG.to(device), target_l.to(device), id_l
        input_ul, target_ul, id_ul = input_ul.to(device), target_ul.to(device), id_ul
        input_ul = torch.nn.functional.interpolate(input_ul, size=(input_ul.shape[-2], input_ul.shape[-1]), mode='bilinear', align_corners=True)
        
        optimizer_s.zero_grad()

        with torch.no_grad():
           loss_t1, predict_target_ul1 = model_t1(input_ul=input_ul, target_ul=target_ul)
           predict_target_ul1 = torch.nn.functional.interpolate(predict_target_ul1,
                                                                    size=(input_ul.shape[-2], input_ul.shape[-1]),
                                                                    mode='bilinear',
                                                                    align_corners=True)
           
        total_loss, cur_losses, outputs = model_s(x_FA=image_FA, x_ICG=image_ICG, target_l=target_l,
                                                  x_ul=input_ul, target_ul=predict_target_ul1,
                                                  epoch=epoch, curr_iter=batch_idx, warm_up=False,
                                                  mix_up=False, t1=model_t1, t2=model_t1)
        
        epoch_loss += total_loss
        total_loss.backward()
        optimizer_s.step()

        with torch.no_grad():
            update_teachers(teacher=model_t1.segmentationMAE,
                            student=model_s.segmentationMAE)

    # validation
    # metric_list = 0.0
    folder_name = os.path.join("see_image", "val")
    model_s.eval()
    for batch in tqdm(val_loader):
        image_val, label, id_val = batch
        image_val, label, id_val = image_val.to(device), label.to(device), id_val

        H, W = label.size(1), label.size(2)
        up_sizes = (ceil(H / 8) * 8, ceil(W / 8) * 8)

        for i in range(0, int(label.size(0))):
            folder_val = os.path.join(folder_name, "val_original_target")
            os.makedirs(folder_val, exist_ok=True)
            image = Image.fromarray(np.uint8(label[i].detach().cpu().numpy()))
            image.save(os.path.join(folder_val, str(id_val[i]) + ".png"))

        data = torch.nn.functional.interpolate(image_val, size=(up_sizes[0], up_sizes[1]),
                                               mode='bilinear', align_corners=True)
        
        with torch.no_grad():
            loss_t1, output = model_t1(input_ul=data,  target_ul=label)
            
            output = torch.nn.functional.interpolate(output, size=(H, W),
                                                 mode='bilinear', align_corners=True)

        for i in range(0, int(output.size(0))):
            folder_val_prob = os.path.join(folder_name, "val_original_prob")
            os.makedirs(folder_val_prob, exist_ok=True)
            image_prob = output[i].squeeze().detach()
            image_prob = torch.argmax(image_prob, dim=0).cpu().numpy()
            image_prob = Image.fromarray((image_prob * 255).astype(np.uint8))
            image_prob.save(os.path.join(folder_val_prob, str(id_val[i]) + ".png"))
        

    
    # show epoch mIoU, mDice
    index_mIoU, index_mDice = count_index(folder_val_prob, folder_val)
    # print(f'Epoch {epoch+1}' + " val:" + f'DSC: {index_mDice:.4f}, HD95 = {index_mean_hd95:.4f}')
    print(f'Epoch {epoch+1}' + " val:" + f'mIoU: {index_mIoU:.4f}, DSC: {index_mDice:.4f}')

    # find the best mIoU, mDice
    # if index_mDice > do_best_Dice and (index_mDice > do_best_Dice or index_mean_hd95 < do_best_hd95):
    if index_mDice > do_best_Dice:
        early_stopper.best_score = index_mDice
        early_stopper.counter = 0
        do_best_Dice = index_mDice
        do_best_mIoU = index_mIoU
        # do_best_hd95 = index_mean_hd95
        do_best_epoch = epoch+1
        best_model_t1_params = copy.deepcopy(model_t1.state_dict())
        best_model_s_params = copy.deepcopy(model_s.state_dict())
        # print("Change Best: " + f'epoch: {do_best_epoch} DSC: {do_best_Dice:.4f}, HD95 = {do_best_hd95:.4f}')
        print("Change Best: " + f'epoch : {do_best_epoch} mIoU: {do_best_mIoU:.4f}, DSC: {do_best_Dice:.4f}')

        # save the best valiation prod
        folder_val_best_prob = os.path.join(folder_name, "val_original_best_prob")
        os.makedirs(folder_val_best_prob, exist_ok=True)
        file_names = os.listdir(folder_val_prob)
        for file_name in file_names:
            source_path = os.path.join(folder_val_prob, file_name)
            destination_path = os.path.join(folder_val_best_prob, file_name)
            shutil.copyfile(source_path, destination_path)
    else:
        if epoch >= 5:
            early_stopper(index_mDice)
    
    if early_stopper.early_stop: 
        print("Early stopping")   
        break

    # show the best mIoU, mDice
    # print("Best model : " + f'epoch : {do_best_epoch}  DSC : {do_best_Dice:.4f}, HD95 = {do_best_hd95:.4f}')
    print("Best model : " + f'epoch : {do_best_epoch} mIoU: {do_best_mIoU:.4f}, DSC: {do_best_Dice:.4f}')
   

In [None]:
# save the best teacher1&2 and student model
model_path = None
model_path = os.path.join("saved_models")
os.makedirs(model_path, exist_ok=True)

model_params = os.path.join(model_path, "original_models")
os.makedirs(model_params, exist_ok=True)

torch.save(best_model_t1_params, os.path.join(model_params, f'original_epoch_{do_best_epoch}_dsc_{do_best_Dice:.4f}_best_t1.pth'))
torch.save(best_model_s_params, os.path.join(model_params, f'original_epoch_{do_best_epoch}_dsc_{do_best_Dice:.4f}_best_s.pth'))
print("Best model : " + f'epoch : {do_best_epoch}  DSC : {do_best_Dice:.4f}')