Mutil Mean Teacher

original
FAICG_FA_FA

backbone=DeeplabV3+

In [1]:
import os
import json
import logging
import copy
import shutil
import time

from tqdm import tqdm

import cv2
import numpy as np
from PIL import Image

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

from torchvision import transforms
from torchvision.transforms import functional as F

from math import ceil
from itertools import cycle
from itertools import chain
from collections import OrderedDict
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns

import random
random.seed(42)

import gc

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#config
batch_size = 8
epochs = 80
warm_up = 100
labeled_examples = 515
lr = 1e-2
backbone = 101 #" the resnet x {50, 101} layers"
semi_p_th = 0.6 # positive_threshold for semi-supervised loss
semi_n_th = 0.0 # negative_threshold for semi-supervised loss
unsup_weight = 1.5 # unsupervised weight for the semi-supervised loss

config = json.load(open("configs/config_deeplab_v3+_onlyFA_range.json"))

config['train_supervised']['batch_size'] = batch_size
config['train_unsupervised']['batch_size'] = batch_size
config['model']['epochs'] = epochs
config['model']['warm_up_epoch'] = warm_up
config['n_labeled_examples'] = labeled_examples
config['model']['resnet'] = backbone
config['optimizer']['args']['lr'] = lr
config['unsupervised_w'] = unsup_weight
config['model']['data_h_w'] = [config['train_supervised']['crop_size'], config['train_supervised']['crop_size']]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

logger = logging.getLogger("PS-MT")
logger.propagate = False
logger.warning("Training start, 总共 {} epochs".format(str(config['model']['epochs'])))
logger.critical("GPUs: {}".format(device))
logger.critical("DeeplabV3+ with ResNet {} backbone".format(str(config['model']['resnet'])))
logger.critical("Current Labeled Example: {}".format(config['n_labeled_examples']))
logger.critical("Learning rate: other {}, and head is the SAME [world]".format(config['optimizer']['args']['lr']))

logger.critical("Current batch: {} [world]".format(int(config['train_unsupervised']['batch_size']) +
                                                int(config['train_supervised']['batch_size'])) )

logger.critical("Current unsupervised loss function: {}, with weight {} and length {}".format(config['model']['un_loss'],
                                                                                            config['unsupervised_w'],
                                                                                            config['ramp_up']))
print("\nconfig json :")
for i in config:
    print(i, config[i])


Training start, 总共 80 epochs
GPUs: cuda
DeeplabV3+ with ResNet 101 backbone
Current Labeled Example: 515
Learning rate: other 0.01, and head is the SAME [world]
Current batch: 16 [world]
Current unsupervised loss function: semi_ce, with weight 1.5 and length 12



config json :
name PS-MT(DeeplabV3+)
experim_name TEST_warm
n_labeled_examples 515
ramp_up 12
unsupervised_w 1.5
lr_scheduler Poly
gamma 0.5
model {'supervised': False, 'semi': True, 'resnet': 101, 'sup_loss': 'DE', 'un_loss': 'semi_ce', 'epochs': 80, 'warm_up_epoch': 100, 'data_h_w': [224, 224]}
optimizer {'type': 'SGD', 'args': {'lr': 0.01, 'weight_decay': 0.0001, 'momentum': 0.9}}
train_supervised {'data_dir': 'FA', 'batch_size': 8, 'shuffle': True, 'crop_size': 224, 'split': 'train_supervised', 'num_workers': 8}
train_unsupervised {'data_dir': 'range_unFA', 'batch_size': 8, 'shuffle': True, 'crop_size': 224, 'split': 'train_unsupervised', 'num_workers': 8}
val_loader {'data_dir': 'FA', 'batch_size': 1, 'split': 'val', 'shuffle': False, 'num_workers': 4}
test_loader {'data_dir': 'FA', 'batch_size': 1, 'split': 'test', 'shuffle': False, 'num_workers': 4}


In [3]:
# DATA LOADERS
from DataLoader.dataset_onlyFA import *
choose_data = "All"

config['train_supervised']['choose'] = choose_data
config['train_unsupervised']['choose'] = choose_data
config['val_loader']['choose'] = choose_data
config['test_loader']['choose'] = choose_data

print("train_supervised")
for i in config['train_supervised']:
    print("    ",i, ":", config['train_supervised'][i])
print("train_unsupervised")
for i in config['train_unsupervised']:
    print("    ", i, ":", config['train_unsupervised'][i])
print("val_loader")
for i in config['val_loader']:
    print("    ", i, ":", config['val_loader'][i])
print("test_loader")
for i in config['test_loader']:
    print("    ", i, ":", config['test_loader'][i])

supervised_set = BasicDataset(data_dir=config['train_supervised']['data_dir'], 
                                 choose=config['train_supervised']['choose'],
                                 split=config['train_supervised']['split'])
unsupervised_set = BasicDataset(data_dir=config['train_unsupervised']['data_dir'],
                                   choose=config['train_unsupervised']['choose'],
                                   split=config['train_unsupervised']['split'])
val_set = BasicDataset(data_dir=config['val_loader']['data_dir'],
                          choose=config['val_loader']['choose'],
                          split=config['val_loader']['split'])

test_set = BasicDataset(data_dir=config['test_loader']['data_dir'],
                          choose=config['test_loader']['choose'],
                          split=config['test_loader']['split'])

print("supervised_loader: ",len(supervised_set))
print("unsupervised_loader: ",len(unsupervised_set))
print("val_loader: ",len(val_set))
print("test_loader: ",len(test_set))


supervised_loader = DataLoader(dataset=supervised_set, batch_size=config['train_supervised']['batch_size'],
                               shuffle=config['train_supervised']['shuffle'], 
                               num_workers=config['train_supervised']['num_workers'])
unsupervised_loader = DataLoader(dataset=unsupervised_set, batch_size=config['train_unsupervised']['batch_size'],
                               shuffle=config['train_unsupervised']['shuffle'], 
                               num_workers=config['train_unsupervised']['num_workers'])
val_loader = DataLoader(dataset=val_set, batch_size=config['val_loader']['batch_size'],
                               shuffle=config['val_loader']['shuffle'], 
                               num_workers=config['val_loader']['num_workers'])
