Dataset: https://drive.google.com/u/0/uc?id=1p33nsWQaiZMAgsruDoJLyatoq5XAH-TH&export=download

Baseline code: https://github.com/emma-sjwang/Dofe

## Algorithm Implementation

1. Implement the naive baseline model for comparison. This is typically a simple model with basic feature extraction and classification steps.

2. Choose at least one Domain Generalization (DG) method to implement and compare against the naive baseline. This could be FACT, Dofe, or another method of your choice.

3. Report the segmentation performance in terms of dice coefficient and average surface distance.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

%cd "/content/gdrive/MyDrive/Colab Notebooks/ELEC4010N/Final Project/Project2"
!unzip "/content/gdrive/MyDrive/Colab Notebooks/ELEC4010N/Final Project/Project2/Fundus-doFE.zip" -d "/content/"

## pip install external lib
!pip install -U albumentations
!pip install torchmetrics

## Dataset info
[cup]Domian1: Drishti-GS dataset [101] including training[50] and testing[51]

[cup]Domain2: RIM-ONE_r3 dataset [159] including training and[99] testing[60]. 

[cup]Domain3: REFUGE training [400]  MICCAI 2018 workshop including training and[320] testing[80]. 

[cup]Domian4: REFUGE val [400]  including training and[320] testing[80]. 
Domain5: ISBI [81]  IDRID chanllenge



## Naive Baseline Model

In [None]:
# Just for reference
# UNet implementation
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True))

class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
                
        self.conv_down1 = double_conv(3, 64)      # Number of channel
        self.conv_down2 = double_conv(64, 128)
        self.conv_down3 = double_conv(128, 256)
        self.conv_down4 = double_conv(256, 512)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.conv_up3 = double_conv(256 + 512, 256)
        self.conv_up2 = double_conv(128 + 256, 128)
        self.conv_up1 = double_conv(128 + 64, 64)
        
        self.last_conv = nn.Conv2d(64, num_classes, kernel_size=1)

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        conv1 = self.conv_down1(x)
        x = self.maxpool(conv1)
        conv2 = self.conv_down2(x)
        x = self.maxpool(conv2)
        conv3 = self.conv_down3(x)
        x = self.maxpool(conv3)
        x = self.conv_down4(x)
        x = self.upsample(x)
        
        x = torch.cat([x, conv3], dim=1)
        x = self.conv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)
        x = self.conv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)
        
        x = self.conv_up1(x)
        out = self.last_conv(x)
        out = torch.sigmoid(out)
        
        return out

In [None]:
# Test the model
model = UNet(num_classes=1).to(device)
output = model(torch.randn(1,3,256,256).to(device))
print(f'Output shape: {output.shape}')

## FACT

https://github.com/MediaBrain-SJTU/FACT

In [2]:
# Standard library imports
import os
import shutil
import math
import copy
import random
import gc
import numpy as np
import pandas as pd
from math import sqrt

# PyTorch imports for deep learning
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import Dataset, DataLoader, TensorDataset, ConcatDataset, Subset
from torchvision import transforms
from torchvision.datasets import ImageFolder
import torchvision.models as models
from torchvision.models import resnet50 as _resnet50
from torchsummary import summary

# Other utilities
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from google.colab.patches import cv2_imshow
import albumentations as Augm
from albumentations.pytorch import ToTensorV2

 ### 1.1 Data Reader

In [None]:
%cd "/content"

In [196]:
from data.data_utils import *

train_path = "./train"
val_path = "./val"
test_path = "./test"
D1_Mask_dir = ["/content/Fundus/Domain1/test/mask", "/content/Fundus/Domain1/train/mask"]
D2_Mask_dir = ["/content/Fundus/Domain2/test/mask", "/content/Fundus/Domain2/train/mask"]
D3_Mask_dir = ["/content/Fundus/Domain3/test/mask", "/content/Fundus/Domain3/train/mask"]
D4_Mask_dir = ["/content/Fundus/Domain4/test/mask", "/content/Fundus/Domain4/train/mask"]


