In [None]:
#Basic project to understand domain adaptaion
#Feature Alignment
#MNIST to SVHN
#A100

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Datasets

In [None]:
transform_mnist = transforms.Compose([
    transforms.Resize(32),
    transforms.Grayscale(3),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

transform_svhn = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
svhn_train = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_svhn)

mnist_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
svhn_loader = DataLoader(svhn_train, batch_size=128, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.0MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 486kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.46MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 5.56MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./data/train_32x32.mat


100%|██████████| 182M/182M [00:02<00:00, 79.8MB/s]


## Model

### Loss function

In [None]:
def gaussian_kernel(x, y, sigma=1.0):
    x_size = x.size(0)
    y_size = y.size(0)
    dim = x.size(1)
    tiled_x = x.unsqueeze(1).expand(x_size, y_size, dim)
    tiled_y = y.unsqueeze(0).expand(x_size, y_size, dim)
    return torch.exp(-torch.mean((tiled_x - tiled_y) ** 2, dim=2) / (2 * sigma ** 2))

In [None]:
def mmd_loss(source_features, target_features, sigma=1.0):
    source_kernel = gaussian_kernel(source_features, source_features, sigma)
    target_kernel = gaussian_kernel(target_features, target_features, sigma)
    source_target_kernel = gaussian_kernel(source_features, target_features, sigma)

    loss = source_kernel.mean() + target_kernel.mean() - 2 * source_target_kernel.mean()
    return loss

### Feature Extractor & Classifier

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.MaxPool2d(2)
        )
    def forward(self, x):
        return self.network(x).view(x.size(0), -1)

In [None]:
class Classifier(nn.Module):
    def __init__(self, input_dim=64*8*8, num_classes=10):
        super(Classifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        return self.fc(x)

## Train & Evaluate

In [None]:
def train_feature_alignment(feature_extractor, classifier, mnist_loader, svhn_loader, num_epochs=5, lambda_mmd=1.0):
    feature_extractor.train()
    classifier.train()

    optimizer = optim.NAdam(list(feature_extractor.parameters()) + list(classifier.parameters()), lr=5e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, num_epochs + 1):
        mnist_iter = iter(mnist_loader)
        svhn_iter = iter(svhn_loader)
        len_dataloader = min(len(mnist_iter), len(svhn_iter))

        total_loss, total_mmd_loss, total_cls_loss = 0, 0, 0
        for i in range(len_dataloader):
            mnist_data, mnist_labels = next(mnist_iter)
            mnist_data, mnist_labels = mnist_data.to(device), mnist_labels.to(device)

            svhn_data, _ = next(svhn_iter)
            svhn_data = svhn_data.to(device)

            source_features = feature_extractor(mnist_data)
            target_features = feature_extractor(svhn_data)

            preds = classifier(source_features)
            cls_loss = criterion(preds, mnist_labels)

            mmd = mmd_loss(source_features, target_features)

            loss = cls_loss + lambda_mmd * mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_cls_loss += cls_loss.item()
            total_mmd_loss += mmd.item()

        print(f"Epoch [{epoch}/{num_epochs}], Loss: {total_loss/len_dataloader:.4f}, "
              f"Cls Loss: {total_cls_loss/len_dataloader:.4f}, MMD Loss: {total_mmd_loss/len_dataloader:.4f}")

In [None]:
def evaluate(feature_extractor, classifier, dataloader):
    feature_extractor.eval()
    classifier.eval()

    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            features = feature_extractor(imgs)
            outputs = classifier(features)
            _, preds = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
    return correct / total


In [None]:
feature_extractor = FeatureExtractor().to(device)
classifier = Classifier().to(device)

In [None]:
train_feature_alignment(feature_extractor, classifier, mnist_loader, svhn_loader, num_epochs=5, lambda_mmd=0.5)

Epoch [1/5], Loss: 0.1883, Cls Loss: 0.1708, MMD Loss: 0.0350
Epoch [2/5], Loss: 0.0514, Cls Loss: 0.0434, MMD Loss: 0.0159
Epoch [3/5], Loss: 0.0338, Cls Loss: 0.0281, MMD Loss: 0.0114
Epoch [4/5], Loss: 0.0246, Cls Loss: 0.0200, MMD Loss: 0.0093
Epoch [5/5], Loss: 0.0183, Cls Loss: 0.0142, MMD Loss: 0.0081


In [None]:
svhn_acc = evaluate(feature_extractor, classifier, svhn_loader)
print(f"Accuracy on SVHN: {svhn_acc:.4f}")

Accuracy on SVHN: 0.2886


In [None]:
svhn_acc = evaluate(feature_extractor, classifier, mnist_loader)
print(f"Accuracy on MNIST: {svhn_acc:.4f}")

Accuracy on MNIST: 0.9996


In [None]:
train_feature_alignment(feature_extractor, classifier, mnist_loader, svhn_loader, num_epochs=10, lambda_mmd=0.5)

Epoch [1/10], Loss: 0.0035, Cls Loss: 0.0018, MMD Loss: 0.0033
Epoch [2/10], Loss: 0.0025, Cls Loss: 0.0011, MMD Loss: 0.0027
Epoch [3/10], Loss: 0.0016, Cls Loss: 0.0005, MMD Loss: 0.0021
Epoch [4/10], Loss: 0.0015, Cls Loss: 0.0005, MMD Loss: 0.0019
Epoch [5/10], Loss: 0.0010, Cls Loss: 0.0002, MMD Loss: 0.0015
Epoch [6/10], Loss: 0.0008, Cls Loss: 0.0002, MMD Loss: 0.0012
Epoch [7/10], Loss: 0.0007, Cls Loss: 0.0002, MMD Loss: 0.0011
Epoch [8/10], Loss: 0.0006, Cls Loss: 0.0001, MMD Loss: 0.0009
Epoch [9/10], Loss: 0.0005, Cls Loss: 0.0001, MMD Loss: 0.0008
Epoch [10/10], Loss: 0.0005, Cls Loss: 0.0001, MMD Loss: 0.0007


In [None]:
svhn_acc = evaluate(feature_extractor, classifier, svhn_loader)
print(f"Accuracy on SVHN: {svhn_acc:.4f}")

Accuracy on SVHN: 0.2710


In [None]:
svhn_acc = evaluate(feature_extractor, classifier, mnist_loader)
print(f"Accuracy on MNIST: {svhn_acc:.4f}")

Accuracy on MNIST: 0.9996