test_loader = DataLoader(dataset=test_set, batch_size=config['test_loader']['batch_size'],
                               shuffle=config['test_loader']['shuffle'], 
                               num_workers=config['test_loader']['num_workers'])



train_supervised
     data_dir : FA
     batch_size : 8
     shuffle : True
     crop_size : 224
     split : train_supervised
     num_workers : 8
     choose : All
train_unsupervised
     data_dir : range_unFA
     batch_size : 8
     shuffle : True
     crop_size : 224
     split : train_unsupervised
     num_workers : 8
     choose : All
val_loader
     data_dir : FA
     batch_size : 1
     split : val
     shuffle : False
     num_workers : 4
     choose : All
test_loader
     data_dir : FA
     batch_size : 1
     split : test
     shuffle : False
     num_workers : 4
     choose : All
supervised_loader:  515
unsupervised_loader:  5263
val_loader:  162
test_loader:  157


In [4]:
#model setting (1 teacher + 1 student)
from torch import optim

from Model.Deeplabv3_plus.psmt_model import *
from Utils.ramps import *

cons_w_unsup = ConsistencyWeight(final_w=config['unsupervised_w'], iters_per_epoch=len(unsupervised_loader),
                                 rampup_starts=0, rampup_ends=config['ramp_up'],  ramp_type="cosine_rampup")


model_t1 = Teacher_Net(num_classes=2, config=config['model'])
model_t1 = model_t1.to(device)

model_s = Student_Net(num_classes=2, config=config['model'],  cons_w_unsup=cons_w_unsup)
model_s = model_s.to(device)

optimizer_t1 = optim.SGD(model_t1.parameters(), 
                      lr=config['optimizer']['args']['lr'],
                      momentum=config['optimizer']['args']['momentum'],
                      weight_decay=config['optimizer']['args']['weight_decay'])

optimizer_s = optim.SGD(model_s.parameters(), 
                      lr=config['optimizer']['args']['lr'],
                      momentum=config['optimizer']['args']['momentum'],
                      weight_decay=config['optimizer']['args']['weight_decay'])

Load model, Time usage:
	IO: 1.0073959827423096, initialize parameters: 0.03242826461791992


In [5]:
#calculate metrics
from sklearn import metrics, neighbors
from sklearn.metrics import confusion_matrix

#compute mean IoU & DSC  of given outputs & targets per patient (5 time points)
def information_index(outputs, targets):
    eps = np.finfo(np.float64).eps
    output = outputs.flatten()
    target = targets.flatten()
    TN, FP, FN, TP = confusion_matrix(target,output).ravel()

    index_MIou =  ( TP / (TP + FP + FN + eps) + TN / (TN + FN + FP + eps) ) / 2
    mean_iou = np.mean(index_MIou)
    index_dice = 2*TP / (2*TP + FP + FN + eps)
    mean_dice = np.mean(index_dice)

    return mean_iou, mean_dice

#compute mean IoU & DSC of validaton (test set)
def count_index(pre, tar):
        path_pre = pre
        path_target = tar
        dirs = os.listdir(path_pre)
        # print(len(dirs))
        con_mIOU = 0
        con_mdice = 0
        for imgs in dirs:
            pre_path = path_pre + '/' + str(imgs)
            target_path = path_target + '/' + str(imgs)

            target = cv2.imread(target_path, cv2.IMREAD_GRAYSCALE)
            _, tar = cv2.threshold(target, 128, 255, cv2.THRESH_BINARY)

            predict = cv2.imread(pre_path, cv2.IMREAD_GRAYSCALE)
            _, pre = cv2.threshold(predict, 128, 255, cv2.THRESH_BINARY)

            tIOU, tdice = information_index(pre,tar)
            con_mIOU += tIOU
            con_mdice += tdice
        val_mIoU = con_mIOU/len(dirs)
        val_mDice = con_mdice/len(dirs)
        
        return val_mIoU, val_mDice

In [6]:
# from medpy import metric
# def calculate_metric_percase(pred, gt):
#     pred[pred > 0] = 1
#     gt[gt > 0] = 1
#     if pred.sum() > 0:
#         dice = metric.binary.dc(pred, gt)
#         hd95 = metric.binary.hd95(pred, gt)
#         return dice, hd95
#     else:
#         return 0, 0
# def test_single_volume(label, output, classes=2):
#     label = torch.clamp(label, 0, 1)
#     label = label.squeeze(0).cpu().detach().numpy()
#     output = output.cpu().detach().numpy()
#     metric_list = []
#     for i in range(1, classes):
#         metric_list.append(calculate_metric_percase(
#             output == i, label == i))
    
#     return metric_list

In [7]:
# warm
from Utils.early_stop import EarlyStopping

early_stopper = EarlyStopping(patience=5, delta=0.0001)

