**CORAL APPROACH**
---

Sun, B., et al. "Correlation alignment for unsupervised domain adaptation." In Domain Adaptation in Computer Vision Applications (pp. 153-171). Springer, Cham, 2017.

In [1]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset

import matplotlib.pyplot as plt
import IPython.display
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from PIL import Image

In [2]:
# DEFINE DEVICE

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
print(device)

cuda


In [10]:
# GET COVARIANCE MATRICES

#------------------------------------------------
# Compute the MNIST and SVHN mean and standard 
# deviation on the training set
#------------------------------------------------

def calculate_mean(loader, num_channels):
    """Calculate the mean for each channel across the dataset."""
    channel_sum, num_elements = torch.zeros(num_channels), 0

    for data, _ in loader:
        channel_sum += data.mean([0, 2, 3]) * data.size(0)
        num_elements += data.size(0)

    mean = channel_sum / num_elements
    return mean

def calculate_std(loader, num_channels, mean):
    """Calculate the standard deviation for each channel across the dataset."""
    channel_squared_sum, num_elements = torch.zeros(num_channels), 0

    for data, _ in loader:
        batch_size = data.size(0)
        num_elements += batch_size
        for i in range(num_channels):
            channel_squared_sum[i] += ((data[:, i, :, :] - mean[i])**2).sum()

    variance = channel_squared_sum / num_elements
    std = torch.sqrt(variance)
    return std


# Extract MNIST and SVHN datasets
MNIST_train = datasets.MNIST(root='./mnist_data/', train=True, transform = transforms.ToTensor(), download=True)
SVHN_train = datasets.SVHN(root='./svhn_data/', split='train', transform = transforms.ToTensor(), download=True)

# Create DataLoaders for MNIST and SVHN datasets
batch_size = 64  # Adjust as needed
MNIST_train_loader = DataLoader(MNIST_train, batch_size=batch_size, shuffle=False)
SVHN_train_loader = DataLoader(SVHN_train, batch_size=batch_size, shuffle=False)

# Compute mean and standard deviation for MNIST and SVHN datasets
mnist_mean = calculate_mean(MNIST_train_loader, 1)
mnist_std = calculate_std(MNIST_train_loader, 1, mnist_mean)
print(f"MNIST mean: {mnist_mean}, MNIST std: {mnist_std}")

svhn_mean = calculate_mean(SVHN_train_loader, 3)
svhn_std = calculate_std(SVHN_train_loader, 3, svhn_mean)
print(f"SVHN mean: {svhn_mean}, SVHN std: {svhn_std}")

#------------------------------------------------
# Normalise and flatten the MNIST and SVHN  dataset
# with real mean and standard deviation
#------------------------------------------------

class FlattenTransform:
    def __call__(self, x):
        return x.view(-1)
    
transform_mnist = transforms.Compose([
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mnist_mean, mnist_std),  # Normalize with real mean and standard deviation
    FlattenTransform(),  # Flatten the images
])

transform_svhn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(svhn_mean, svhn_std), # Normalize with real mean and standard deviation
    FlattenTransform(),  # Flatten the images
])

MNIST_train_flat = datasets.MNIST(root='./mnist_data/', train=True, transform = transform_mnist, download=True)
SVHN_train_flat = datasets.SVHN(root='./svhn_data/', split='train', transform = transform_svhn, download=True)

MNIST_train_loader_flat = DataLoader(MNIST_train_flat, batch_size=batch_size, shuffle=False)
SVHN_train_loader_flat = DataLoader(SVHN_train_flat, batch_size=batch_size, shuffle=False)

print("MNIST shape: ", MNIST_train_flat.data.shape)
print("SVHN shape: ", SVHN_train_flat.data.shape)




    