class Fundus_Dataset(Dataset):
    def __init__(self, path, transform, transform_mask, split_idx):
        self.PATH = path
        self.transform = transform
        self.transform_mask = transform_mask
        self.split_idx = split_idx
        dirsImg, images, dirsMask, masks = [], [], [], []
        for PATH in self.PATH:
            for img_path in PATH:
                for root, folders, files in os.walk(img_path):
                    for file in files:
                        check_mask = cv2.imread(img_path+'/'+file, cv2.IMREAD_GRAYSCALE)
                        if check_mask.sum() > 0:
                            match (img_path[16:23]):  ## Domain idx
                                case 'Domain1':
                                    images.append(file)
                                case 'Domain2':
                                    images.append(file.replace('png','jpg'))
                                case 'Domain3':
                                    images.append(file.replace('bmp','jpg'))
                                case 'Domain4':
                                    images.append(file.replace('bmp','jpg'))
                            masks.append(file)
                            dirsMask.append(img_path+'/')
                            dirsImg.append(img_path.replace('mask','image')+'/')
        path_df = pd.DataFrame({'direcImg':dirsImg, 'images':images, 'direcMask':dirsMask, 'masks':masks})
        self.path_df = path_df
    
    def split_traindataset(self):
        for i in range(len(self.path_df)):
            image_direc = self.path_df.loc[i]['direcImg'][:-1]+'/'+self.path_df.loc[i]['images']
            mask_direc = self.path_df.loc[i]['direcMask'][:-1]+'/'+self.path_df.loc[i]['masks']
            match(self.split_idx):
                case 1:
                    if "Domain4" not in image_direc:
                        if i < int(0.8*len(self.path_df)):
                            shutil.copy(image_direc, train_path+"/image/")
                            shutil.copy(mask_direc, train_path+"/mask/")
                        else:
                            shutil.copy(image_direc, val_path+"/image/")
                            shutil.copy(mask_direc, val_path+"/mask/")
                case 2:
                    if "Domain3" not in image_direc:
                        if i < int(0.8*len(self.path_df)):
                            shutil.copy(image_direc, train_path+"/image/")
                            shutil.copy(mask_direc, train_path+"/mask/")
                        else:
                            shutil.copy(image_direc, val_path+"/image/")
                            shutil.copy(mask_direc, val_path+"/mask/")
                case 3:
                    if "Domain2" not in image_direc:
                        if i < int(0.8*len(self.path_df)):
                            shutil.copy(image_direc, train_path+"/image/")
                            shutil.copy(mask_direc, train_path+"/mask/")
                        else:
                            shutil.copy(image_direc, val_path+"/image/")
                            shutil.copy(mask_direc, val_path+"/mask/")
                case 4:
                    if "Domain1" not in image_direc:
                        if i < int(0.8*len(self.path_df)):
                            shutil.copy(image_direc, train_path+"/image/")
                            shutil.copy(mask_direc, train_path+"/mask/")
                        else:
                            shutil.copy(image_direc, val_path+"/image/")
                            shutil.copy(mask_direc, val_path+"/mask/")

        return self.trainGetitem()

    def split_testdataset(self):
        for i in range(len(self.path_df)):
            image_direc = self.path_df.loc[i]['direcImg'][:-1]+'/'+self.path_df.loc[i]['images']
            mask_direc = self.path_df.loc[i]['direcMask'][:-1]+'/'+self.path_df.loc[i]['masks']
            shutil.copy(image_direc, test_path+"/image/")
            shutil.copy(mask_direc, test_path+"/mask/")
        return self.testGetitem()

    def __len__(self):
        return len(self.path_df)
  
    def trainGetitem(self):
        trainImg, trainMask = [], []
        valImg, valMask = [], []
        train_folders = os.listdir(train_path)
        for folder in train_folders:
            files = os.listdir(train_path+'/'+folder)
            for file in files:
                direc = train_path+'/'+folder+'/'+file
                if 'mask' in folder:
                    mask = cv2.imread(direc, cv2.IMREAD_GRAYSCALE)
                    mask = self.transform(mask)
                    mask = torch.round(mask)
                    trainMask.append(mask)
                else:
                    fundus_img = cv2.imread(direc)
                    fundus_img = self.transform(fundus_img)
                    trainImg.append(fundus_img)

        val_folders = os.listdir(val_path)
        for folder in val_folders:
            files = os.listdir(val_path+'/'+folder)
            for file in files:
                direc = val_path+'/'+folder+'/'+file
                if 'mask' in folder:
                    mask = cv2.imread(direc, cv2.IMREAD_GRAYSCALE)
                    mask = self.transform(mask)
                    mask = torch.round(mask)
                    valMask.append(mask)
                else:
                    fundus_img = cv2.imread(direc)
                    fundus_img = self.transform(fundus_img)
                    valImg.append(fundus_img)
        return [[trainImg,trainMask], [valImg,valMask]]

    def testGetitem(self):
        testImg, testMask = [], []
        test_folders = os.listdir(test_path)
        for folder in test_folders:
            files = os.listdir(test_path+'/'+folder)
            for file in files:
                direc = test_path+'/'+folder+'/'+file
                if 'mask' in folder:
                    mask = cv2.imread(direc, cv2.IMREAD_GRAYSCALE)
                    mask = self.transform(mask)
                    mask = torch.round(mask)
                    testMask.append(mask)
                else:
                    fundus_img = cv2.imread(direc)
                    fundus_img = self.transform(fundus_img)
                    testImg.append(fundus_img)
        return [[testImg,testMask]]

    def direc(self):
        FundusImg_direc = []
        Mask_direc = []
        train_folders = os.listdir(train_path)
        for folder in train_folders:
            files = os.listdir(train_path+'/'+folder)
            for file in files:
                direc = train_path+'/'+folder+'/'+file
                if 'mask' in folder:
                    Mask_direc.append(direc)
                else:
                    FundusImg_direc.append(direc)
        return FundusImg_direc, Mask_direc

    def stack_train(self, batch_size=8):
        All_dataset = []
        DATASET = self.trainGetitem()
        for dataset in DATASET:
            fundus = dataset[0]
            mask = dataset[1]
            inputs = (fundus[0])
            labels = (mask[0])
            inputs = torch.stack(fundus)
            labels = torch.stack(mask)
            All_dataset.append(TensorDataset(inputs,labels))
        return All_dataset[0], All_dataset[1]

    def stack_test(self, batch_size=8):
        All_dataset = []
        DATASET = self.testGetitem()
        for dataset in DATASET:
            fundus = dataset[0]
            mask = dataset[1]
            inputs = (fundus[0])
            labels = (mask[0])
            inputs = torch.stack(fundus)
            labels = torch.stack(mask)
            All_dataset.append(TensorDataset(inputs,labels))
        return All_dataset[0]



