In [42]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torchvision.models import ResNet18_Weights
from semantic_transforms import (
    FFTSuppressAmplitude,
    GaussianBlur,
    CannyEdge,
    EdgeMaskedBlur,
    GrayScale,
    ColorSwap,
    RandomDropChannel,
    MyColorJitter,
    HistEqualization,
    RandomChannelNormalization,
    LowFrequencyNoiseInjection
)

In [2]:
seed = 123
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
dataset_paths = {
    "PACS" : os.path.join(os.getcwd(), "PACS", "kfold")
    }

### Dataset Class

In [4]:
class DGDomainDataset(Dataset):
    def __init__(self, root, seed=seed, train=True, transform=None):
        random.seed(seed)

        self.root = root
        self.transform = transform
        train_path_list = []
        validation_path_list = []
        self.domain_to_idx = {}

        # Traverse directory structure
        domains = sorted([d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))])
        for domain_idx, domain in enumerate(domains):
            self.domain_to_idx[domain] = domain_idx
            domain_path = os.path.join(root, domain)

            classes = sorted([c for c in os.listdir(domain_path) if os.path.isdir(os.path.join(domain_path, c))])
            for cls in classes:
                class_path = os.path.join(domain_path, cls)
                img_filenames = [
                    f for f in os.listdir(class_path)
                    if f.lower().endswith(('.jpg', '.jpeg', '.png'))
                ]
                train_val_split = int(len(img_filenames) * 0.85)
                train_img_filenames = img_filenames[:train_val_split]
                validation_img_filenames = img_filenames[train_val_split:]
                for img in train_img_filenames:
                    train_path_list.append({
                        "path": os.path.join(class_path, img),
                        "domain": domain,
                        "class": cls
                    })
                for img in validation_img_filenames:
                    validation_path_list.append({
                        "path": os.path.join(class_path, img),
                        "domain": domain,
                        "class": cls
                    })

        random.shuffle(train_path_list)
        random.shuffle(validation_path_list)
        if train:
            self.path_list = train_path_list
        else:
            self.path_list = validation_path_list

    def __len__(self):
        return len(self.path_list)

    def __getitem__(self, idx):
        sample = self.path_list[idx]
        img_path = sample["path"]
        domain = sample["domain"]
        domain_idx = self.domain_to_idx[domain]

        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, domain_idx


In [22]:
def get_transform(**kwargs):
    title = ''
    if "fourier" in kwargs.keys():
        for k, v in kwargs.items():
            title += str(k)
            title += ' = '+str(v) if v is not None else ''
            title += ' '
            if k == "ampl_scale":
                ampl_scale = v
        tr = FFTSuppressAmplitude(ampl_scale=ampl_scale)
    elif "blur" in kwargs.keys():
        for k, v in kwargs.items():
            title += str(k)
            title += ' = '+str(v) if v is not None else ''
            title += ' '
            if k == "blur_scale":
                blur_scale = v
        tr = GaussianBlur(blur_scale=blur_scale)
    elif "edge" in kwargs.keys():
        for k, v in kwargs.items():
            title += str(k)
            title += ' = '+str(v) if v is not None else ''
            title += ' '
            if k == "edge_scale":
                edge_scale = v
        tr = CannyEdge(edge_scale=edge_scale)
    elif "edgeMaskedBlur" in kwargs.keys():
        for k, v in kwargs.items():
            title += ' = '+str(k)
            title += ' '
            title += str(v) if v is not None else ''
            if k == "blur_scale":
                blur_scale = v
            elif k == "edge_scale":
                edge_scale = v
        tr = EdgeMaskedBlur(blur_scale=blur_scale, edge_scale=edge_scale)
    elif "gray" in kwargs.keys():
        for k, v in kwargs.items():
            title += str(k)
            title += ' = '+str(v) if v is not None else ''
            title += ' '
        tr = GrayScale()
    elif "colorSwap" in kwargs.keys():
        for k, v in kwargs.items():
            title += str(k)
            title += ' = '+str(v) if v is not None else ''
            title += ' '
            if k == "seed":
                seed =v
        tr = ColorSwap(seed=seed)
    elif "colorDrop" in kwargs.keys():
        for k, v in kwargs.items():
            title += str(k)
            title += ' = '+str(v) if v is not None else ''
            title += ' '
            if k == "drop_rate":
                drop_rate = v
            elif k == "seed":
                seed = v
        tr = RandomDropChannel(drop_rate=drop_rate, seed=seed)
    elif "colorJitter" in kwargs.keys():
        for k, v in kwargs.items():
            title += str(k)
            title += ' = '+str(v) if v is not None else ''
            title += ' '
            if k == "brightness":
                brightness = v
            elif k == "contrast":
                contrast = v
            elif k == "saturation":
                saturation = v
            elif k == "hue":
                hue = v
            elif k == "seed":
                seed = v
        tr = MyColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, seed=seed)
    elif "histEqualize" in kwargs.keys():
        for k, v in kwargs.items():
            title += str(k)
            title += ' = '+str(v) if v is not None else ''
            title += ' '
            if k == "clip_limit":
                clip_limit = v
            elif k == "tile_grid_size":
                tile_grid_size = v
        tr = HistEqualization(clip_limit=clip_limit, tile_grid_size=tile_grid_size)
    elif "rndChannelNormalize" in kwargs.keys():
        for k, v in kwargs.items():
            title += str(k)
            title += ' = '+str(v) if v is not None else ''
            title += ' '
            if k == "mean_shift_range":
                mean_shift_range = v
            elif k == "scale_range":
                scale_range = v
            elif k == "seed":
                seed = v
        tr = RandomChannelNormalization(mean_shift_range=mean_shift_range, scale_range=scale_range, seed=seed)
    elif "lowFreqNoiseInject" in kwargs.keys():
        for k, v in kwargs.items():
            title += str(k)
            title += ' = '+str(v) if v is not None else ''
            title += ' '
            title+=' '
            if k == "alpha":
                alpha = v
            elif k == "blur_kernel":
                blur_kernel = v
            elif k == "seed":
                seed = v
        tr = LowFrequencyNoiseInjection(alpha=alpha, blur_kernel=blur_kernel, seed=seed)
    
    
    return tr, title

