Dataset: https://drive.google.com/u/0/uc?id=1p33nsWQaiZMAgsruDoJLyatoq5XAH-TH&export=download

Baseline code: https://github.com/emma-sjwang/Dofe

## Algorithm Implementation

1. Implement the naive baseline model for comparison. This is typically a simple model with basic feature extraction and classification steps.

2. Choose at least one Domain Generalization (DG) method to implement and compare against the naive baseline. This could be FACT, Dofe, or another method of your choice.

3. Report the segmentation performance in terms of dice coefficient and average surface distance.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

%cd "/content/gdrive/MyDrive/Colab Notebooks/ELEC4010N/Final Project/Project2"
!unzip "/content/gdrive/MyDrive/Colab Notebooks/ELEC4010N/Final Project/Project2/Fundus-doFE.zip" -d "/content/"

# pip install external lib
!pip install -U albumentations
!pip install segmentation-models-pytorch
!pip install torchmetrics
!git clone https://github.com/deepmind/surface-distance.git
!pip install surface-distance/

## Dataset info
[cup]Domian1: Drishti-GS dataset [101] including training[50] and testing[51]

[cup]Domain2: RIM-ONE_r3 dataset [159] including training and[99] testing[60]. 

[cup]Domain3: REFUGE training [400]  MICCAI 2018 workshop including training and[320] testing[80]. 

[cup]Domian4: REFUGE val [400]  including training and[320] testing[80]. 
Domain5: ISBI [81]  IDRID chanllenge



In [None]:
## 3 classes
label = cv2.imread("/content/train/mask/G-1-L.png", cv2.IMREAD_GRAYSCALE)
np.unique(label)

array([  0, 128, 255], dtype=uint8)

## Naive Baseline Model

In [None]:
# Just for reference
# UNet implementation
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True))

class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
                
        self.conv_down1 = double_conv(3, 64)      # Number of channel
        self.conv_down2 = double_conv(64, 128)
        self.conv_down3 = double_conv(128, 256)
        self.conv_down4 = double_conv(256, 512)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.conv_up3 = double_conv(256 + 512, 256)
        self.conv_up2 = double_conv(128 + 256, 128)
        self.conv_up1 = double_conv(128 + 64, 64)
        
        self.last_conv = nn.Conv2d(64, num_classes, kernel_size=1)

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        conv1 = self.conv_down1(x)
        x = self.maxpool(conv1)
        conv2 = self.conv_down2(x)
        x = self.maxpool(conv2)
        conv3 = self.conv_down3(x)
        x = self.maxpool(conv3)
        x = self.conv_down4(x)
        x = self.upsample(x)
        
        x = torch.cat([x, conv3], dim=1)
        x = self.conv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.conv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        
        x = self.conv_up1(x)
        out = self.last_conv(x)
        out = torch.sigmoid(out)
        
        return out

NameError: ignored

In [None]:
# Test the model
model = UNet(num_classes=1).to(device)
output = model(torch.randn(1,3,256,256).to(device))
print(f'Output shape: {output.shape}')

## FACT

https://github.com/MediaBrain-SJTU/FACT

In [None]:
# Standard library imports
import os
import shutil
import re
import math
import copy
from copy import deepcopy
import random as randpy
import gc
import numpy as np
import pandas as pd
from math import sqrt

# PyTorch imports for deep learning
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import Dataset, DataLoader, TensorDataset, ConcatDataset, Subset
from torchvision import transforms
from torchvision.datasets import ImageFolder
import torchvision.models as models
from torchvision.models import resnet50 as _resnet50
from torchsummary import summary
import segmentation_models_pytorch as smp

# Other utilities
import cv2
from pylab import *
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from google.colab.patches import cv2_imshow
import albumentations as Augm
from albumentations.pytorch import ToTensorV2
from ast import In
import surface_distance as surfdist
from google.colab.patches import cv2_imshow
from torchmetrics.classification import BinaryJaccardIndex

