In [1]:
import os
import json
import logging
import copy
import shutil
import time

from tqdm import tqdm

import cv2
import numpy as np
from PIL import Image

import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

from torchvision import transforms
from torchvision.transforms import functional as F

from math import ceil
from itertools import cycle
from itertools import chain
from collections import OrderedDict
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns

import random
random.seed(42)

In [2]:
#config
batch_size = 8
epochs = 80
warm_up = 100
labeled_examples = 515
lr = 1e-2
backbone = 101 #" the resnet x {50, 101} layers"
semi_p_th = 0.6 # positive_threshold for semi-supervised loss
semi_n_th = 0.0 # negative_threshold for semi-supervised loss
unsup_weight = 1.5 # unsupervised weight for the semi-supervised loss

config = json.load(open("configs/config_deeplab_v3+_onlyFA_range_selfsupervised.json"))

config['train_supervised']['batch_size'] = batch_size
config['train_unsupervised']['batch_size'] = batch_size
config['warm_selfsupervised']['batch_size'] = batch_size
config['model']['epochs'] = epochs
config['model']['warm_up_epoch'] = warm_up
config['n_labeled_examples'] = labeled_examples
config['model']['resnet'] = backbone
config['optimizer']['args']['lr'] = lr
config['unsupervised_w'] = unsup_weight
config['model']['data_h_w'] = [config['train_supervised']['crop_size'], config['train_supervised']['crop_size']]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

logger = logging.getLogger("PS-MT")
logger.propagate = False
logger.warning("Training start, 总共 {} epochs".format(str(config['model']['epochs'])))
logger.critical("GPUs: {}".format(device))
logger.critical("DeeplabV3+ with ResNet {} backbone".format(str(config['model']['resnet'])))
logger.critical("Current Labeled Example: {}".format(config['n_labeled_examples']))
logger.critical("Learning rate: other {}, and head is the SAME [world]".format(config['optimizer']['args']['lr']))

logger.critical("Current batch: {} [world]".format(int(config['train_unsupervised']['batch_size']) +
                                                int(config['train_supervised']['batch_size'])) )

logger.critical("Current unsupervised loss function: {}, with weight {} and length {}".format(config['model']['un_loss'],
                                                                                            config['unsupervised_w'],
                                                                                            config['ramp_up']))
print("\nconfig json :")
for i in config:
    print(i, config[i])


Training start, 总共 80 epochs
GPUs: cuda
DeeplabV3+ with ResNet 101 backbone
Current Labeled Example: 515
Learning rate: other 0.01, and head is the SAME [world]
Current batch: 16 [world]
Current unsupervised loss function: semi_ce, with weight 1.5 and length 12