In [None]:
# Transform list of kwargs
tr_list = [get_transform(fourier=None, ampl_scale=0), 
           get_transform(fourier=None, ampl_scale=0.5),
           get_transform(blur=None, blur_scale=10),
           get_transform(blur=None, blur_scale=50),
           get_transform(edge=None, edge_scale=100),
           get_transform(edgeMaskedBlur=None, blur_scale=50, edge_scale=100),
           get_transform(edgeMaskedBlur=None, blur_scale=50, edge_scale=250),
           get_transform(gray=None),
           get_transform(colorSwap=None, seed=123),
           get_transform(colorSwap=None, seed=1234),
           get_transform(colorDrop=None, drop_rate=1, seed=123),
           get_transform(colorDrop=None, drop_rate=0.5, seed=123),
           get_transform(colorJitter=None, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3, seed=123),
           get_transform(colorJitter=None, brightness=0.5, contrast=0.5, saturation=0.5, hue=1, seed=123),
           get_transform(histEqualize=None, clip_limit=2.0, tile_grid_size=(8, 8)),
           get_transform(rndChannelNormalize=None, mean_shift_range=(-50, 50), scale_range=(0.5, 1.5), seed=123),
           get_transform(lowFreqNoiseInject=None, alpha=0.5, blur_kernel=51, seed=123),
           get_transform(lowFreqNoiseInject=None, alpha=1, blur_kernel=51, seed=123)
           ]

tr_dict = {
    title : transforms.Compose([
                transforms.Resize((224, 224)),
                tr,
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
            ])
    for tr, title in tr_list
}

train_datasets = {
    title : DGDomainDataset(root=dataset_paths["PACS"], 
                            seed=seed, 
                            train=True, 
                            transform=tr)
    for title, tr in tr_dict.items()
}

validation_datasets = {
    title : DGDomainDataset(root=dataset_paths["PACS"], 
                            seed=seed, 
                            train=False, 
                            transform=tr)
    for title, tr in tr_dict.items()
}

train_loaders = {
    title : DataLoader(dataset,
                       batch_size=32,
                       shuffle=True,
                       num_workers=1)
    for title, dataset in train_datasets.items()
}

validation_loaders = {
    title : DataLoader(dataset,
                       batch_size=32,
                       shuffle=True,
                       num_workers=1)
    for title, dataset in validation_datasets.items()
}

In [37]:
num_domains = len(list(train_datasets.values())[0].domain_to_idx)
print(f"Number of classes: {num_domains}")

Number of classes: 4


## Model, Losses, and Optimizer

In [43]:
# Model
model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_domains)
model = model.to(device)

# Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

## Training

In [None]:
def train_epoch(epoch):
    model.train()
    running_loss = 

In [None]:
# Evaluation: 
# Classification accuracy, Recall, Precision, F1 score : For entire dataset, class specific, domain specific
# List of paths of images for classified and not classified || classified = {'domain1':[path1, path2], 'domain2':[]}, unclassified = {'domain1':[path1, path2]}
# Separate code snippet to create folders to visualize results