In [None]:
%cd "/content"

/content


In [None]:
# Define the path to move the data
train_path = "./train"
val_path = "./val"
test_path = "./test"

# Define the path to load the data
D_mask_dir = [["./Fundus/Domain1/train/mask", "./Fundus/Domain1/test/mask"],
              ["./Fundus/Domain2/train/mask", "./Fundus/Domain2/test/mask"],
              ["./Fundus/Domain3/train/mask", "./Fundus/Domain3/test/mask"],
              ["./Fundus/Domain4/train/mask", "./Fundus/Domain4/test/mask"]]

# Define the fundus dataset class to load the data by functions
class FundusDataset(Dataset):
    # Define the constructor
    def __init__(self, path, transform, transform_mask, split_idx):
        self.path = path
        self.transform = transform
        self.transform_mask = transform_mask
        self.split_idx = split_idx
        self.path_df = self.create_path_df()
    
    # Creates a DataFrame with image and mask paths
    def create_path_df(self):
        dirsImg, images, dirsMask, masks = [], [], [], []
        
        for mask_dir in self.path:
            for img_path in mask_dir:
                for _, _, files in os.walk(img_path):
                    for file in files:
                        check_mask = cv2.imread(os.path.join(img_path, file), cv2.IMREAD_GRAYSCALE)
                        
                        if check_mask.sum() > 0:
                            domain = re.search(r'Domain\d', img_path).group() # Extract domain
                            match (domain):
                                case 'Domain1':
                                    images.append(file)
                                case 'Domain2':
                                    images.append(file.replace('png','jpg'))
                                case 'Domain3':
                                    images.append(file.replace('bmp','jpg'))
                                case 'Domain4':
                                    images.append(file.replace('bmp','jpg'))

                            masks.append(file)
                            dirsMask.append(img_path+'/')
                            dirsImg.append(img_path.replace('mask','image')+'/')

        return pd.DataFrame({'direcImg':dirsImg, 'images':images, 'direcMask':dirsMask, 'masks':masks})
    
    # Splits the train dataset into train and validation based on the split index
    def split_traindataset(self):
        for i in range(len(self.path_df)):
            image_direc = self.path_df.loc[i]['direcImg'][:-1] + '/' + self.path_df.loc[i]['images']
            mask_direc = self.path_df.loc[i]['direcMask'][:-1] + '/' + self.path_df.loc[i]['masks']
            domain = re.search(r'Domain\d', image_direc).group() if image_direc else None
            
            # Split the dataset based on the project description
            if domain:
                if domain != "Domain4":
                    destination_path = train_path if i < int(0.8 * len(self.path_df)) else val_path
                    shutil.copy(image_direc, destination_path + "/image/")
                    shutil.copy(mask_direc, destination_path + "/mask/")
                elif domain != "Domain3":
                    destination_path = train_path if i < int(0.8 * len(self.path_df)) else val_path
                    shutil.copy(image_direc, destination_path + "/image/")
                    shutil.copy(mask_direc, destination_path + "/mask/")
                elif domain != "Domain2":
                    destination_path = train_path if i < int(0.8 * len(self.path_df)) else val_path
                    shutil.copy(image_direc, destination_path + "/image/")
                    shutil.copy(mask_direc, destination_path + "/mask/")
                elif domain != "Domain1":
                    destination_path = train_path if i < int(0.8 * len(self.path_df)) else val_path
                    shutil.copy(image_direc, destination_path + "/image/")
                    shutil.copy(mask_direc, destination_path + "/mask/")

        return self.trainGetitem()

    # Moves the test dataset to the test path
    def split_testdataset(self):
        for i in range(len(self.path_df)):
            image_direc = self.path_df.loc[i]['direcImg'][:-1]+'/'+self.path_df.loc[i]['images']
            mask_direc = self.path_df.loc[i]['direcMask'][:-1]+'/'+self.path_df.loc[i]['masks']
            shutil.copy(image_direc, test_path+"/image/")
            shutil.copy(mask_direc, test_path+"/mask/")

        return self.testGetitem()

    # Returns length of the dataset
    def __len__(self):
        return len(self.path_df)

    # Processes the images and masks in the given directory
    def process_images(self, directory, transform):
        images = []
        masks = []
        folders = os.listdir(directory)
        
        for folder in folders:
            files = os.listdir(os.path.join(directory, folder))
            
            for file in files:
                direc = os.path.join(directory, folder, file)
                
                if 'mask' in folder:
                    mask = cv2.imread(direc, cv2.IMREAD_GRAYSCALE)
                    mask = transform(mask)
                    mask = torch.round(mask)
                    masks.append(mask)
                else:
                    fundus_img = cv2.imread(direc)
                    fundus_img = transform(fundus_img)
                    images.append(fundus_img)
        
        return images, masks

    # Gets the train and validation items
    def trainGetitem(self):
        trainImg, trainMask = self.process_images(train_path, self.transform)
        valImg, valMask = self.process_images(val_path, self.transform)
        return [[trainImg, trainMask], [valImg, valMask]]

    # Gets the test items
    def testGetitem(self):
        testImg, testMask = self.process_images(test_path, self.transform)
        return [[testImg, testMask]]

    #  Creates a list of image and mask paths
    def direc(self):
        FundusImg_direc = []
        Mask_direc = []
        train_folders = os.listdir(train_path)
        
        for folder in train_folders:
            files = os.listdir(os.path.join(train_path, folder))
            
            for file in files:
                direc = os.path.join(train_path, folder, file)
                
                if 'mask' in folder:
                    Mask_direc.append(direc)
                else:
                    FundusImg_direc.append(direc)
        
        return FundusImg_direc, Mask_direc

    # Stacks the datasets
    def stack_datasets(self, datasets):
        stacked_datasets = []
        
        for dataset in datasets:
            fundus = dataset[0]
            mask = dataset[1]
            inputs = torch.stack(fundus)
            labels = torch.stack(mask)
            stacked_datasets.append(TensorDataset(inputs, labels))
        
        return stacked_datasets

    # Stacks the train and validation datasets
    def stack_train(self):
        datasets = self.trainGetitem()
        return self.stack_datasets(datasets)

    # Stacks the test dataset
    def stack_test(self):
        datasets = self.testGetitem()
        return self.stack_datasets(datasets)[0]

