In [2]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch

from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
import pandas as pd
from scripts.pre_rn18 import ResNetFinetuner, eigen_train, CustomDataset
from scripts.datasets import office_home, convert_bytes_to_images
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import numpy as np
import warnings

In [None]:
ds = load_dataset("flwrlabs/office-home")

class_names, real_data = office_home('real')
print(f"Class names: {class_names}")

In [None]:
real_data_df = real_data.to_pandas()

# Split the in-distribution data into training (80%), validation (10%), and testing (10%)
train_val_df, test_df = train_test_split(
    real_data_df, 
    test_size=0.2, 
    random_state=42, 
    stratify=real_data_df['label']
)
train_df, val_df = train_test_split(
    train_val_df, 
    test_size=0.125, 
    random_state=42, 
    stratify=train_val_df['label']
)

from datasets import Dataset
train_data = Dataset.from_pandas(train_df)
val_data = Dataset.from_pandas(val_df)
test_data = Dataset.from_pandas(test_df)

print(f"In-Distribution Training Data: {len(train_data)}")
print(f"In-Distribution Validation Data: {len(val_data)}")
print(f"In-Distribution Test Data: {len(test_data)}")

In [4]:
train_images = train_data['image']
train_labels = train_data['label']
classes = ds['train'].features['label'].names
num_classes = len(set(classes))

val_images = val_data['image']
val_labels = val_data['label']

test_images = test_data['image']
test_labels = test_data['label']

train_images = convert_bytes_to_images(train_images)
val_images = convert_bytes_to_images(val_images)
test_images = convert_bytes_to_images(test_images)

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224 for CLIP
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))  # CLIP's mean and std
])

In [5]:
# Create custom datasets
train_dataset = CustomDataset(images=train_images, labels=train_labels, classes=classes, transform=transform)
val_dataset = CustomDataset(images=val_images, labels=val_labels, classes=classes, transform=transform)
test_dataset = CustomDataset(images=test_images, labels=test_labels, classes=classes, transform=transform)

# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [6]:
learning_rate = 1e-5
epochs = 5

Data Diversity

In [7]:
auto = transforms.Compose([
    transforms.AutoAugment()
])

rand = transforms.Compose([
    transforms.RandAugment()
])

augM = transforms.Compose([
    transforms.AugMix()
])

trAu = transforms.Compose([
    transforms.TrivialAugmentWide()
])


class GaussianNoise(object):
    def __init__(self, mean=0.0, sigma=1.0, noise_ratio=1.0):
        self.mean = mean
        self.sigma = sigma
        self.noise_ratio = noise_ratio

    def __call__(self, img):
        # Convert PIL Image to Tensor
        img = transforms.ToTensor()(img)

        # Generate Gaussian noise
        noise = torch.randn(img.size()) * self.sigma + self.mean

        # Apply noise to the image
        noisy_img = img + noise * self.noise_ratio

        # Clip to maintain valid pixel range
        noisy_img = torch.clamp(noisy_img, 0.0, 1.0)

        # Convert back to PIL Image
        return transforms.ToPILImage()(noisy_img)

    def __repr__(self):
        return f"GaussianNoise(mean={self.mean}, sigma={self.sigma}, noise_ratio={self.noise_ratio})"

noise_01 = transforms.Compose([                 # Resizing images
    GaussianNoise(mean=0.0, sigma=0.1, noise_ratio=0.1)
])


noise_03 = transforms.Compose([                 # Resizing images
    GaussianNoise(mean=0.0, sigma=0.1, noise_ratio=0.3)
])


noise_05 = transforms.Compose([                   # Resizing images
    GaussianNoise(mean=0.0, sigma=0.1, noise_ratio=0.5)
])

auto1 = transforms.Compose([
    transforms.AutoAugment(),
    GaussianNoise(mean=0.0, sigma=0.1, noise_ratio=0.5)
])