config json :
name PS-MT(DeeplabV3+)
experim_name TEST_warm
n_labeled_examples 515
ramp_up 12
unsupervised_w 1.5
lr_scheduler Poly
gamma 0.5
model {'supervised': False, 'semi': True, 'resnet': 101, 'sup_loss': 'DE', 'un_loss': 'semi_ce', 'epochs': 80, 'warm_up_epoch': 100, 'data_h_w': [224, 224]}
optimizer {'type': 'SGD', 'args': {'lr': 0.01, 'weight_decay': 0.0001, 'momentum': 0.9}}
train_supervised {'data_dir': 'FA', 'batch_size': 8, 'shuffle': True, 'crop_size': 224, 'split': 'train_supervised', 'num_workers': 8}
train_unsupervised {'data_dir': 'range_unFA', 'batch_size': 8, 'shuffle': True, 'crop_size': 224, 'split': 'train_unsupervised', 'num_workers': 8}
warm_selfsupervised {'data_dir': 'range_unFA', 'batch_size': 8, 'shuffle': True, 'crop_size': 224, 'split': 'train_unsupervised', 'num_workers': 8}
val_loader {'data_dir': 'FA', 'batch_size': 1, 'split': 'val', 'shuffle': False, 'num_workers': 4}
test_loader {'data_dir': 'FA', 'batch_size': 1, 'split': 'test', 'shuffle': False, 

In [3]:
# DATA LOADERS
from DataLoader.dataset_onlyFA import *
choose_data = "All"

config['train_supervised']['choose'] = choose_data
config['train_unsupervised']['choose'] = choose_data
config['warm_selfsupervised']['choose'] = choose_data
config['val_loader']['choose'] = choose_data
config['test_loader']['choose'] = choose_data

print("train_supervised")
for i in config['train_supervised']:
    print("    ",i, ":", config['train_supervised'][i])
print("train_unsupervised")
for i in config['train_unsupervised']:
    print("    ", i, ":", config['train_unsupervised'][i])
print("warm_selfsupervised")
for i in config['warm_selfsupervised']:
    print("    ", i, ":", config['warm_selfsupervised'][i])
print("val_loader")
for i in config['val_loader']:
    print("    ", i, ":", config['val_loader'][i])
print("test_loader")
for i in config['test_loader']:
    print("    ", i, ":", config['test_loader'][i])

supervised_set = BasicDataset(data_dir=config['train_supervised']['data_dir'], 
                                 choose=config['train_supervised']['choose'],
                                 split=config['train_supervised']['split'])

unsupervised_set = BasicDataset(data_dir=config['train_unsupervised']['data_dir'],
                                   choose=config['train_unsupervised']['choose'],
                                   split=config['train_unsupervised']['split'])

warm_selfsupervised_set = BasicDataset(data_dir=config['warm_selfsupervised']['data_dir'],
                                      choose=config['warm_selfsupervised']['choose'],
                                      split=config['warm_selfsupervised']['split'])

val_set = BasicDataset(data_dir=config['val_loader']['data_dir'],
                          choose=config['val_loader']['choose'],
                          split=config['val_loader']['split'])

test_set = BasicDataset(data_dir=config['test_loader']['data_dir'],
                          choose=config['test_loader']['choose'],
                          split=config['test_loader']['split'])

print("supervised_loader: ",len(supervised_set))
print("unsupervised_loader: ",len(unsupervised_set))
print("warm_selfsupervised_loader: ",len(warm_selfsupervised_set))
print("val_loader: ",len(val_set))
print("test_loader: ",len(test_set))


supervised_loader = DataLoader(dataset=supervised_set, batch_size=config['train_supervised']['batch_size'],
                               shuffle=config['train_supervised']['shuffle'], 
                               num_workers=config['train_supervised']['num_workers'])

unsupervised_loader = DataLoader(dataset=unsupervised_set, batch_size=config['train_unsupervised']['batch_size'],
                               shuffle=config['train_unsupervised']['shuffle'], 
                               num_workers=config['train_unsupervised']['num_workers'])
                               
warm_selfsupervised_loader = DataLoader(dataset=unsupervised_set, batch_size=config['warm_selfsupervised']['batch_size'],
                              shuffle=config['warm_selfsupervised']['shuffle'],
                              num_workers=config['warm_selfsupervised']['num_workers'])

val_loader = DataLoader(dataset=val_set, batch_size=config['val_loader']['batch_size'],
                               shuffle=config['val_loader']['shuffle'], 
                               num_workers=config['val_loader']['num_workers'])

test_loader = DataLoader(dataset=test_set, batch_size=config['test_loader']['batch_size'],
                               shuffle=config['test_loader']['shuffle'], 
                               num_workers=config['test_loader']['num_workers'])



train_supervised
     data_dir : FA
     batch_size : 8
     shuffle : True
     crop_size : 224
     split : train_supervised
     num_workers : 8
     choose : All
train_unsupervised
     data_dir : range_unFA
     batch_size : 8
     shuffle : True
     crop_size : 224
     split : train_unsupervised
     num_workers : 8
     choose : All
warm_selfsupervised
     data_dir : range_unFA
     batch_size : 8
     shuffle : True
     crop_size : 224
     split : train_unsupervised
     num_workers : 8
     choose : All
val_loader
     data_dir : FA
     batch_size : 1
     split : val
     shuffle : False
     num_workers : 4
     choose : All
test_loader
     data_dir : FA
     batch_size : 1
     split : test
     shuffle : False
     num_workers : 4
     choose : All
supervised_loader:  515
unsupervised_loader:  5263
warm_selfsupervised_loader:  5263
val_loader:  162
test_loader:  157


In [4]:
#calculate metrics
from sklearn import metrics, neighbors
from sklearn.metrics import confusion_matrix

#compute mean IoU & DSC  of given outputs & targets per patient (5 time points)
def information_index(outputs, targets):
    eps = np.finfo(np.float64).eps
    output = outputs.flatten()
    target = targets.flatten()
    
    # Compute the confusion matrix
    cm = confusion_matrix(target, output).ravel()
    
    # Handle different shapes of confusion matrix
    if len(cm) == 4:
        TN, FP, FN, TP = cm
    elif len(cm) == 1:
        # Case where only one class is present in both target and output
        if target[0] == 0:
            TN, FP, FN, TP = cm[0], 0, 0, 0  # All True Negatives
        else:
            TN, FP, FN, TP = 0, 0, 0, cm[0]  # All True Positives
    elif len(cm) == 2:
        # Case where target and output contain only a single class
        if target[0] == 0:
            TN, FP, FN, TP = cm[0], cm[1], 0, 0  # No False Negatives or True Positives
        else:
            TN, FP, FN, TP = 0, 0, cm[0], cm[1]  # No True Negatives or False Positives
    else:
        raise ValueError("Unexpected confusion matrix size.")

    # Compute IoU and Dice coefficients
    index_MIou = (TP / (TP + FP + FN + eps) + TN / (TN + FN + FP + eps)) / 2
    mean_iou = np.mean(index_MIou)
    index_dice = 2 * TP / (2 * TP + FP + FN + eps)
    mean_dice = np.mean(index_dice)

    return mean_iou, mean_dice
#compute mean IoU & DSC of validaton (test set)
def count_index(pre, tar):
        path_pre = pre
        path_target = tar
        dirs = os.listdir(path_pre)
        # print(len(dirs))
        con_mIOU = 0
        con_mdice = 0
        for imgs in dirs:
            pre_path = path_pre + '/' + str(imgs)
            target_path = path_target + '/' + str(imgs)

            target = cv2.imread(target_path, cv2.IMREAD_GRAYSCALE)
            _, tar = cv2.threshold(target, 128, 1, cv2.THRESH_BINARY)

            predict = cv2.imread(pre_path, cv2.IMREAD_GRAYSCALE)
            _, pre = cv2.threshold(predict, 128, 1, cv2.THRESH_BINARY)

            tIOU, tdice = information_index(pre,tar)
            con_mIOU += tIOU
            con_mdice += tdice
        val_mIoU = con_mIOU/len(dirs)
        val_mDice = con_mdice/len(dirs)
        
        return val_mIoU, val_mDice

In [5]:
#calculate DSC & HD

from medpy import metric

#calculate DSC & HD of each image
def calculate_metric_percase(pred, gt):
    pred[pred > 0] = 1
    gt[gt > 0] = 1
    if pred.sum() > 0:
        dice = metric.binary.dc(pred, gt)
        hd95 = metric.binary.hd95(pred, gt)
        return dice, hd95
    else:
        return 0, 0

#append DSC & HD of each patient (5 time points)
def test_single_volume(label, output, classes=2):
    label = torch.clamp(label, 0, 1)
    label = label.squeeze(0).cpu().detach().numpy()
    output = output.cpu().detach().numpy()
    metric_list = []
    for i in range(1, classes):
        metric_list.append(calculate_metric_percase(
            output == i, label == i))
    
    return metric_list

In [6]:
#model setting (1 teacher + 1 student)
from torch import optim

from Model.selfsupervised.selfsupervised_model import *
from Utils.ramps import *

cons_w_unsup = ConsistencyWeight(final_w=config['unsupervised_w'], iters_per_epoch=len(unsupervised_loader),
                                 rampup_starts=0, rampup_ends=config['ramp_up'],  ramp_type="cosine_rampup")


model_t1 = Teacher_Net(num_classes=2, config=config['model'])
model_t1 = model_t1.to(device)

model_s = Student_Net(num_classes=2, config=config['model'],  cons_w_unsup=cons_w_unsup)
model_s = model_s.to(device)

optimizer_t1 = optim.SGD(model_t1.parameters(), 
                      lr=config['optimizer']['args']['lr'],
                      momentum=config['optimizer']['args']['momentum'],
                      weight_decay=config['optimizer']['args']['weight_decay'])

optimizer_s = optim.SGD(model_s.parameters(), 
                      lr=config['optimizer']['args']['lr'],
                      momentum=config['optimizer']['args']['momentum'],
                      weight_decay=config['optimizer']['args']['weight_decay'])

In [7]:
# val HD

PRED_MODEL_t1 = './saved_models/original_models/original_epoch_2_dsc_0.2770_best_t1.pth'
PRED_MODEL_s = './saved_models//original_models/original_epoch_2_dsc_0.2770_best_s.pth'

model_t1.load_state_dict(torch.load(PRED_MODEL_t1, map_location=device))
model_s.load_state_dict(torch.load(PRED_MODEL_s, map_location=device))

folder_name = None
folder_name = os.path.join("see_image")
os.makedirs(folder_name, exist_ok=True)

# validation
metric_list = 0.0
list = []
model_t1.eval()
model_s.eval()
for batch in tqdm(val_loader):
    image_val, label, id_val = batch
    image_val, label, id_val = image_val.to(device), label.to(device), id_val

    H, W = label.size(1), label.size(2)
    up_sizes = (ceil(H / 8) * 8, ceil(W / 8) * 8)
    data = torch.nn.functional.interpolate(image_val, size=(up_sizes[0], up_sizes[1]),
                                            mode='bilinear', align_corners=True)
    with torch.no_grad():
        loss_t1, output = model_t1(input_ul=data, target_ul=label)
    output = torch.nn.functional.interpolate(output, size=(H, W),
                                                mode='bilinear', align_corners=True)
    
    out = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze(0)
    metric_i = test_single_volume(label, out, classes=2)
    list.append(metric_i[0][1])
    metric_list += np.array(metric_i)
metric_list = metric_list / len(val_set)
print(list)
# index_mDice = np.mean(metric_list, axis=0)[0]
index_mean_hd95 = np.mean(metric_list, axis=0)[1]



# show the best HD
print("original_models valiation : " + f'HD95 = {index_mean_hd95:.4f}') 

100%|██████████| 162/162 [00:03<00:00, 45.15it/s]

[73.82377664290077, 51.19863204015901, 104.4877954824302, 165.66577766554886, 54.037024344425184, 27.582598675073818, 58.75794352136386, 36.05551275463989, 29.17532117774409, 73.12591450262427, 36.235341863986875, 56.00535586144822, 34.9828507202755, 45.55764686119848, 61.08511747271744, 69.11946170882587, 106.16025245284726, 48.30113870293329, 102.94950174442525, 46.87216658103186, 26.832815729997478, 86.34002473492443, 90.77306615441785, 55.323141463735, 29.154759474226502, 52.0, 76.02631123499285, 112.90438315871492, 33.19629283530026, 55.98258213080535, 60.452433948779344, 67.4166151627327, 36.069377593742864, 33.06055050963308, 30.610455730027933, 39.59797974644666, 32.55764119219941, 150.40810957816836, 48.507731342539614, 26.970329614269, 86.58059163133099, 149.10566697511024, 49.768402284594885, 86.01045769134377, 105.4324425047794, 36.345563690772494, 27.65863337187866, 36.134470520342774, 75.13421311940617, 66.89088722214798, 67.18333116708747, 86.78565115969761, 72.945184899




In [8]:
# Test

model_t1.eval()
model_s.eval()
for batch in tqdm(test_loader):
    image_test, label, id_test = batch
    image_test, label, id_test = image_test.to(device), label.to(device), id_test

    H, W = label.size(1), label.size(2)
    up_sizes = (ceil(H / 8) * 8, ceil(W / 8) * 8)

    for i in range(0, int(label.size(0))):
        folder_test = os.path.join(folder_name, "test_target")
        os.makedirs(folder_test, exist_ok=True)
        image = Image.fromarray(np.uint8(label[i].detach().cpu().numpy()))
        image.save(os.path.join(folder_test, str(id_test[i]) + ".png"))

    data = torch.nn.functional.interpolate(image_test, size=(up_sizes[0], up_sizes[1]),
                                            mode='bilinear', align_corners=True)
    with torch.no_grad():
         loss_t1, output = model_t1(input_ul=data,  target_ul=label)
    output = torch.nn.functional.interpolate(output, size=(H, W),
                                                mode='bilinear', align_corners=True)

    for i in range(0, int(output.size(0))):
        folder_test_prob = os.path.join(folder_name, "test_original_models_prob")
        os.makedirs(folder_test_prob, exist_ok=True)
        image_prob = output[i].squeeze().detach()
        image_prob = torch.argmax(image_prob, dim=0).cpu().numpy()
        image_prob = Image.fromarray((image_prob * 255).astype(np.uint8))
        image_prob.save(os.path.join(folder_test_prob, str(id_test[i]) + ".png"))

    out = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze(0)
    metric_i = test_single_volume(label, out, classes=2)
    metric_list += np.array(metric_i)
metric_list = metric_list / len(test_set)

# performance = np.mean(metric_list, axis=0)[0]

mean_hd95 = np.mean(metric_list, axis=0)[1]

# show epoch mIoU, mDice
index_mIoU, index_mDice = count_index(folder_test_prob, folder_test)


print(f'DSC = {index_mDice}, HD95 = {mean_hd95}')

100%|██████████| 157/157 [00:03<00:00, 46.52it/s]


DSC = 0.3011820534715776, HD95 = 58.47744415479772