for epoch in range(config['model']['warm_up_epoch']):

    epoch_loss_t1 = 0
    epoch_loss_s = 0
    t1_mIoU, t1_mDice = 0, 0
    s_mIoU, s_mDice = 0, 0
    
    localtime = time.asctime( time.localtime(time.time()) )
    print('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['warm_up_epoch'],localtime))
    print('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['warm_up_epoch'],localtime)))

    folder_name = None
    folder_name = os.path.join("see_image")
    os.makedirs(folder_name, exist_ok=True)

    for batch in tqdm(supervised_loader):
        image_FA, image_ICG, target_l, id_l = batch
        image_FA, image_ICG, target_l, id_l = image_FA.to(device), image_ICG.to(device), target_l.to(device), id_l

        for i in range(0, int(target_l.size(0))):
            folder_target = os.path.join(folder_name, "target")
            os.makedirs(folder_target, exist_ok=True)
            image = Image.fromarray(np.uint8(target_l[i].detach().cpu().numpy()))
            image.save(os.path.join(folder_target, str(id_l[i]) + ".png"))

        # warm teacher
        model_t1.train()
        model_s.eval()
        optimizer_t1.zero_grad()
        loss_t1, outputs = model_t1(x_FA=image_FA, x_ICG=image_ICG, target_l=target_l,
                           warm_up=True,mix_up=False)
        output_t1 = outputs["sup_pred"]

        # out = torch.argmax(torch.softmax(output_t1, dim=1), dim=1).squeeze(0)
        # metric_i = test_single_volume(target_l, out, classes=2)
        # metric_list_t1 += np.array(metric_i)
        
        for i in range(0, int(output_t1.size(0))):
            folder_t1_prob = os.path.join(folder_name, "t1_prob")
            os.makedirs(folder_t1_prob, exist_ok=True)
            image_prob = output_t1[i].squeeze().detach()
            image_prob = torch.argmax(image_prob, dim=0).cpu().numpy()
            image_prob = Image.fromarray((image_prob * 255).astype(np.uint8))
            image_prob.save(os.path.join(folder_t1_prob, str(id_l[i]) + ".png"))

        epoch_loss_t1 += loss_t1
        loss_t1.backward()
        optimizer_t1.step()
        # warm student
        model_t1.eval()
        model_s.train()
        optimizer_s.zero_grad()
        loss_s, outputs = model_s(x_FA=image_FA, target_l=target_l, x_ul=None, target_ul=None,
                           warm_up=True)
        output_s = outputs["sup_pred"]

        # out = torch.argmax(torch.softmax(output_s, dim=1), dim=1).squeeze(0)
        # metric_i = test_single_volume(target_l, out, classes=2)
        # metric_list_s += np.array(metric_i)

        for i in range(0, int(output_s.size(0))):
            folder_s_prob = os.path.join(folder_name, "s_prob")
            os.makedirs(folder_s_prob, exist_ok=True)
            image_prob = output_s[i].squeeze().detach()
            image_prob = torch.argmax(image_prob, dim=0).cpu().numpy()
            image_prob = Image.fromarray((image_prob * 255).astype(np.uint8))
            image_prob.save(os.path.join(folder_s_prob, str(id_l[i]) + ".png"))

        epoch_loss_s += loss_s
        loss_s.backward()
        optimizer_s.step()

    # print(folder_target)
    # print(folder_t1_prob)
    t1_mIoU, t1_mDice = count_index(folder_t1_prob, folder_target)
    s_mIoU, s_mDice = count_index(folder_s_prob, folder_target)

    # metric_list_t1 = metric_list_t1 / len(supervised_set)
    # # t1_mDice = np.mean(metric_list_t1, axis=0)[0]
    # t1_mean_hd95 = np.mean(metric_list_t1, axis=0)[1]

    # metric_list_s = metric_list_s / len(supervised_set)
    # # s_mDice = np.mean(metric_list_s, axis=0)[0]
    # s_mean_hd95 = np.mean(metric_list_s, axis=0)[1]

    print(f'Epoch{epoch+1} loss : \nteacher 1 loss = {epoch_loss_t1/len(supervised_set)},student loss = {epoch_loss_s/len(supervised_set)}')
    # print(f'Epoch{epoch+1} loss : \nteacher 1 loss = {epoch_loss_t1/len(supervised_set)}, teacher 2 loss = {epoch_loss_t2/len(supervised_set)}')
    # print(f'Epoch{epoch+1} loss : \nteacher 1 loss = {epoch_loss_t1/len(supervised_set)}')
    # print(f'teacher 1 : mIoU = {t1_mIoU}, DCS = {t1_mDice}')
    # print(f'teacher 2 : mIoU = {t2_mIoU}, DCS = {t2_mDice}')
    # print(f'student : mIoU = {s_mIoU}, DCS = {s_mDice}')

    if s_mDice > 0.95:
        early_stopper(epoch_loss_s/len(supervised_set))
        print(f'teacher 1 : mIoU = {t1_mIoU}, DCS = {t1_mDice}')
        print(f'student : mIoU = {s_mIoU}, DCS = {s_mDice}')
        if early_stopper.early_stop: 
            print("Early stopping")   
            break

Epoch: 1/100 --- < Starting Time : Tue Aug 13 14:26:00 2024 >
-------------------------------------------------------------


100%|██████████| 65/65 [00:08<00:00,  7.54it/s]


Epoch1 loss : 
teacher 1 loss = 0.04230799153447151,student loss = 0.029643038287758827
Epoch: 2/100 --- < Starting Time : Tue Aug 13 14:26:19 2024 >
-------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.44it/s]


Epoch2 loss : 
teacher 1 loss = 0.033194661140441895,student loss = 0.014235884882509708
Epoch: 3/100 --- < Starting Time : Tue Aug 13 14:26:37 2024 >
-------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.43it/s]


Epoch3 loss : 
teacher 1 loss = 0.03075830452144146,student loss = 0.010624725371599197
Epoch: 4/100 --- < Starting Time : Tue Aug 13 14:26:55 2024 >
-------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.42it/s]


Epoch4 loss : 
teacher 1 loss = 0.027949031442403793,student loss = 0.009253444150090218
Epoch: 5/100 --- < Starting Time : Tue Aug 13 14:27:13 2024 >
-------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.48it/s]


Epoch5 loss : 
teacher 1 loss = 0.025159431621432304,student loss = 0.007280655205249786
Epoch: 6/100 --- < Starting Time : Tue Aug 13 14:27:31 2024 >
-------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.56it/s]


Epoch6 loss : 
teacher 1 loss = 0.021464068442583084,student loss = 0.006096147000789642
Epoch: 7/100 --- < Starting Time : Tue Aug 13 14:27:48 2024 >
-------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch7 loss : 
teacher 1 loss = 0.02048063464462757,student loss = 0.00581568805500865
Epoch: 8/100 --- < Starting Time : Tue Aug 13 14:28:06 2024 >
-------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.62it/s]


Epoch8 loss : 
teacher 1 loss = 0.018916551023721695,student loss = 0.005946884397417307
Epoch: 9/100 --- < Starting Time : Tue Aug 13 14:28:24 2024 >
-------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.61it/s]


Epoch9 loss : 
teacher 1 loss = 0.016780897974967957,student loss = 0.005225147120654583
Epoch: 10/100 --- < Starting Time : Tue Aug 13 14:28:42 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.61it/s]


Epoch10 loss : 
teacher 1 loss = 0.01386942621320486,student loss = 0.0046220337972044945
Epoch: 11/100 --- < Starting Time : Tue Aug 13 14:29:00 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.58it/s]