auto2 = transforms.Compose([
    transforms.AutoAugment(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    GaussianNoise(mean=0.0, sigma=0.1, noise_ratio=0.5)
])

auto3 = transforms.Compose([
    transforms.AutoAugment(),
    transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomVerticalFlip(p=0.3),
    transforms.ElasticTransform(alpha=40.0, sigma=6.0, interpolation=transforms.InterpolationMode.BICUBIC),
    GaussianNoise(mean=0.0, sigma=0.1, noise_ratio=0.5)
])


def auto_transform(image):
    image = auto(image)  # Apply auto3 transform (assumes this returns a PIL image)
    return image

def rand_transform(image):
    image = rand(image)  # Apply auto3 transform (assumes this returns a PIL image)
    return image

def augM_transform(image):
    image = augM(image)  # Apply auto3 transform (assumes this returns a PIL image)
    return image

def trAu_transform(image):
    image = trAu(image)  # Apply auto3 transform (assumes this returns a PIL image)
    return image

def noise_01_transform(image):
    image = noise_01(image)  # Apply auto3 transform (assumes this returns a PIL image)
    return image

def noise_03_transform(image):
    image = noise_03(image)  # Apply auto3 transform (assumes this returns a PIL image)
    return image

def noise_05_transform(image):
    image = noise_05(image)  # Apply auto3 transform (assumes this returns a PIL image)
    return image

def auto1_transform(image):
    image = auto1(image)  # Apply auto3 transform (assumes this returns a PIL image)
    return image

def auto2_transform(image):
    image = auto2(image)  # Apply auto3 transform (assumes this returns a PIL image)
    return image

def auto3_transform(image):
    image = auto3(image)  # Apply auto3 transform (assumes this returns a PIL image)
    return image

In [None]:
dataset_div = {'Auto': auto_transform,
               'Rand': rand_transform,
               'AugM': augM_transform,
               'TrAu': trAu_transform,
               'Noize10': noise_01_transform,
               'Noize30': noise_03_transform,
               'Noize50': noise_05_transform,
               'Auto1': auto1_transform,
               'Auto2': auto2_transform,
               'Auto3': auto3_transform
               }

# Dictionaries to store models, watchers, and summaries
model_div = {}
ft_model_div = {}
trace_div = {}

for k, v in dataset_div.items():

    # Create custom datasets
    train_dataset = CustomDataset(images=train_images, labels=train_labels, classes=classes, transform=transform)
    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    # Define models and watchers
    model_div[k] = ResNetFinetuner(num_classes=num_classes)

    # Save model state
    torch.save(model_div[k].state_dict(), f'loss_model/model_div{k}.pth')

    # Train the model
    ft_model_div[k], trace_div[k] = eigen_train(model_div[k], train_dataloader, num_epochs=epochs)

    # Save fine-tuned model state
    torch.save(ft_model_div[k].state_dict(), f'loss_model/ft_model_div{k}.pth')

In [9]:
trace_div = pd.DataFrame([trace_div])

trace_div.to_csv('loss_model/trace_div.csv', index=False)

Visualization

In [126]:
trace_bs_df = pd.read_csv('loss_model/trace_bs.csv')

trace_dp_df = pd.read_csv('loss_model/trace_dp.csv')

trace_wd_df = pd.read_csv('loss_model/trace_wd.csv')

trace_div_df = pd.read_csv('loss_model/trace_div.csv')

In [49]:
trace_bs = np.array(trace_bs_df)[0]

In [31]:
import ast
dp01 = ast.literal_eval(trace_dp_df.iloc[:, 0][0])
dp03 = ast.literal_eval(trace_dp_df.iloc[:, 1][0])
dp05 = ast.literal_eval(trace_dp_df.iloc[:, 2][0])
dp07 = ast.literal_eval(trace_dp_df.iloc[:, 3][0])

wd1e_5 = ast.literal_eval(trace_wd_df.iloc[:, 0][0])
wd5e_5 = ast.literal_eval(trace_wd_df.iloc[:, 1][0])
wd1e_4 = ast.literal_eval(trace_wd_df.iloc[:, 2][0])
wd5e_4 = ast.literal_eval(trace_wd_df.iloc[:, 3][0])
wd1e_3 = ast.literal_eval(trace_wd_df.iloc[:, 4][0])
wd5e_3 = ast.literal_eval(trace_wd_df.iloc[:, 5][0])

div01 = ast.literal_eval(trace_div_df.iloc[:, 0][0])
div02 = ast.literal_eval(trace_div_df.iloc[:, 1][0])
div03 = ast.literal_eval(trace_div_df.iloc[:, 2][0])
div04 = ast.literal_eval(trace_div_df.iloc[:, 3][0])
div05 = ast.literal_eval(trace_div_df.iloc[:, 4][0])
div06 = ast.literal_eval(trace_div_df.iloc[:, 5][0])
div07 = ast.literal_eval(trace_div_df.iloc[:, 6][0])
div08 = ast.literal_eval(trace_div_df.iloc[:, 7][0])
div09 = ast.literal_eval(trace_div_df.iloc[:, 8][0])
div10 = ast.literal_eval(trace_div_df.iloc[:, 9][0])


In [34]:
def smooth(scalars, weight):  # Weight between 0 and 1
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)                        # Save it
        last = smoothed_val                                  # Anchor the last smoothed value
        
    return smoothed

