Dataset: https://drive.google.com/u/0/uc?id=1p33nsWQaiZMAgsruDoJLyatoq5XAH-TH&export=download

Baseline code: https://github.com/emma-sjwang/Dofe

## Algorithm Implementation

1. Implement the naive baseline model for comparison. This is typically a simple model with basic feature extraction and classification steps.

2. Choose at least one Domain Generalization (DG) method to implement and compare against the naive baseline. This could be FACT, Dofe, or another method of your choice.

3. Report the segmentation performance in terms of dice coefficient and average surface distance.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

%cd "/content/gdrive/MyDrive/Colab Notebooks/ELEC4010N/Final Project/Project2"
!unzip "/content/gdrive/MyDrive/Colab Notebooks/ELEC4010N/Final Project/Project2/Fundus-doFE.zip" -d "/content/"

## pip install external lib
!pip install -U albumentations
!pip install torchmetrics

## Dataset info
[cup]Domian1: Drishti-GS dataset [101] including training[50] and testing[51]

[cup]Domain2: RIM-ONE_r3 dataset [159] including training and[99] testing[60]. 

[cup]Domain3: REFUGE training [400]  MICCAI 2018 workshop including training and[320] testing[80]. 

[cup]Domian4: REFUGE val [400]  including training and[320] testing[80]. 
Domain5: ISBI [81]  IDRID chanllenge



## Naive Baseline Model

In [None]:
# Just for reference
# UNet implementation
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True))

class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
                
        self.conv_down1 = double_conv(3, 64)      # Number of channel
        self.conv_down2 = double_conv(64, 128)
        self.conv_down3 = double_conv(128, 256)
        self.conv_down4 = double_conv(256, 512)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.conv_up3 = double_conv(256 + 512, 256)
        self.conv_up2 = double_conv(128 + 256, 128)
        self.conv_up1 = double_conv(128 + 64, 64)
        
        self.last_conv = nn.Conv2d(64, num_classes, kernel_size=1)

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        conv1 = self.conv_down1(x)
        x = self.maxpool(conv1)
        conv2 = self.conv_down2(x)
        x = self.maxpool(conv2)
        conv3 = self.conv_down3(x)
        x = self.maxpool(conv3)
        x = self.conv_down4(x)
        x = self.upsample(x)
        
        x = torch.cat([x, conv3], dim=1)
        x = self.conv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.conv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        
        x = self.conv_up1(x)
        out = self.last_conv(x)
        out = torch.sigmoid(out)
        
        return out

In [None]:
# Test the model
model = UNet(num_classes=1).to(device)
output = model(torch.randn(1,3,256,256).to(device))
print(f'Output shape: {output.shape}')

## FACT

https://github.com/MediaBrain-SJTU/FACT

In [None]:
# Standard library imports
import os
import shutil
import math
import copy
import random
import gc
import numpy as np
import pandas as pd
from math import sqrt

# PyTorch imports for deep learning
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import Dataset, DataLoader, TensorDataset, ConcatDataset, Subset
from torchvision import transforms
from torchvision.datasets import ImageFolder
import torchvision.models as models
from torchvision.models import resnet50 as _resnet50
from torchsummary import summary

# Other utilities
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from google.colab.patches import cv2_imshow
import albumentations as Augm
from albumentations.pytorch import ToTensorV2

In [None]:
%cd "/content"

/content


In [None]:
train_path = "./train"
val_path = "./val"
test_path = "./test"
D1_Mask_dir = ["/content/Fundus/Domain1/test/mask", "/content/Fundus/Domain1/train/mask"]
D2_Mask_dir = ["/content/Fundus/Domain2/test/mask", "/content/Fundus/Domain2/train/mask"]
D3_Mask_dir = ["/content/Fundus/Domain3/test/mask", "/content/Fundus/Domain3/train/mask"]
D4_Mask_dir = ["/content/Fundus/Domain4/test/mask", "/content/Fundus/Domain4/train/mask"]