Epoch11 loss : 
teacher 1 loss = 0.012473135255277157,student loss = 0.004267718642950058
Epoch: 12/100 --- < Starting Time : Tue Aug 13 14:29:17 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch12 loss : 
teacher 1 loss = 0.011685425415635109,student loss = 0.004053571727126837
Epoch: 13/100 --- < Starting Time : Tue Aug 13 14:29:35 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.53it/s]


Epoch13 loss : 
teacher 1 loss = 0.009935845620930195,student loss = 0.0038674138486385345
Epoch: 14/100 --- < Starting Time : Tue Aug 13 14:29:53 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch14 loss : 
teacher 1 loss = 0.007917976006865501,student loss = 0.003721783170476556
Epoch: 15/100 --- < Starting Time : Tue Aug 13 14:30:11 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.57it/s]


Epoch15 loss : 
teacher 1 loss = 0.006693030707538128,student loss = 0.003494328586384654
Epoch: 16/100 --- < Starting Time : Tue Aug 13 14:30:29 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.53it/s]


Epoch16 loss : 
teacher 1 loss = 0.006309970747679472,student loss = 0.0033766082488000393
Epoch: 17/100 --- < Starting Time : Tue Aug 13 14:30:47 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.50it/s]


Epoch17 loss : 
teacher 1 loss = 0.006149244029074907,student loss = 0.003411476267501712
Epoch: 18/100 --- < Starting Time : Tue Aug 13 14:31:05 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.48it/s]


Epoch18 loss : 
teacher 1 loss = 0.005906295962631702,student loss = 0.0032902806997299194
Epoch: 19/100 --- < Starting Time : Tue Aug 13 14:31:22 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.47it/s]


Epoch19 loss : 
teacher 1 loss = 0.0059502036310732365,student loss = 0.0032758554443717003
Epoch: 20/100 --- < Starting Time : Tue Aug 13 14:31:40 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.61it/s]


Epoch20 loss : 
teacher 1 loss = 0.005530723370611668,student loss = 0.0031273390632122755
Epoch: 21/100 --- < Starting Time : Tue Aug 13 14:31:58 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.58it/s]


Epoch21 loss : 
teacher 1 loss = 0.006185369100421667,student loss = 0.003277196316048503
Epoch: 22/100 --- < Starting Time : Tue Aug 13 14:32:16 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch22 loss : 
teacher 1 loss = 0.004811024758964777,student loss = 0.0029972095508128405
Epoch: 23/100 --- < Starting Time : Tue Aug 13 14:32:34 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.60it/s]


Epoch23 loss : 
teacher 1 loss = 0.004666403401643038,student loss = 0.002971774898469448
Epoch: 24/100 --- < Starting Time : Tue Aug 13 14:32:52 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.55it/s]


Epoch24 loss : 
teacher 1 loss = 0.004473497159779072,student loss = 0.002836960833519697
Epoch: 25/100 --- < Starting Time : Tue Aug 13 14:33:10 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.47it/s]


Epoch25 loss : 
teacher 1 loss = 0.004073349758982658,student loss = 0.002674855524674058
Epoch: 26/100 --- < Starting Time : Tue Aug 13 14:33:28 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch26 loss : 
teacher 1 loss = 0.004135705530643463,student loss = 0.002655591582879424
Epoch: 27/100 --- < Starting Time : Tue Aug 13 14:33:46 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch27 loss : 
teacher 1 loss = 0.0040905713103711605,student loss = 0.002590848132967949
Epoch: 28/100 --- < Starting Time : Tue Aug 13 14:34:04 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.51it/s]


Epoch28 loss : 
teacher 1 loss = 0.003818752244114876,student loss = 0.0025830399245023727
Epoch: 29/100 --- < Starting Time : Tue Aug 13 14:34:22 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.59it/s]


Epoch29 loss : 
teacher 1 loss = 0.003638321068137884,student loss = 0.002492950763553381
Epoch: 30/100 --- < Starting Time : Tue Aug 13 14:34:39 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.59it/s]


Epoch30 loss : 
teacher 1 loss = 0.0034551064018160105,student loss = 0.0024460142012685537
Epoch: 31/100 --- < Starting Time : Tue Aug 13 14:34:57 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.59it/s]


Epoch31 loss : 
teacher 1 loss = 0.0035080169327557087,student loss = 0.0024342213291674852
Epoch: 32/100 --- < Starting Time : Tue Aug 13 14:35:15 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.55it/s]


Epoch32 loss : 
teacher 1 loss = 0.003371230326592922,student loss = 0.002356921788305044
Epoch: 33/100 --- < Starting Time : Tue Aug 13 14:35:33 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.63it/s]


Epoch33 loss : 
teacher 1 loss = 0.003291322151198983,student loss = 0.002383887069299817
Epoch: 34/100 --- < Starting Time : Tue Aug 13 14:35:51 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch34 loss : 
teacher 1 loss = 0.0031447489745914936,student loss = 0.0022966095712035894
Epoch: 35/100 --- < Starting Time : Tue Aug 13 14:36:09 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch35 loss : 
teacher 1 loss = 0.0031403463799506426,student loss = 0.0023027395363897085
Epoch: 36/100 --- < Starting Time : Tue Aug 13 14:36:26 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.56it/s]


Epoch36 loss : 
teacher 1 loss = 0.003082557348534465,student loss = 0.002203643787652254
Epoch: 37/100 --- < Starting Time : Tue Aug 13 14:36:44 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.53it/s]


Epoch37 loss : 
teacher 1 loss = 0.003039778210222721,student loss = 0.0022415476851165295
Epoch: 38/100 --- < Starting Time : Tue Aug 13 14:37:02 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.56it/s]


Epoch38 loss : 
teacher 1 loss = 0.0030373898334801197,student loss = 0.0021796012297272682
Epoch: 39/100 --- < Starting Time : Tue Aug 13 14:37:20 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch39 loss : 
teacher 1 loss = 0.003068830817937851,student loss = 0.0022350505460053682
Epoch: 40/100 --- < Starting Time : Tue Aug 13 14:37:38 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.53it/s]


Epoch40 loss : 
teacher 1 loss = 0.0032755343709141016,student loss = 0.002179386094212532
Epoch: 41/100 --- < Starting Time : Tue Aug 13 14:37:56 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.49it/s]