In [None]:
# Creates the train, validation, and test folders
for path in [train_path, val_path, test_path]:
    if os.path.exists(path):
        shutil.rmtree(path)
    os.makedirs(path)
    os.makedirs(os.path.join(path, "image"))
    os.makedirs(os.path.join(path, "mask"))

# Normalization
normalization = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(64),
    transforms.RandomCrop(64),
    transforms.ToTensor(),
])

# Create the datasets
# train = 1,2,3 test= 4
train_class = FundusDataset(D_mask_dir[:3], transform=normalization, transform_mask=normalization, split_idx=1)
test_class = FundusDataset(D_mask_dir[3:], transform=normalization, transform_mask=normalization, split_idx=1)

# Split files of system
train_class.split_traindataset()
test_class.split_testdataset()

### data_utils.py

In [None]:
%cd "/content/"

/content


In [None]:
def dataset_info(filepath):
    img_direcs = []
    mask_direcs = []
    for folder in os.listdir(filepath):
        files = os.listdir(filepath+'/'+folder)
        for file in files:
            direc = filepath+'/'+folder+'/'+file
            if 'mask' in direc: mask_direcs.append(direc)
            else: img_direcs.append(direc)
    return img_direcs, mask_direcs


def get_img_transform(train=False, image_size=224, crop=False, jitter=0):
    mean = [0.5]
    std = [0.5]
    if train:
        if crop:
            img_transform = [transforms.ToPILImage(),transforms.RandomResizedCrop(image_size, scale=[0.8, 1.0])]
        else:
            img_transform = [transforms.ToPILImage(),transforms.Resize((image_size, image_size))]
        if jitter > 0:
            img_transform.append(transforms.ColorJitter(brightness=jitter,
                                                        contrast=jitter,
                                                        saturation=jitter,
                                                        hue=min(0.5, jitter)))
        img_transform += [transforms.RandomHorizontalFlip(),
                          transforms.ToTensor(),
                          transforms.Normalize(mean, std)]
        img_transform = transforms.Compose(img_transform)
    else:
        img_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    return img_transform