class AugmDataSet():
    def __init__(self, direcs, transform_img=None, transform_mask=None):
        self.images_direc = direcs[0]
        self.mask_direc = direcs[1]
        self.transform_mask = transform_mask
        self.transform_img = transform_img

    def __len__(self):
        return len(self.images_direc)

    def stack_dataset(self):
        MRI = []
        MASK = []
        for img_path, mask_path in zip(self.images_direc, self.mask_direc):
            image = cv2.imread(img_path)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            mask = np.expand_dims(mask, axis=2)
            if self.transform_img is not None:
                transform = self.transform_img(image=image, mask=mask)
                image = transform['image']/255
                mask = transform['mask']
                mask = torch.permute(mask, (2,0,1))
                mask = torch.round(mask)
                MRI.append(image)
                MASK.append(mask)   
        IMAGE = torch.stack(MRI)
        LABEL = torch.stack(MASK)
        IMAGE = IMAGE.type(torch.float32)
        LABEL = LABEL.type(torch.float32)
        dataset = TensorDataset(IMAGE,LABEL)
        return dataset


def get_AugmDataset(batch_size=16, MRImg_dataset=None):
    TrainImage_subset = []
    n_mri_split = int(1*len(MRImg_dataset[0]))
    for i in range(len(MRImg_dataset)):
        TrainImage_subset.append(torch.utils.data.Subset(MRImg_dataset[i], range(n_mri_split)))
    MRI_dataset = torch.utils.data.ConcatDataset([TrainImage_subset[0], TrainImage_subset[1],
                                                  TrainImage_subset[2], TrainImage_subset[3]])
    return MRI_dataset


In [198]:
for path in [train_path, val_path, test_path]:
    if os.path.exists(path):
        shutil.rmtree(path)
    os.makedirs(path)

for path in [train_path, val_path, test_path]:
    os.makedirs(os.path.join(path, "image"))
    os.makedirs(os.path.join(path, "mask"))

normalization = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(64),
    transforms.RandomCrop(64),
    transforms.ToTensor(),
])

# Flipping
GEO = Augm.Compose([
    Augm.augmentations.geometric.resize.Resize(64, 64),
    Augm.RandomCrop(width=64, height=64),
    Augm.HorizontalFlip(p=1),
    Augm.RandomBrightnessContrast(p=0.1),
    ToTensorV2()
])

# Color distortion
COL = Augm.Compose([
    Augm.augmentations.geometric.resize.Resize(64, 64),
    Augm.RandomCrop(width=64, height=64),
    Augm.augmentations.transforms.ColorJitter(),
    Augm.RandomBrightnessContrast(p=0.1),
    ToTensorV2()
])

## PCA
PCA = Augm.Compose([
    Augm.augmentations.geometric.resize.Resize(64, 64),
    Augm.RandomCrop(width=64, height=64),
    Augm.augmentations.transforms.FancyPCA(),
    Augm.RandomBrightnessContrast(p=0.1),
    ToTensorV2()
])