Using downloaded and verified file: ./svhn_data/train_32x32.mat
MNIST mean: tensor([0.1307]), MNIST std: tensor([8.6270])
SVHN mean: tensor([0.4377, 0.4438, 0.4728]), SVHN std: tensor([6.3370, 6.4325, 6.3052])
Using downloaded and verified file: ./svhn_data/train_32x32.mat
MNIST shape:  torch.Size([60000, 28, 28])
SVHN shape:  (73257, 3, 32, 32)


In [3]:
# DOWNLOAD DATA 

# DOWNLOAD, RESIZE & NORMALIZE MNIST DATASET  (32x32x3 instead of originl 28x28x1)

from torchvision import datasets, transforms
import torch

# Define the transform to resize the image to 32x32 and replicate to 3 channels
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to 32x32
    transforms.Grayscale(num_output_channels=3),  # Convert to RGB by replicating channels
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize each channel (assuming mean 0.5, std 0.5 for simplicity)
])

# Download and load the dataset with the defined transform
train_dataset_source = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset_source = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=True)

# DOWNLOAD, (RESIZE &) NORMALISE, SVHN DATASET (Stret View House Numbers)
train_dataset_target = datasets.SVHN(root='./svhn_data/', split='train', transform=transform, download=True) # transform to insure same shape and normalisation
test_dataset_target = datasets.SVHN(root='./svhn_data/', split='test', transform=transform, download=True)


Using downloaded and verified file: ./svhn_data/train_32x32.mat
Using downloaded and verified file: ./svhn_data/test_32x32.mat


In [None]:
##########################################
# FUNCTIONS TO APPLY CORAL METHOD
##########################################



def compute_covariance_matrix(data):
    """
    Compute the covariance matrix for the given data.
    
    :param data: 2D array where rows are samples and columns are features.
    :return: Covariance matrix.
    """
    scaler = StandardScaler(with_std=False)
    data = scaler.fit_transform(data)  # Mean centering the data
    covariance_matrix = np.cov(data, rowvar=False)
    return covariance_matrix

def coral(source_data, target_data):
    """
    Perform CORAL on the source data to match the target data.
    
    :param source_data: Source data (MNIST) as a 2D numpy array.
    :param target_data: Target data (SVHN) as a 2D numpy array.
    :return: Transformed source data as a 2D numpy array.
    """
    # Compute the covariance matrices
    source_cov = compute_covariance_matrix(source_data)
    target_cov = compute_covariance_matrix(target_data)

    # Compute the source data whitening matrix
    source_cov_eigvals, source_cov_eigvecs = np.linalg.eigh(source_cov)
    source_whitening_matrix = np.dot(source_cov_eigvecs, np.diag(1.0 / np.sqrt(source_cov_eigvals)))
    
    # Compute the target data coloring matrix
    target_cov_eigvals, target_cov_eigvecs = np.linalg.eigh(target_cov)
    target_coloring_matrix = np.dot(target_cov_eigvecs, np.diag(np.sqrt(target_cov_eigvals)))

    # Transform the source data
    source_data_whitened = np.dot(source_data, source_whitening_matrix)
    source_data_colored = np.dot(source_data_whitened, target_coloring_matrix)

    return source_data_colored

# Flatten the images and convert to numpy arrays
def extract_features_and_flatten(dataset):
    dataset_flattened = []
    for data, _ in dataset:
        # Flatten the image data and convert to a numpy array
        data = data.numpy().flatten()
        dataset_flattened.append(data)
    return np.array(dataset_flattened)

# Extract features from datasets
source_features = extract_features_and_flatten(train_dataset_source)
target_features = extract_features_and_flatten(train_dataset_target)

# Apply CORAL to align the source dataset to the target dataset
source_features_aligned = coral(source_features, target_features)

# Convert the transformed features back into PyTorch tensors
source_features_aligned_tensor = torch.tensor(source_features_aligned, dtype=torch.float32)

# Here you would reshape the tensor and train your classifier using the aligned data
# This part is not shown and would depend on your classifier training procedure.