class Fundus_Dataset(Dataset):
    def __init__(self, path, transform, transform_mask, split_idx):
        self.PATH = path
        self.transform = transform
        self.transform_mask = transform_mask
        self.split_idx = split_idx
        dirsImg, images, dirsMask, masks = [], [], [], []
        for PATH in self.PATH:
            for img_path in PATH:
                for root, folders, files in os.walk(img_path):
                    for file in files:
                        check_mask = cv2.imread(img_path+'/'+file, cv2.IMREAD_GRAYSCALE)
                        if check_mask.sum() > 0:
                            match (img_path[16:23]):  ## Domain idx
                                case 'Domain1':
                                    images.append(file)
                                case 'Domain2':
                                    images.append(file.replace('png','jpg'))
                                case 'Domain3':
                                    images.append(file.replace('bmp','jpg'))
                                case 'Domain4':
                                    images.append(file.replace('bmp','jpg'))
                            masks.append(file)
                            dirsMask.append(img_path+'/')
                            dirsImg.append(img_path.replace('mask','image')+'/')
        path_df = pd.DataFrame({'direcImg':dirsImg, 'images':images, 'direcMask':dirsMask, 'masks':masks})
        self.path_df = path_df
    
    def split_traindataset(self):
        for i in range(len(self.path_df)):
            image_direc = self.path_df.loc[i]['direcImg'][:-1]+'/'+self.path_df.loc[i]['images']
            mask_direc = self.path_df.loc[i]['direcMask'][:-1]+'/'+self.path_df.loc[i]['masks']
            match(self.split_idx):
                case 1:
                    if "Domain4" not in image_direc:
                        if i < int(0.8*len(self.path_df)):
                            shutil.copy(image_direc, train_path+"/image/")
                            shutil.copy(mask_direc, train_path+"/mask/")
                        else:
                            shutil.copy(image_direc, val_path+"/image/")
                            shutil.copy(mask_direc, val_path+"/mask/")
                case 2:
                    if "Domain3" not in image_direc:
                        if i < int(0.8*len(self.path_df)):
                            shutil.copy(image_direc, train_path+"/image/")
                            shutil.copy(mask_direc, train_path+"/mask/")
                        else:
                            shutil.copy(image_direc, val_path+"/image/")
                            shutil.copy(mask_direc, val_path+"/mask/")
                case 3:
                    if "Domain2" not in image_direc:
                        if i < int(0.8*len(self.path_df)):
                            shutil.copy(image_direc, train_path+"/image/")
                            shutil.copy(mask_direc, train_path+"/mask/")
                        else:
                            shutil.copy(image_direc, val_path+"/image/")
                            shutil.copy(mask_direc, val_path+"/mask/")
                case 4:
                    if "Domain1" not in image_direc:
                        if i < int(0.8*len(self.path_df)):
                            shutil.copy(image_direc, train_path+"/image/")
                            shutil.copy(mask_direc, train_path+"/mask/")
                        else:
                            shutil.copy(image_direc, val_path+"/image/")
                            shutil.copy(mask_direc, val_path+"/mask/")

        return self.trainGetitem()

    def split_testdataset(self):
        for i in range(len(self.path_df)):
            image_direc = self.path_df.loc[i]['direcImg'][:-1]+'/'+self.path_df.loc[i]['images']
            mask_direc = self.path_df.loc[i]['direcMask'][:-1]+'/'+self.path_df.loc[i]['masks']
            shutil.copy(image_direc, test_path+"/image/")
            shutil.copy(mask_direc, test_path+"/mask/")
        return self.testGetitem()

    def __len__(self):
        return len(self.path_df)
  
    def trainGetitem(self):
        trainImg, trainMask = [], []
        valImg, valMask = [], []
        train_folders = os.listdir(train_path)
        for folder in train_folders:
            files = os.listdir(train_path+'/'+folder)
            for file in files:
                direc = train_path+'/'+folder+'/'+file
                if 'mask' in folder:
                    mask = cv2.imread(direc, cv2.IMREAD_GRAYSCALE)
                    mask = self.transform(mask)
                    mask = torch.round(mask)
                    trainMask.append(mask)
                else:
                    fundus_img = cv2.imread(direc)
                    fundus_img = self.transform(fundus_img)
                    trainImg.append(fundus_img)

        val_folders = os.listdir(val_path)
        for folder in val_folders:
            files = os.listdir(val_path+'/'+folder)
            for file in files:
                direc = val_path+'/'+folder+'/'+file
                if 'mask' in folder:
                    mask = cv2.imread(direc, cv2.IMREAD_GRAYSCALE)
                    mask = self.transform(mask)
                    mask = torch.round(mask)
                    valMask.append(mask)
                else:
                    fundus_img = cv2.imread(direc)
                    fundus_img = self.transform(fundus_img)
                    valImg.append(fundus_img)
        return [[trainImg,trainMask], [valImg,valMask]]

    def testGetitem(self):
        testImg, testMask = [], []
        test_folders = os.listdir(test_path)
        for folder in test_folders:
            files = os.listdir(test_path+'/'+folder)
            for file in files:
                direc = test_path+'/'+folder+'/'+file
                if 'mask' in folder:
                    mask = cv2.imread(direc, cv2.IMREAD_GRAYSCALE)
                    mask = self.transform(mask)
                    mask = torch.round(mask)
                    testMask.append(mask)
                else:
                    fundus_img = cv2.imread(direc)
                    fundus_img = self.transform(fundus_img)
                    testImg.append(fundus_img)
        return [[testImg,testMask]]

    def direc(self):
        FundusImg_direc = []
        Mask_direc = []
        train_folders = os.listdir(train_path)
        for folder in train_folders:
            files = os.listdir(train_path+'/'+folder)
            for file in files:
                direc = train_path+'/'+folder+'/'+file
                if 'mask' in folder:
                    Mask_direc.append(direc)
                else:
                    FundusImg_direc.append(direc)
        return FundusImg_direc, Mask_direc

    def stack_train(self, batch_size=8):
        All_dataset = []
        DATASET = self.trainGetitem()
        for dataset in DATASET:
            fundus = dataset[0]
            mask = dataset[1]
            inputs = (fundus[0])
            labels = (mask[0])
            inputs = torch.stack(fundus)
            labels = torch.stack(mask)
            All_dataset.append(TensorDataset(inputs,labels))
        return All_dataset[0], All_dataset[1]

    def stack_test(self, batch_size=8):
        All_dataset = []
        DATASET = self.testGetitem()
        for dataset in DATASET:
            fundus = dataset[0]
            mask = dataset[1]
            inputs = (fundus[0])
            labels = (mask[0])
            inputs = torch.stack(fundus)
            labels = torch.stack(mask)
            All_dataset.append(TensorDataset(inputs,labels))
        return All_dataset[0]