## Raw Dataset
Train_class = Fundus_Dataset([D1_Mask_dir,D2_Mask_dir,D3_Mask_dir], transform=normalization, transform_mask=normalization, split_idx=1)
Test_class = Fundus_Dataset([D4_Mask_dir], transform=normalization, transform_mask=normalization, split_idx=1)
Train_class.split_traindataset()
Test_class.split_testdataset()
train_dataset, val_dataset = Train_class.stack_train()
test_dataset = Test_class.stack_test()

## Augmen Dataset
PCA_dataset = AugmDataSet(direcs=Train_class.direc(), transform_img=PCA).stack_dataset()
GEO_dataset = AugmDataSet(direcs=Train_class.direc(), transform_img=GEO).stack_dataset()
COL_dataset = AugmDataSet(direcs=Train_class.direc(), transform_img=COL).stack_dataset()
train_augm_dataset = get_AugmDataset(MRImg_dataset=[train_dataset, PCA_dataset, GEO_dataset, COL_dataset])

train_loader = torch.utils.data.DataLoader(train_augm_dataset, batch_size=8, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)

## Test DataLoader
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)


print(f'Shape of image: {train_loader.dataset[0][0].shape}')       
print(f'Number of training batches: {len(train_loader)}')          
print(f'Number of validation batches: {len(val_loader)}')          
print(f'Number of test batches: {len(test_loader)}')               

Shape of image: torch.Size([3, 64, 64])
Number of training batches: 264
Number of validation batches: 17
Number of test batches: 400


In [None]:
############################ BELOW ARE ORIGINAL CODE ############################

class DGDataset(Dataset):
    def __init__(self, names, labels, transformer=None):
        self.names = names
        self.labels = labels
        self.transformer = transformer

    def __len__(self):
        return len(self.names)

    def __getitem__(self, index):
        img_name = self.names[index]
        img = Image.open(img_name).convert('RGB')
        if self.transformer is not None:
            img = self.transformer(img)
        label = self.labels[index]
        return img, label


class FourierDGDataset(Dataset):
    def __init__(self, names, labels, transformer=None, from_domain=None, alpha=1.0):
        
        self.names = names
        self.labels = labels
        self.transformer = transformer
        self.post_transform = get_post_transform()
        self.from_domain = from_domain
        self.alpha = alpha
        
        self.flat_names = []
        self.flat_labels = []
        self.flat_domains = []
        for i in range(len(names)):
            self.flat_names += names[i]
            self.flat_labels += labels[i]
            self.flat_domains += [i] * len(names[i])
        assert len(self.flat_names) == len(self.flat_labels)
        assert len(self.flat_names) == len(self.flat_domains)

    def __len__(self):
        return len(self.flat_names)

    def __getitem__(self, index):
        img_name = self.flat_names[index]
        label = self.flat_labels[index]
        domain = self.flat_domains[index]
        img = Image.open(img_name).convert('RGB')
        img_o = self.transformer(img)

        img_s, label_s, domain_s = self.sample_image(domain)
        img_s2o, img_o2s = colorful_spectrum_mix(img_o, img_s, alpha=self.alpha)
        img_o, img_s = self.post_transform(img_o), self.post_transform(img_s)
        img_s2o, img_o2s = self.post_transform(img_s2o), self.post_transform(img_o2s)
        img = [img_o, img_s, img_s2o, img_o2s]
        label = [label, label_s, label, label_s]
        domain = [domain, domain_s, domain, domain_s]
        return img, label, domain

    def sample_image(self, domain):
        if self.from_domain == 'all':
            domain_idx = random.randint(0, len(self.names)-1)
        elif self.from_domain == 'inter':
            domains = list(range(len(self.names)))
            domains.remove(domain)
            domain_idx = random.sample(domains, 1)[0]
        elif self.from_domain == 'intra':
            domain_idx = domain
        else:
            raise ValueError("Not implemented")
        img_idx = random.randint(0, len(self.names[domain_idx])-1)
        imgn_ame_sampled = self.names[domain_idx][img_idx]
        img_sampled = Image.open(imgn_ame_sampled).convert('RGB')
        label_sampled = self.labels[domain_idx][img_idx]
        return self.transformer(img_sampled), label_sampled, domain_idx


def get_dataset(path, train=False, image_size=224, crop=False, jitter=0, config=None):
    names, labels = dataset_info(path)
    if config:
        image_size = config["image_size"]
        crop = config["use_crop"]
        jitter = config["jitter"]
    img_transform = get_img_transform(train, image_size, crop, jitter)
    return DGDataset(names, labels, img_transform)