def get_label_transform(train=False, image_size=224, crop=False, jitter=0):
    mean = [0.5]
    std = [0.5]
    if train:
        if crop:
            label_transform = [transforms.ToPILImage(),transforms.RandomResizedCrop(image_size, scale=[0.8, 1.0])]
        else:
            label_transform = [transforms.ToPILImage(),transforms.Resize((image_size, image_size))]
        if jitter > 0:
            label_transform.append(transforms.ColorJitter(brightness=jitter,
                                                        contrast=jitter,
                                                        saturation=jitter,
                                                        hue=min(0.5, jitter)))
        label_transform += [transforms.RandomHorizontalFlip(),
                          transforms.ToTensor(),
                          transforms.Normalize(mean, std)]
        label_transform = transforms.Compose(label_transform)
    else:
        label_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    return label_transform
    

def get_pre_transform(image_size=224, crop=False, jitter=0):
    if crop:
        img_transform = [transforms.ToPILImage(), transforms.RandomResizedCrop(image_size, scale=[0.8, 1.0])]
    else:
        img_transform = [transforms.ToPILImage(), transforms.Resize((image_size, image_size))]
    if jitter > 0:
        img_transform.append(transforms.ColorJitter(brightness=jitter,
                                                    contrast=jitter,
                                                    saturation=jitter,
                                                    hue=min(0.5, jitter)))
    img_transform += [transforms.RandomHorizontalFlip(), lambda x: np.asarray(x)]
    img_transform = transforms.Compose(img_transform)
    return img_transform