In [None]:
for path in [train_path, val_path, test_path]:
    if os.path.exists(path):
        shutil.rmtree(path)
    os.makedirs(path)

for path in [train_path, val_path, test_path]:
    os.makedirs(os.path.join(path, "image"))
    os.makedirs(os.path.join(path, "mask"))

normalization = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(64),
    transforms.RandomCrop(64),
    transforms.ToTensor(),
])


## Raw Dataset
Train_class = Fundus_Dataset([D1_Mask_dir,D2_Mask_dir,D3_Mask_dir], transform=normalization, transform_mask=normalization, split_idx=1)
Test_class = Fundus_Dataset([D4_Mask_dir], transform=normalization, transform_mask=normalization, split_idx=1)
Train_class.split_traindataset()
Test_class.split_testdataset()
train_dataset, val_dataset = Train_class.stack_train()
test_dataset = Test_class.stack_test()

train_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)


print(f'Shape of image: {train_loader.dataset[0][0].shape}')       
print(f'Number of training batches: {len(train_loader)}')          
print(f'Number of validation batches: {len(val_loader)}')          
print(f'Number of test batches: {len(test_loader)}')               

Shape of image: torch.Size([3, 64, 64])
Number of training batches: 50
Number of validation batches: 17
Number of test batches: 400


### data_utils.py

In [None]:
%cd "/content/"

In [222]:
def dataset_info(filepath):
    img_direcs = []
    mask_direcs = []
    for folder in os.listdir(filepath):
        files = os.listdir(filepath+'/'+folder)
        for file in files:
            direc = filepath+'/'+folder+'/'+file
            if 'mask' in direc: mask_direcs.append(direc)
            else: img_direcs.append(direc)
    return img_direcs, mask_direcs