In [None]:
plt.figure(figsize=(10,5))
colours = ["#000000", "#E69F00", "#0072B2", "#009E73", "#CC79A7"]
plt.plot(np.arange(480), smooth(trace_bs, 0.98), label='baseline', c=colours[0])
plt.plot(np.arange(480), smooth(dp01, 0.98), label='dropout_01', c=colours[1])
plt.plot(np.arange(480), smooth(dp03, 0.98), label='dropout_03', c=colours[2])
plt.plot(np.arange(480), smooth(dp05, 0.98), label='dropout_05', c=colours[3])
plt.plot(np.arange(480), smooth(dp07, 0.98), label='dropout_07', c=colours[4])
plt.ylabel("Trace", fontsize=12)
plt.xlabel("Iteration", fontsize=12)
plt.axvline(x=245, color='black', linestyle=':')
plt.axvline(x=375, color='black', linestyle=':')
plt.legend()

plt.savefig("loss_dp.png", dpi=300, bbox_inches='tight')

In [None]:
plt.figure(figsize=(8,5))
colours = ["#000000", "#E69F00", "#0072B2", "#009E73", "#CC79A7", "#D55E00"]
plt.plot(np.arange(480), smooth(trace_bs, 0.98), label='baseline', c=colours[0])
plt.plot(np.arange(480), smooth(dp01, 0.98), label='dropout_01', c=colours[1])
plt.plot(np.arange(480), smooth(wd5e_3, 0.98), label='weight decay_5e-3', c=colours[2])
plt.plot(np.arange(480), smooth(div04, 0.98), label='TrAu', c=colours[3])
plt.plot(np.arange(480), smooth(div06, 0.98), label='Noise_03', c=colours[4])
plt.ylabel("Trace", fontsize=12)
plt.xlabel("Iteration", fontsize=12)
plt.axvline(x=245, color='black', linestyle=':')
plt.axvline(x=375, color='black', linestyle=':')
plt.legend()

plt.savefig("loss_vis.png", dpi=300, bbox_inches='tight')

In [None]:
plt.figure(figsize=(10,5))
colours = ["#000000", "#E69F00", "#0072B2", "#009E73", "#CC79A7", "#56B4E9", "#F0E442", "#D55E00", "#4B0092", "#A52A2A", "#008000"]
plt.plot(np.arange(480), smooth(trace_bs, 0.98), label='baseline', c=colours[0])
plt.plot(np.arange(480), smooth(wd1e_5, 0.98), label='weight decay_1e_5', c=colours[1])
plt.plot(np.arange(480), smooth(wd5e_5, 0.98), label='weight decay_5e_5', c=colours[2])
plt.plot(np.arange(480), smooth(wd1e_4, 0.98), label='weight decay_1e_4', c=colours[3])
plt.plot(np.arange(480), smooth(wd5e_4, 0.98), label='weight decay_5e_4', c=colours[4])
plt.plot(np.arange(480), smooth(wd1e_3, 0.98), label='weight decay_1e_3', c=colours[5])
plt.plot(np.arange(480), smooth(wd5e_3, 0.98), label='weight decay_5e_3', c=colours[6])
plt.ylabel("Trace", fontsize=12)
plt.xlabel("Iteration", fontsize=12)
plt.axvline(x=245, color='black', linestyle=':')
plt.axvline(x=375, color='black', linestyle=':')
plt.legend()

plt.savefig("loss_wd.png", dpi=300, bbox_inches='tight')

In [None]:
plt.figure(figsize=(10,5))
colours = ["#000000", "#E69F00", "#0072B2", "#009E73", "#CC79A7", "#56B4E9", "#F0E442", "#D55E00", "#4B0092", "#A52A2A", "#008000"]
plt.plot(np.arange(480), smooth(trace_bs, 0.98), label='baseline', c=colours[0])
plt.plot(np.arange(480), smooth(div01, 0.98), label='Auto', c=colours[1])
plt.plot(np.arange(480), smooth(div02, 0.98), label='Rand', c=colours[2])
plt.plot(np.arange(480), smooth(div03, 0.98), label='AugM', c=colours[3])
plt.plot(np.arange(480), smooth(div04, 0.98), label='TrAu', c=colours[4])
plt.plot(np.arange(480), smooth(div05, 0.98), label='Noise10', c=colours[5])
plt.plot(np.arange(480), smooth(div06, 0.98), label='Noise30', c=colours[6])
plt.plot(np.arange(480), smooth(div07, 0.98), label='Nosie50', c=colours[7])
plt.plot(np.arange(480), smooth(div08, 0.98), label='Auto-v1', c=colours[8])
plt.plot(np.arange(480), smooth(div09, 0.98), label='Auto-v2', c=colours[9])
plt.plot(np.arange(480), smooth(div09, 0.98), label='Auto-v3', c=colours[9])
plt.ylabel("Trace", fontsize=12)
plt.xlabel("Iteration", fontsize=12)
plt.axvline(x=245, color='black', linestyle=':')
plt.axvline(x=375, color='black', linestyle=':')
plt.legend()

plt.savefig("loss_div.png", dpi=300, bbox_inches='tight')