def get_fourier_dataset(path, image_size=224, crop=False, jitter=0, from_domain='all', alpha=1.0, config=None):
    assert isinstance(path, list)
    names = []
    labels = []
    for p in path:
        name, label = dataset_info(p)
        names.append(name)
        labels.append(label)

    if config:
        image_size = config["image_size"]
        crop = config["use_crop"]
        jitter = config["jitter"]
        from_domain = config["from_domain"]
        alpha = config["alpha"]

    img_transform = get_pre_transform(image_size, crop, jitter)
    return FourierDGDataset(names, labels, img_transform, from_domain, alpha)


## fix tmr & load dataset from Google drive
import bisect
from data.ConcatDataset import ConcatDataset
from utils.tools import *


default_input_dir = '/content/gdrive/MyDrive/Colab Notebooks/ELEC4010N/Final Project/Project2/data/datalists'

digits_datset = ["mnist", "mnist_m", "svhn", "syn"]
pacs_dataset = ["art_painting", "cartoon", "photo", "sketch"]
officehome_dataset = ['Art', 'Clipart', 'Product', 'Real_World']
available_datasets = pacs_dataset + officehome_dataset + digits_datset


def get_datalists_folder(args=None):
    datalists_folder = default_input_dir
    if args is not None:
        if args.input_dir is not None:
            datalists_folder = args.input_dir
    return datalists_folder


def get_train_dataloader(source_list=None, batch_size=64, image_size=224, crop=False, jitter=0, args=None, config=None):
    if args is not None:
        source_list = args.source
    if config is not None:
        batch_size = config["batch_size"]
        data_config = config["data_opt"]
    else:
        data_config = None
    assert isinstance(source_list, list)
    datasets = []
    for dname in source_list:
        datalists_folder = get_datalists_folder(args)
        path = os.path.join(datalists_folder, '%s_train.txt' % dname)
        train_dataset = get_dataset(path=path,
                                    train=True,
                                    image_size=image_size,
                                    crop=crop,
                                    jitter=jitter,
                                    config=data_config)
        datasets.append(train_dataset)
    dataset = ConcatDataset(datasets)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=False, drop_last=True)
    return loader


def get_fourier_train_dataloader(
        source_list=None,
        batch_size=64,
        image_size=224,
        crop=False,
        jitter=0,
        args=None,
        from_domain='all',
        alpha=1.0,
        config=None
):
    if args is not None:
        source_list = args.source
    if config is not None:
        batch_size = config["batch_size"]
        data_config = config["data_opt"]
    else:
        data_config = None
    assert isinstance(source_list, list)

    paths = []
    for dname in source_list:
        datalists_folder = get_datalists_folder(args)
        path = os.path.join(datalists_folder, '%s_train.txt' % dname)
        paths.append(path)
    dataset = get_fourier_dataset(path=paths,
                                  image_size=image_size,
                                  crop=crop,
                                  jitter=jitter,
                                  from_domain=from_domain,
                                  alpha=alpha,
                                  config=data_config)

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=False, drop_last=True)
    return loader


def get_val_dataloader(source_list=None, batch_size=64, image_size=224, args=None, config=None):
    if args is not None:
        source_list = args.source
    if config is not None:
        batch_size = config["batch_size"]
        data_config = config["data_opt"]
    else:
        data_config = None
    assert isinstance(source_list, list)
    datasets = []
    for dname in source_list:
        datalists_folder = get_datalists_folder(args)
        path = os.path.join(datalists_folder, '%s_val.txt' % dname)
        val_dataset = get_dataset(path=path, train=False, image_size=image_size, config=data_config)
        datasets.append(val_dataset)
    dataset = ConcatDataset(datasets)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False)
    return loader


def get_test_loader(target=None, batch_size=64, image_size=224, args=None, config=None):
    if args is not None:
        target = args.target
    if config is not None:
        batch_size = config["batch_size"]
        data_config = config["data_opt"]
    else:
        data_config = None
    data_folder = get_datalists_folder(args)
    path = os.path.join(data_folder, '%s_test.txt' % target)
    test_dataset = get_dataset(path=path, train=False, image_size=image_size, config=data_config)
    dataset = ConcatDataset([test_dataset])
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=False)
    return loader

if __name__ == "__main__":
    # print("yes this main")
    batch_size=16
    source = ["art_painting", "cartoon", "photo"]
    loader = get_fourier_train_dataloader(source, batch_size, image_size=224, from_domain='all', alpha=1.0)

    it = iter(loader)
    batch = next(it)
    images = torch.cat(batch[0], dim=0)
    # images = batch[0][0]
    save_image_from_tensor_batch(images, batch_size, path='batch.jpg', device='cpu')