def get_img_transform(train=False, image_size=224, crop=False, jitter=0):
    mean = [0.5]
    std = [0.5]
    if train:
        if crop:
            img_transform = [transforms.ToPILImage(),transforms.RandomResizedCrop(image_size, scale=[0.8, 1.0])]
        else:
            img_transform = [transforms.ToPILImage(),transforms.Resize((image_size, image_size))]
        if jitter > 0:
            img_transform.append(transforms.ColorJitter(brightness=jitter,
                                                        contrast=jitter,
                                                        saturation=jitter,
                                                        hue=min(0.5, jitter)))
        img_transform += [transforms.RandomHorizontalFlip(),
                          transforms.ToTensor(),
                          transforms.Normalize(mean, std)]
        img_transform = transforms.Compose(img_transform)
    else:
        img_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    return img_transform
    

def get_pre_transform(image_size=224, crop=False, jitter=0):
    if crop:
        img_transform = [transforms.ToPILImage(), transforms.RandomResizedCrop(image_size, scale=[0.8, 1.0])]
    else:
        img_transform = [transforms.ToPILImage(), transforms.Resize((image_size, image_size))]
    if jitter > 0:
        img_transform.append(transforms.ColorJitter(brightness=jitter,
                                                    contrast=jitter,
                                                    saturation=jitter,
                                                    hue=min(0.5, jitter)))
    img_transform += [transforms.RandomHorizontalFlip(), transforms.ToTensor(), lambda x: np.asarray(x)]
    img_transform = transforms.Compose(img_transform)
    return img_transform

def get_post_transform(mean=[0.5], std=[0.5]):
    img_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    return img_transform

def colorful_spectrum_mix(img1, img2, alpha, ratio=1.0):
    """Input image size: ndarray of [H, W, C]"""
    lam = np.random.uniform(0, alpha)

    assert img1.shape == img2.shape
    h, w, c = img1.shape
    h_crop = int(h * sqrt(ratio))
    w_crop = int(w * sqrt(ratio))
    h_start = h // 2 - h_crop // 2
    w_start = w // 2 - w_crop // 2

    img1_fft = np.fft.fft2(img1, axes=(0, 1))
    img2_fft = np.fft.fft2(img2, axes=(0, 1))
    img1_abs, img1_pha = np.abs(img1_fft), np.angle(img1_fft)
    img2_abs, img2_pha = np.abs(img2_fft), np.angle(img2_fft)

    img1_abs = np.fft.fftshift(img1_abs, axes=(0, 1))
    img2_abs = np.fft.fftshift(img2_abs, axes=(0, 1))

    img1_abs_ = np.copy(img1_abs)
    img2_abs_ = np.copy(img2_abs)
    img1_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = \
        lam * img2_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img1_abs_[
                                                                                          h_start:h_start + h_crop,
                                                                                          w_start:w_start + w_crop]
    img2_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = \
        lam * img1_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img2_abs_[
                                                                                          h_start:h_start + h_crop,
                                                                                          w_start:w_start + w_crop]

    img1_abs = np.fft.ifftshift(img1_abs, axes=(0, 1))
    img2_abs = np.fft.ifftshift(img2_abs, axes=(0, 1))

    img21 = img1_abs * (np.e ** (1j * img1_pha))
    img12 = img2_abs * (np.e ** (1j * img2_pha))
    img21 = np.real(np.fft.ifft2(img21, axes=(0, 1)))
    img12 = np.real(np.fft.ifft2(img12, axes=(0, 1)))
    img21 = np.uint8(np.clip(img21, 0, 255))
    img12 = np.uint8(np.clip(img12, 0, 255))

    return img21, img12

 ### 1.1 Data Reader

In [225]:
class DGDataset(Dataset):
    def __init__(self, names, labels, transformer=None):
        self.names = names
        self.labels = labels
        self.transformer = transformer

    def __len__(self):
        return len(self.names)

    def __getitem__(self):
        Img, Mask = [], []
        for idx in range(len(self.names)):
            img_name = self.names[idx]
            mask_name = self.labels[idx]
            img = cv2.imread(img_name)
            mask = cv2.imread(mask_name, cv2.IMREAD_GRAYSCALE)
            if self.transformer is not None:
                img = self.transformer(img)
                mask = self.transformer(mask)
            Img.append(img)
            Mask.append(mask)
        return Img, Mask

    def stack_dataset(self, batch_size=8):
        IMAGE, MASK = self.__getitem__()
        images = IMAGE[0]
        masks = MASK[0]
        inputs = (images[0])
        labels = (masks[0])
        inputs = torch.stack(IMAGE)
        labels = torch.stack(MASK)
        dataset = TensorDataset(inputs,labels)
        return dataset