Epoch41 loss : 
teacher 1 loss = 0.0028566380497068167,student loss = 0.002067418536171317
Epoch: 42/100 --- < Starting Time : Tue Aug 13 14:38:14 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.49it/s]


Epoch42 loss : 
teacher 1 loss = 0.0027331120800226927,student loss = 0.002012968994677067
Epoch: 43/100 --- < Starting Time : Tue Aug 13 14:38:32 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.58it/s]


Epoch43 loss : 
teacher 1 loss = 0.0026959734968841076,student loss = 0.002007765229791403
Epoch: 44/100 --- < Starting Time : Tue Aug 13 14:38:49 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.60it/s]


Epoch44 loss : 
teacher 1 loss = 0.0026628291234374046,student loss = 0.0019237218657508492
Epoch: 45/100 --- < Starting Time : Tue Aug 13 14:39:07 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.53it/s]


Epoch45 loss : 
teacher 1 loss = 0.002617275109514594,student loss = 0.0018773437477648258
Epoch: 46/100 --- < Starting Time : Tue Aug 13 14:39:25 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.58it/s]


Epoch46 loss : 
teacher 1 loss = 0.002527697244659066,student loss = 0.0018980383174493909
Epoch: 47/100 --- < Starting Time : Tue Aug 13 14:39:43 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.56it/s]


Epoch47 loss : 
teacher 1 loss = 0.002544473623856902,student loss = 0.0019101044163107872
Epoch: 48/100 --- < Starting Time : Tue Aug 13 14:40:01 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.53it/s]


Epoch48 loss : 
teacher 1 loss = 0.0024472458753734827,student loss = 0.001850693253800273
Epoch: 49/100 --- < Starting Time : Tue Aug 13 14:40:19 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.51it/s]


Epoch49 loss : 
teacher 1 loss = 0.002386206528171897,student loss = 0.0018569071544334292
Epoch: 50/100 --- < Starting Time : Tue Aug 13 14:40:37 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.53it/s]


Epoch50 loss : 
teacher 1 loss = 0.0022752350196242332,student loss = 0.0018240032950416207
Epoch: 51/100 --- < Starting Time : Tue Aug 13 14:40:55 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.51it/s]


Epoch51 loss : 
teacher 1 loss = 0.0023661586456000805,student loss = 0.0017691327957436442
Epoch: 52/100 --- < Starting Time : Tue Aug 13 14:41:12 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.55it/s]


Epoch52 loss : 
teacher 1 loss = 0.0023841639049351215,student loss = 0.001780339633114636
Epoch: 53/100 --- < Starting Time : Tue Aug 13 14:41:30 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.61it/s]


Epoch53 loss : 
teacher 1 loss = 0.002373626222833991,student loss = 0.0018058603163808584
Epoch: 54/100 --- < Starting Time : Tue Aug 13 14:41:48 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.59it/s]


Epoch54 loss : 
teacher 1 loss = 0.002369434107095003,student loss = 0.0017900635721161962
Epoch: 55/100 --- < Starting Time : Tue Aug 13 14:42:06 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.58it/s]


Epoch55 loss : 
teacher 1 loss = 0.0022601443342864513,student loss = 0.001736452803015709
Epoch: 56/100 --- < Starting Time : Tue Aug 13 14:42:24 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.53it/s]


Epoch56 loss : 
teacher 1 loss = 0.002259261906147003,student loss = 0.0017281308537349105
Epoch: 57/100 --- < Starting Time : Tue Aug 13 14:42:41 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.56it/s]


Epoch57 loss : 
teacher 1 loss = 0.0021386209409683943,student loss = 0.0016921552596613765
Epoch: 58/100 --- < Starting Time : Tue Aug 13 14:42:59 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch58 loss : 
teacher 1 loss = 0.0023500421084463596,student loss = 0.001723550260066986
Epoch: 59/100 --- < Starting Time : Tue Aug 13 14:43:17 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.56it/s]


Epoch59 loss : 
teacher 1 loss = 0.0024901563301682472,student loss = 0.0016520177014172077
Epoch: 60/100 --- < Starting Time : Tue Aug 13 14:43:35 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch60 loss : 
teacher 1 loss = 0.0023514628410339355,student loss = 0.0017406577244400978
Epoch: 61/100 --- < Starting Time : Tue Aug 13 14:43:53 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.52it/s]


Epoch61 loss : 
teacher 1 loss = 0.0022547703702002764,student loss = 0.0016704668523743749
Epoch: 62/100 --- < Starting Time : Tue Aug 13 14:44:11 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.52it/s]


Epoch62 loss : 
teacher 1 loss = 0.002183044096454978,student loss = 0.0016643520211800933
Epoch: 63/100 --- < Starting Time : Tue Aug 13 14:44:28 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.49it/s]


Epoch63 loss : 
teacher 1 loss = 0.00207954621873796,student loss = 0.001631007413379848
Epoch: 64/100 --- < Starting Time : Tue Aug 13 14:44:46 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.49it/s]


Epoch64 loss : 
teacher 1 loss = 0.0020518130622804165,student loss = 0.001591887790709734
teacher 1 : mIoU = 0.9409825589939406, DCS = 0.9388461307874452
student : mIoU = 0.9527473859041209, DCS = 0.9505998704320805
Epoch: 65/100 --- < Starting Time : Tue Aug 13 14:45:04 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.45it/s]


Epoch65 loss : 
teacher 1 loss = 0.001912886742502451,student loss = 0.0015544912312179804
EarlyStopping counter: 1 out of 5
teacher 1 : mIoU = 0.9441479993792371, DCS = 0.9421559858524381
student : mIoU = 0.9529227271823548, DCS = 0.9507161277303886
Epoch: 66/100 --- < Starting Time : Tue Aug 13 14:45:22 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.47it/s]


Epoch66 loss : 
teacher 1 loss = 0.002025735331699252,student loss = 0.0015659742057323456
EarlyStopping counter: 2 out of 5
teacher 1 : mIoU = 0.941863084614813, DCS = 0.939414844821166
student : mIoU = 0.9540556995684945, DCS = 0.9524773460852598
Epoch: 67/100 --- < Starting Time : Tue Aug 13 14:45:40 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.54it/s]