def get_post_transform(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    img_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    return img_transform

def colorful_spectrum_mix(img1, img2, alpha, ratio=1.0):
    """Input image size: ndarray of [H, W, C]"""
    lam = np.random.uniform(0, alpha)

    assert img1.shape == img2.shape
    h, w, c = img1.shape
    h_crop = int(h * sqrt(ratio))
    w_crop = int(w * sqrt(ratio))
    h_start = h // 2 - h_crop // 2
    w_start = w // 2 - w_crop // 2

    img1_fft = np.fft.fft2(img1, axes=(0, 1))
    img2_fft = np.fft.fft2(img2, axes=(0, 1))
    img1_abs, img1_pha = np.abs(img1_fft), np.angle(img1_fft)
    img2_abs, img2_pha = np.abs(img2_fft), np.angle(img2_fft)

    img1_abs = np.fft.fftshift(img1_abs, axes=(0, 1))
    img2_abs = np.fft.fftshift(img2_abs, axes=(0, 1))

    img1_abs_ = np.copy(img1_abs)
    img2_abs_ = np.copy(img2_abs)
    img1_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = \
        lam * img2_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img1_abs_[
                                                                                          h_start:h_start + h_crop,
                                                                                          w_start:w_start + w_crop]
    img2_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = \
        lam * img1_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img2_abs_[
                                                                                          h_start:h_start + h_crop,
                                                                                          w_start:w_start + w_crop]

    img1_abs = np.fft.ifftshift(img1_abs, axes=(0, 1))
    img2_abs = np.fft.ifftshift(img2_abs, axes=(0, 1))

    img21 = img1_abs * (np.e ** (1j * img1_pha))
    img12 = img2_abs * (np.e ** (1j * img2_pha))
    img21 = np.real(np.fft.ifft2(img21, axes=(0, 1)))
    img12 = np.real(np.fft.ifft2(img12, axes=(0, 1)))
    img21 = np.uint8(np.clip(img21, 0, 255))
    img12 = np.uint8(np.clip(img12, 0, 255))

    return img21, img12

 ### 1.1 Data Read(Fourier) & Load 

In [None]:
class DGDataset(Dataset):
    def __init__(self, names, labels, transformer=None):
        self.names = names
        self.labels = labels
        self.transformer = transformer

    def __len__(self):
        return len(self.names)

    def __getitem__(self, index):
        img_name = self.names[index]
        mask_name = self.labels[index]
        img = cv2.imread(img_name)
        mask = cv2.imread(mask_name, cv2.IMREAD_GRAYSCALE)
        if self.transformer is not None:
            img = self.transformer(img)
            mask = self.transformer(mask)
        return img, mask

class FourierDGDataset(Dataset):
    def __init__(self, names, labels, transformer=None, lbl_transformer=None, from_domain=None, alpha=1.0):
        self.names = names
        self.labels = labels
        self.transformer = transformer
        self.lbl_transformer = lbl_transformer
        self.post_transform = get_post_transform()
        self.from_domain = from_domain
        self.alpha = alpha
        
        self.flat_names = []
        self.flat_labels = []
        self.flat_domains = []
        for i in range(len(names)):
            self.flat_domains += [i] * len(names[i])
            self.flat_names += names[i]
            self.flat_labels += labels[i]

    def __len__(self):
        return len(self.names)

    def __getitem__(self, index):
        img_name = self.names[index]
        mask_name = self.labels[index]
        img = cv2.imread(img_name)
        mask = cv2.imread(mask_name, cv2.IMREAD_GRAYSCALE)
        img_o = self.transformer(img)
        mask_o = self.lbl_transformer(mask)

        img_s, mask_s, _ = self.sample_image()  ## random pick image
        img_s2o, img_o2s = colorful_spectrum_mix(img_o, img_s, alpha=self.alpha)  ## mix their amplitude
        img_o, img_s = self.post_transform(img_o), self.post_transform(img_s)
        img_s2o, img_o2s = self.post_transform(img_s2o), self.post_transform(img_o2s)
        img = [img_o, img_s, img_s2o, img_o2s]  ## [original, img2, img1 x img2, img2 x img1] 
        mask = [mask_o, mask_s, mask_o, mask_s]
        # domain = [domain, domain_s, domain, domain_s]
        return img, mask

    def sample_image(self, domain=None):
        if self.from_domain == 'all':
            domain_idx = randpy.randint(0, len(self.names)-1)
        elif self.from_domain == 'inter':
            domains = list(range(len(self.names)))
            # domains.remove(domain)
            domain_idx = randpy.sample(domains, 1)[0]
        elif self.from_domain == 'intra':
            domain_idx = domain
        else:
            raise ValueError("Not implemented")
        img_idx = randpy.randint(0, len(self.names[domain_idx])-1)
        imgn_ame_sampled = self.names[img_idx]
        img_sampled = cv2.imread(imgn_ame_sampled)
        label_ame_sampled = self.labels[img_idx]
        label_sampled = cv2.imread(label_ame_sampled, cv2.IMREAD_GRAYSCALE)
        label_sampled = self.lbl_transformer(label_sampled)
        return self.transformer(img_sampled), label_sampled, domain_idx


def get_dataset(path, train=False, image_size=64, crop=False, jitter=0):
    names, labels = dataset_info(path)
    img_transform = get_img_transform(train, image_size, crop, jitter)
    return DGDataset(names, labels, img_transform)

def get_fourier_dataset(path, image_size=64, crop=False, jitter=0, from_domain='all', alpha=1.0):
    names, labels = dataset_info(path)
    img_transform = get_pre_transform(image_size, crop, jitter)
    lbl_transform = get_label_transform(train=True, image_size=image_size, crop=crop, jitter=jitter)
    return FourierDGDataset(names, labels, img_transform, lbl_transform, from_domain, alpha)

In [None]:
trian_path = "./train"
val_path = "./val"
test_path = "./test"

train_fourier_dataset = get_fourier_dataset(path=train_path,crop=True,jitter=0.1)
val_dataset = get_dataset(path=val_path, train=True,crop=True,jitter=0.1)
test_dataset = get_dataset(path=test_path, train=True,crop=True,jitter=0.1)

train_fourier_loader = torch.utils.data.DataLoader(train_fourier_dataset, batch_size=8, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True, num_workers=0)


print(f'Shape of image: {train_fourier_loader.dataset[0][0][0].numpy().shape}')
print(f'Number of training batches: {len(train_fourier_loader)}')
print(f'Number of validation batches: {len(val_loader)}')
print(f'Number of test batches: {len(test_loader)}')

Shape of image: (3, 64, 64)
Number of training batches: 66
Number of validation batches: 17
Number of test batches: 50


### 2 Build Model

#### Unet2D

In [None]:
## Load model
Unet2D_model = smp.Unet(
    encoder_name="resnet50",
    encoder_weights="imagenet",
    in_channels=3,
    classes=3,
)

#### Mean Teacher Model

In [None]:
# Mean Teacher Model
# Student model would be ResNet50 model
class MeanTeacherModel(nn.Module):
    # Core
    def __init__(self, student_model, ema_decay):
        super().__init__()
        self.student_model = student_model
        self.teacher_model = deepcopy(student_model)
        self.ema_decay = ema_decay

    def forward(self, x):
        return self.student_model(x)

    def update_teacher_model(self, current_epoch, momentum=0.9995):
        momentum = min(1 - 1 / (current_epoch+1), self.ema_decay)
        with torch.no_grad():
            for student_params, teacher_params in zip(self.student_model.parameters(), self.teacher_model.parameters()):
                teacher_params.data.mul_(momentum).add_((1 - momentum) * student_params.data)

    # Adjust the weight of the consistency loss to rely on teacher's prediction
    # The weight factor decreases from 1 to 0 during the first 15 epochs
    def sigmoid_rampup(self, current_epoch):
        current_epoch = np.clip(current_epoch, 0.0, 5.0)
        phase = 1.0 - current_epoch / 5.0
        return np.exp(-5.0 * phase * phase).astype(np.float32)

    # The weight decreases from 100
    def get_consistency_weight(self, epoch):
        return 2.0 * self.sigmoid_rampup(epoch)

### Medical Image Metrics

In [None]:
class SoftDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(SoftDiceLoss, self).__init__()

    ## dice loss
    def forward(self, logits, targets, smooth=1):
        num = targets.size(0)
        probs = logits
        print(probs.shape)
        print(targets.shape)
        m1 = probs.view(num, -1)
        m2 = targets.view(num, -1)
        print(m1.shape)
        print(m2.shape)
        intersection = (m1 * m2)
 
        score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
        return 1 - score.sum() / num

    def dice_coeff(self, logits, targets):
        return 1-(self.forward(logits, targets))

###3. Training

In [169]:
def gpu_clean():
    gc.collect()
    torch.cuda.empty_cache()


def train_mean_teacher(model, train_loader, val_loader, optimizer, consistency_criterion, supervised_criterion, device, epochs):
    loss_train_list = []
    acc_train_list = []
    model.student_model.train()
    model.teacher_model.train()
    for epoch in range(epochs):
        ## Train
        for it, (batch, label) in enumerate(train_loader):
            gpu_clean()

            batch = torch.cat(batch, dim=0).cuda()
            label = torch.cat(label, dim=0).cuda()

            # zero grad
            optimizer.zero_grad()

            # forward
            total_loss_train = 0
            super_loss_train = 0
            const_loss_train = 0

            # print(batch.shape)
            # print(label.shape)
            scores = model.student_model(batch)
            with torch.no_grad():
                scores_teacher = model.teacher_model(batch)
            # scores = F.softmax(scores, dim=1)
            # scores_teacher = F.softmax(scores_teacher, dim=1)

            assert batch.size(0) % 2 == 0
            split_idx = int(batch.size(0) / 2)
            scores_ori, scores_aug = torch.split(scores, split_idx)
            scores_ori_tea, scores_aug_tea = torch.split(scores_teacher, split_idx)
            scores_ori_tea, scores_aug_tea = scores_ori_tea.detach(), scores_aug_tea.detach()
            labels_ori, labels_aug = torch.split(label, split_idx)
            print(labels_ori.shape)
            assert scores_ori.size(0) == scores_aug.size(0)

            # original data
            # print(scores_ori.shape)
            # for item in scores_ori[scores_ori > 0]:
            #     print(item.item())
            # print(labels_ori.shape)
            # print(labels_ori[labels_ori > 0])
            # print(labels_ori.to(torch.int64))
            loss_cls = supervised_criterion(scores_ori, labels_ori.to(torch.int64))
            # augmented data
            loss_aug = supervised_criterion(scores_aug, labels_aug)

            # calculate probability
            p_ori, p_aug = F.softmax(scores_ori / 10.0, dim=1), F.softmax(scores_aug / 10.0, dim=1)
            p_ori_tea, p_aug_tea = F.softmax(scores_ori_tea / 10.0, dim=1), F.softmax(scores_aug_tea / 10.0, dim=1)

            # use KLD for consistency loss
            loss_ori_tea = consistency_criterion(p_aug.log(), p_ori_tea, reduction='batchmean')
            loss_aug_tea = consistency_criterion(p_ori.log(), p_aug_tea, reduction='batchmean')

            # get consistency weight
            const_weight = model.get_consistency_weight(epoch)

            # calculate total loss
            total_loss = 0.5 * loss_cls + 0.5 * loss_aug + \
                         const_weight * loss_ori_tea + const_weight * loss_aug_tea

            # update
            total_loss.backward()
            optimizer.step()
 
            # Update teachers model parameters
            model.update_teacher_model(current_epoch=epoch)
            total_loss_train += total_loss.item()


        loss_train_list.append(total_loss_train)
        my_lr_scheduler.step()
        print("Epoch [{}/{}], Loss: {:.4f}".format(
            epoch+1,epochs,
            total_loss_train/len(train_loader),
        ))


    return model, loss_train_list


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model = Unet2D_model
base_model = base_model.cuda()

mean_teacher_model = MeanTeacherModel(base_model, ema_decay=0.99)
mean_teacher_model = mean_teacher_model.to(device)

optimizer = torch.optim.Adam(mean_teacher_model.parameters(), lr=0.0001)
consistency_criterion = F.kl_div
supervised_criterion = smp.losses.DiceLoss(mode='multilabel', classes=3)

max_epoch = 2
my_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=max_epoch)
mean_teacher_model, TRAIN_LOSS = train_mean_teacher(mean_teacher_model, train_fourier_loader, val_loader, 
                                                    optimizer, consistency_criterion, supervised_criterion, device, epochs=max_epoch)

RuntimeError: ignored

In [172]:
a = torch.rand(8,1,64,64)
print(a.to(torch.int64))
A = [a,a,a,a]
final = torch.cat(A, dim=0)
split_idx = int(final.size(0) / 2)
final0, final1 = torch.split(final, split_idx)
print(final0.shape)
final0.reshape((16,64*64)).shape

tensor([[[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]],


        [[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]],


        [[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]],


        ...,


        [[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]]],


        [[[0, 0, 0,  ..., 0, 0, 0],
         

torch.Size([16, 4096])