class FourierDGDataset(Dataset):
    def __init__(self, names, labels, transformer=None, from_domain=None, alpha=1.0):
        self.names = names
        self.labels = labels
        self.transformer = transformer
        self.post_transform = get_post_transform()
        self.from_domain = from_domain
        self.alpha = alpha
        
        self.flat_names = []
        self.flat_labels = []
        self.flat_domains = []
        for i in range(len(names)):
            self.flat_domains += [i] * len(names[i])
            self.flat_names += names[i]
            self.flat_labels += labels[i]

    def __len__(self):
        return len(self.flat_names)

    def __getitem__(self):
        Img, Mask = [], []
        for idx in range(len(self.names)):
            img_name = self.names[idx]
            mask_name = self.labels[idx]
            img = cv2.imread(img_name)
            mask = cv2.imread(mask_name, cv2.IMREAD_GRAYSCALE)
            img_o = self.transformer(img)
            mask_o = self.transformer(mask)

            img_s, mask_s, _ = self.sample_image()  ## random pick image
            img_s2o, img_o2s = colorful_spectrum_mix(img_o, img_s, alpha=self.alpha)  ## mix their amplitude
            img_o, img_s = self.post_transform(img_o), self.post_transform(img_s)
            img_s2o, img_o2s = self.post_transform(img_s2o), self.post_transform(img_o2s)

            img = [img_o, img_s, img_s2o, img_o2s]
            mask = [mask, mask_s, mask, mask_s]
            Img.append(img)
            Mask.append(mask)
            # domain = [domain, domain_s, domain, domain_s]
        return Img, Mask

    def sample_image(self, domain=None):
        if self.from_domain == 'all':
            domain_idx = random.randint(0, len(self.names)-1)
        elif self.from_domain == 'inter':
            domains = list(range(len(self.names)))
            # domains.remove(domain)
            domain_idx = random.sample(domains, 1)[0]
        elif self.from_domain == 'intra':
            domain_idx = domain
        else:
            raise ValueError("Not implemented")
        img_idx = random.randint(0, len(self.names[domain_idx])-1)
        imgn_ame_sampled = self.names[img_idx]
        img_sampled = cv2.imread(imgn_ame_sampled)
        label_ame_sampled = self.labels[img_idx]
        label_sampled = cv2.imread(label_ame_sampled)
        return self.transformer(img_sampled), self.transformer(label_sampled), domain_idx

    def stack_dataset(self, batch_size=8):
        IMAGE, MASK = self.__getitem__()
        images = IMAGE[0]
        masks = MASK[0]
        inputs = (images[0])
        labels = (masks[0])
        inputs = torch.stack(IMAGE)
        labels = torch.stack(MASK)
        dataset = TensorDataset(inputs,labels)
        return dataset


def get_dataset(path, train=False, image_size=64, crop=False, jitter=0):
    names, labels = dataset_info(path)
    img_transform = get_img_transform(train, image_size, crop, jitter)
    return DGDataset(names, labels, img_transform)

def get_fourier_dataset(path, image_size=64, crop=False, jitter=0, from_domain='all', alpha=1.0):
    names, labels = dataset_info(path)
    img_transform = get_pre_transform(image_size, crop, jitter)
    return FourierDGDataset(names, labels, img_transform, from_domain, alpha)

In [226]:
trian_path = "./train"
val_path = "./val"
test_path = "./test"

# train_dataset = get_dataset(path=train_path, train=True,crop=True,jitter=0.1).stack_dataset()
train_fourier_dataset = get_fourier_dataset(path=train_path,crop=True,jitter=0.1).stack_dataset()


# print(f'Shape of image: {train_loader.dataset[0][0].shape}')

TypeError: ignored