Epoch67 loss : 
teacher 1 loss = 0.002027894603088498,student loss = 0.0015468953642994165
EarlyStopping counter: 3 out of 5
teacher 1 : mIoU = 0.9433891698848176, DCS = 0.9423558374378442
student : mIoU = 0.9535409537054232, DCS = 0.9513591946499226
Epoch: 68/100 --- < Starting Time : Tue Aug 13 14:45:58 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.51it/s]


Epoch68 loss : 
teacher 1 loss = 0.001890041516162455,student loss = 0.001503169653005898
EarlyStopping counter: 4 out of 5
teacher 1 : mIoU = 0.9460555899984708, DCS = 0.9450222220300722
student : mIoU = 0.9555082590094238, DCS = 0.9543239636006127
Epoch: 69/100 --- < Starting Time : Tue Aug 13 14:46:16 2024 >
--------------------------------------------------------------


100%|██████████| 65/65 [00:07<00:00,  8.52it/s]


Epoch69 loss : 
teacher 1 loss = 0.00188035040628165,student loss = 0.0015231571160256863
EarlyStopping counter: 5 out of 5
teacher 1 : mIoU = 0.9466917108369775, DCS = 0.9455752142842838
student : mIoU = 0.9556718484221477, DCS = 0.9553598934574623
Early stopping


In [8]:
def freeze_teachers_parameters(model):
    for p in model.encoder.parameters():
        p.requires_grad = False
    for p in model.decoder.parameters():
        p.requires_grad = False

freeze_teachers_parameters(model_t1)

def update_teachers(teacher, student, keep_rate=0.996):
    student_encoder_dict = student.encoder.state_dict()
    student_decoder_dict = student.decoder.state_dict()
    new_teacher_encoder_dict = OrderedDict()
    new_teacher_decoder_dict = OrderedDict()

    for key, value in teacher.encoder.state_dict().items():

        if key in student_encoder_dict.keys():
            new_teacher_encoder_dict[key] = (
                    student_encoder_dict[key] * (1 - keep_rate) + value * keep_rate
            )
        else:
            raise Exception("{} is not found in student encoder model".format(key))

    for key, value in teacher.decoder.state_dict().items():

        if key in student_decoder_dict.keys():
            new_teacher_decoder_dict[key] = (
                    student_decoder_dict[key] * (1 - keep_rate) + value * keep_rate
            )
        else:
            raise Exception("{} is not found in student decoder model".format(key))
    teacher.encoder.load_state_dict(new_teacher_encoder_dict, strict=True)
    teacher.decoder.load_state_dict(new_teacher_decoder_dict, strict=True)

In [9]:
def predict_with_out_grad(model_t1, model_t2, image):
    with torch.no_grad():
        f = model_t1.encoder(image)
        _, predict_target_ul1 = model_t1.decoder(f, data_shape=[image.shape[-2], image.shape[-1]])
        f = model_t2.encoder(image)
        _, predict_target_ul2 = model_t2.decoder(f, data_shape=[image.shape[-2], image.shape[-1]])
        
        predict_target_ul1 = torch.nn.functional.interpolate(predict_target_ul1,
                                                                size=(image.shape[-2], image.shape[-1]),
                                                                mode='bilinear',
                                                                align_corners=True)

        predict_target_ul2 = torch.nn.functional.interpolate(predict_target_ul2,
                                                                size=(image.shape[-2], image.shape[-1]),
                                                                mode='bilinear',
                                                                align_corners=True)

        assert predict_target_ul1.shape == predict_target_ul2.shape, "Expect two prediction in same shape,"
    return predict_target_ul1, predict_target_ul2

In [10]:
# semi train
from Utils.early_stop import EarlyStopper

early_stopper = EarlyStopper(patience=5, delta=0.001)

best_model_t1_params = copy.deepcopy(model_t1.state_dict())
best_model_s_params = copy.deepcopy(model_s.state_dict())
do_best_Dice = 0
do_best_mIoU = 0
# do_best_hd95 = 100000
do_best_epoch = 0

for epoch in range(config['model']['epochs']):

    model_s.train()
    epoch_loss = 0
    epoch_dsc = 0

    dataloader = iter(zip(cycle(supervised_loader), unsupervised_loader))

    localtime = time.asctime( time.localtime(time.time()) )
    print('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['epochs'],localtime))
    print('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,config['model']['epochs'],localtime)))

    tbar = tqdm(range(len(unsupervised_loader)))
    for batch_idx in tbar:
        (image_FA, image_ICG, target_l, id_l), (input_ul, target_ul, id_ul) = next(dataloader)
        image_FA, image_ICG, target_l, id_l = image_FA.to(device), image_ICG.to(device), target_l.to(device), id_l
        input_ul, target_ul, id_ul = input_ul.to(device), target_ul.to(device), id_ul
        optimizer_s.zero_grad()

        with torch.no_grad():
            f = model_t1.encoder(input_ul)
            _, predict_target_ul1 = model_t1.decoder(f, data_shape=[input_ul.shape[-2], input_ul.shape[-1]])
            
            predict_target_ul1 = torch.nn.functional.interpolate(predict_target_ul1,
                                                                    size=(input_ul.shape[-2], input_ul.shape[-1]),
                                                                    mode='bilinear',
                                                                    align_corners=True)

        total_loss, cur_losses, outputs = model_s(x_FA=image_FA, x_ICG=image_ICG, target_l=target_l,
                                                  x_ul=input_ul, target_ul=predict_target_ul1,
                                                  epoch=epoch, curr_iter=batch_idx, warm_up=False,
                                                  mix_up=False, t1=model_t1, t2=model_t1)
        
        epoch_loss += total_loss
        total_loss.backward()
        optimizer_s.step()

        with torch.no_grad():
            update_teachers(teacher=model_t1,
                            student=model_s)

    # valiation
    # metric_list = 0.0
    model_s.eval()
    for batch in tqdm(val_loader):
        image_val, label, id_val = batch
        image_val, label, id_val = image_val.to(device), label.to(device), id_val

        H, W = label.size(1), label.size(2)
        up_sizes = (ceil(H / 8) * 8, ceil(W / 8) * 8)

        for i in range(0, int(label.size(0))):
            folder_val = os.path.join(folder_name, "val_original_target")
            os.makedirs(folder_val, exist_ok=True)
            image = Image.fromarray(np.uint8(label[i].detach().cpu().numpy()))
            image.save(os.path.join(folder_val, str(id_val[i]) + ".png"))

        data = torch.nn.functional.interpolate(image_val, size=(up_sizes[0], up_sizes[1]),
                                               mode='bilinear', align_corners=True)
        
        with torch.no_grad():
            f = model_t1.encoder(data)
            _, output = model_t1.decoder(f, data_shape=[data.shape[-2], data.shape[-1]])
        output = torch.nn.functional.interpolate(output, size=(H, W),
                                                 mode='bilinear', align_corners=True)

        for i in range(0, int(output.size(0))):
            folder_val_prob = os.path.join(folder_name, "val_original_prob")
            os.makedirs(folder_val_prob, exist_ok=True)
            image_prob = output[i].squeeze().detach()
            image_prob = torch.argmax(image_prob, dim=0).cpu().numpy()
            image_prob = Image.fromarray((image_prob * 255).astype(np.uint8))
            image_prob.save(os.path.join(folder_val_prob, str(id_val[i]) + ".png"))
        
    #     out = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze(0)
    #     metric_i = test_single_volume(label, out, classes=2)
    #     metric_list += np.array(metric_i)
    # metric_list = metric_list / len(val_set)
    # # index_mDice = np.mean(metric_list, axis=0)[0]
    # index_mean_hd95 = np.mean(metric_list, axis=0)[1]
    
    # show epoch mIoU, mDice
    index_mIoU, index_mDice = count_index(folder_val_prob, folder_val)
    # print(f'Epoch {epoch+1}' + " val:" + f'DSC: {index_mDice:.4f}, HD95 = {index_mean_hd95:.4f}')
    print(f'Epoch {epoch+1}' + " val:" + f'mIoU: {index_mIoU:.4f}, DSC: {index_mDice:.4f}')

    # find the best mIoU, mDice
    # if index_mDice > do_best_Dice and (index_mDice > do_best_Dice or index_mean_hd95 < do_best_hd95):
    if index_mDice > do_best_Dice:
        early_stopper.best_score = index_mDice
        early_stopper.counter = 0
        do_best_Dice = index_mDice
        do_best_mIoU = index_mIoU
        # do_best_hd95 = index_mean_hd95
        do_best_epoch = epoch+1
        best_model_t1_params = copy.deepcopy(model_t1.state_dict())
        best_model_s_params = copy.deepcopy(model_s.state_dict())
        # print("Change Best: " + f'epoch: {do_best_epoch} DSC: {do_best_Dice:.4f}, HD95 = {do_best_hd95:.4f}')
        print("Change Best: " + f'epoch : {do_best_epoch} mIoU: {do_best_mIoU:.4f}, DSC: {do_best_Dice:.4f}')

        # save the best valiation prod
        folder_val_best_prob = os.path.join(folder_name, "val_original_best_prob")
        os.makedirs(folder_val_best_prob, exist_ok=True)
        file_names = os.listdir(folder_val_prob)
        for file_name in file_names:
            source_path = os.path.join(folder_val_prob, file_name)
            destination_path = os.path.join(folder_val_best_prob, file_name)
            shutil.copyfile(source_path, destination_path)
    else:
        if epoch >= 5:
            early_stopper(index_mDice)
    
    if early_stopper.early_stop: 
        print("Early stopping")   
        break

    # show the best mIoU, mDice
    # print("Best model : " + f'epoch : {do_best_epoch}  DSC : {do_best_Dice:.4f}, HD95 = {do_best_hd95:.4f}')
    print("Best model : " + f'epoch : {do_best_epoch} mIoU: {do_best_mIoU:.4f}, DSC: {do_best_Dice:.4f}')
   

Epoch: 1/80 --- < Starting Time : Tue Aug 13 14:46:35 2024 >
------------------------------------------------------------


100%|██████████| 658/658 [02:27<00:00,  4.46it/s]
100%|██████████| 162/162 [00:02<00:00, 72.56it/s]


Epoch 1 val:mIoU: 0.5385, DSC: 0.2175
Change Best: epoch : 1 mIoU: 0.5385, DSC: 0.2175
Best model : epoch : 1 mIoU: 0.5385, DSC: 0.2175
Epoch: 2/80 --- < Starting Time : Tue Aug 13 14:49:08 2024 >
------------------------------------------------------------


100%|██████████| 658/658 [02:42<00:00,  4.04it/s]
100%|██████████| 162/162 [00:01<00:00, 82.35it/s] 


Epoch 2 val:mIoU: 0.6401, DSC: 0.4737
Change Best: epoch : 2 mIoU: 0.6401, DSC: 0.4737
Best model : epoch : 2 mIoU: 0.6401, DSC: 0.4737
Epoch: 3/80 --- < Starting Time : Tue Aug 13 14:51:56 2024 >
------------------------------------------------------------


100%|██████████| 658/658 [02:44<00:00,  4.00it/s]
100%|██████████| 162/162 [00:01<00:00, 82.49it/s] 


Epoch 3 val:mIoU: 0.6858, DSC: 0.5569
Change Best: epoch : 3 mIoU: 0.6858, DSC: 0.5569
Best model : epoch : 3 mIoU: 0.6858, DSC: 0.5569
Epoch: 4/80 --- < Starting Time : Tue Aug 13 14:54:45 2024 >
------------------------------------------------------------


100%|██████████| 658/658 [02:38<00:00,  4.16it/s]
100%|██████████| 162/162 [00:01<00:00, 82.35it/s] 


Epoch 4 val:mIoU: 0.6853, DSC: 0.5570
Change Best: epoch : 4 mIoU: 0.6853, DSC: 0.5570
Best model : epoch : 4 mIoU: 0.6853, DSC: 0.5570
Epoch: 5/80 --- < Starting Time : Tue Aug 13 14:57:28 2024 >
------------------------------------------------------------


100%|██████████| 658/658 [02:47<00:00,  3.92it/s]
100%|██████████| 162/162 [00:01<00:00, 82.58it/s] 


Epoch 5 val:mIoU: 0.6854, DSC: 0.5573
Change Best: epoch : 5 mIoU: 0.6854, DSC: 0.5573
Best model : epoch : 5 mIoU: 0.6854, DSC: 0.5573
Epoch: 6/80 --- < Starting Time : Tue Aug 13 15:00:21 2024 >
------------------------------------------------------------


100%|██████████| 658/658 [02:42<00:00,  4.04it/s]
100%|██████████| 162/162 [00:01<00:00, 82.60it/s] 


Epoch 6 val:mIoU: 0.6850, DSC: 0.5568
EarlyStop DSC =  0.557278733088501
Best model : epoch : 5 mIoU: 0.6854, DSC: 0.5573
Epoch: 7/80 --- < Starting Time : Tue Aug 13 15:03:09 2024 >
------------------------------------------------------------


100%|██████████| 658/658 [02:44<00:00,  3.99it/s]
100%|██████████| 162/162 [00:01<00:00, 82.99it/s] 


Epoch 7 val:mIoU: 0.6840, DSC: 0.5550
EarlyStop DSC =  0.5549714703108378
EarlyStop Best DSC =  0.556278733088501
EarlyStopping counter: 1 out of 5
Best model : epoch : 5 mIoU: 0.6854, DSC: 0.5573
Epoch: 8/80 --- < Starting Time : Tue Aug 13 15:05:59 2024 >
------------------------------------------------------------


100%|██████████| 658/658 [02:45<00:00,  3.99it/s]
100%|██████████| 162/162 [00:01<00:00, 83.04it/s] 


Epoch 8 val:mIoU: 0.6841, DSC: 0.5554
EarlyStop DSC =  0.5554049079029798
EarlyStop Best DSC =  0.556278733088501
EarlyStopping counter: 2 out of 5
Best model : epoch : 5 mIoU: 0.6854, DSC: 0.5573
Epoch: 9/80 --- < Starting Time : Tue Aug 13 15:08:49 2024 >
------------------------------------------------------------


100%|██████████| 658/658 [02:45<00:00,  3.97it/s]
100%|██████████| 162/162 [00:01<00:00, 82.74it/s] 


Epoch 9 val:mIoU: 0.6834, DSC: 0.5544
EarlyStop DSC =  0.5543874594261096
EarlyStop Best DSC =  0.556278733088501
EarlyStopping counter: 3 out of 5
Best model : epoch : 5 mIoU: 0.6854, DSC: 0.5573
Epoch: 10/80 --- < Starting Time : Tue Aug 13 15:11:40 2024 >
-------------------------------------------------------------


100%|██████████| 658/658 [02:42<00:00,  4.04it/s]
100%|██████████| 162/162 [00:01<00:00, 82.58it/s] 


Epoch 10 val:mIoU: 0.6848, DSC: 0.5568
EarlyStop DSC =  0.557278733088501
Best model : epoch : 5 mIoU: 0.6854, DSC: 0.5573
Epoch: 11/80 --- < Starting Time : Tue Aug 13 15:14:28 2024 >
-------------------------------------------------------------


100%|██████████| 658/658 [02:45<00:00,  3.98it/s]
100%|██████████| 162/162 [00:01<00:00, 82.56it/s]


Epoch 11 val:mIoU: 0.6843, DSC: 0.5559
EarlyStop DSC =  0.5558736343862044
EarlyStop Best DSC =  0.556278733088501
EarlyStopping counter: 1 out of 5
Best model : epoch : 5 mIoU: 0.6854, DSC: 0.5573
Epoch: 12/80 --- < Starting Time : Tue Aug 13 15:17:18 2024 >
-------------------------------------------------------------


100%|██████████| 658/658 [02:44<00:00,  4.01it/s]
100%|██████████| 162/162 [00:01<00:00, 82.50it/s] 


Epoch 12 val:mIoU: 0.6832, DSC: 0.5534
EarlyStop DSC =  0.5534141605488178
EarlyStop Best DSC =  0.556278733088501
EarlyStopping counter: 2 out of 5
Best model : epoch : 5 mIoU: 0.6854, DSC: 0.5573
Epoch: 13/80 --- < Starting Time : Tue Aug 13 15:20:07 2024 >
-------------------------------------------------------------


100%|██████████| 658/658 [02:41<00:00,  4.08it/s]
100%|██████████| 162/162 [00:01<00:00, 81.33it/s] 


Epoch 13 val:mIoU: 0.6833, DSC: 0.5540
EarlyStop DSC =  0.5540308281487685
EarlyStop Best DSC =  0.556278733088501
EarlyStopping counter: 3 out of 5
Best model : epoch : 5 mIoU: 0.6854, DSC: 0.5573
Epoch: 14/80 --- < Starting Time : Tue Aug 13 15:22:54 2024 >
-------------------------------------------------------------


100%|██████████| 658/658 [02:42<00:00,  4.04it/s]
100%|██████████| 162/162 [00:01<00:00, 82.25it/s] 


Epoch 14 val:mIoU: 0.6842, DSC: 0.5555
EarlyStop DSC =  0.5555455992442307
EarlyStop Best DSC =  0.556278733088501
EarlyStopping counter: 4 out of 5
Best model : epoch : 5 mIoU: 0.6854, DSC: 0.5573
Epoch: 15/80 --- < Starting Time : Tue Aug 13 15:25:42 2024 >
-------------------------------------------------------------


100%|██████████| 658/658 [02:35<00:00,  4.24it/s]
100%|██████████| 162/162 [00:01<00:00, 82.42it/s] 


Epoch 15 val:mIoU: 0.6832, DSC: 0.5540
EarlyStop DSC =  0.5539935899995714
EarlyStop Best DSC =  0.556278733088501
EarlyStopping counter: 5 out of 5
Early stopping


In [11]:
# save the best teacher1&2 and student model
model_path = None
model_path = os.path.join("saved_models")
os.makedirs(model_path, exist_ok=True)

model_params = os.path.join(model_path, "original_models")
os.makedirs(model_params, exist_ok=True)

torch.save(best_model_t1_params, os.path.join(model_params, f'original_epoch_{do_best_epoch}_dsc_{do_best_Dice:.4f}_best_t1.pth'))
torch.save(best_model_s_params, os.path.join(model_params, f'original_epoch_{do_best_epoch}_dsc_{do_best_Dice:.4f}_best_s.pth'))
print("Best model : " + f'epoch : {do_best_epoch}  DSC : {do_best_Dice:.4f}')

Best model : epoch : 5  DSC : 0.5573
