## Introduction

Adversarial Discriminative Domain Adaptation (ADDA)

In real-world applications, machine learning models are often trained on a dataset that is not perfectly representative of the domain where the model will be deployed. This difference between the source domain (where the model is trained) and the target domain (where the model is applied) is called **domain shift**, and it often leads to a significant drop in performance when models trained on the source domain are tested on the target domain.

<div style="text-align:center">
<img src="https://i.ibb.co/v3x80W9/Screenshot-2024-10-22-at-11-24-32-AM.png" alt="Adversarial Discriminative Domain Adaptation" border="0" width=400>
</div>

Adversarial Discriminative Domain Adaptation (ADDA) is a technique designed to address this problem through unsupervised domain adaptation. The key idea is to transfer the knowledge learned from a labeled source domain to an unlabeled target domain by aligning the feature distributions between the two domains, See above Figure. ADDA achieves this through a two-step process:

1. Feature learning on the source domain: A feature extractor is trained on the labeled source domain in a supervised manner.
2. Adversarial adaptation for the target domain: The model then uses adversarial training to align the target domain features with those from the source domain, ensuring that the feature representations are domain-invariant.

The method is inspired by the concept of Generative Adversarial Networks (GANs), where a domain discriminator attempts to distinguish between source and target domain features, while the target feature extractor learns to fool the discriminator by producing features that are indistinguishable between domains.

Objectives:

- Learn domain-invariant features: Align the feature spaces of the source and target domains without requiring labels from the target domain.
- Improve generalization: Train a model that performs well on the target domain despite the domain shift.

This notebook will guide you through the implementation of ADDA, covering key steps such as:

- Pretraining a feature extractor on the source domain.
- Applying adversarial training for domain adaptation.
- Evaluating the performance on the target domain.

By the end of this notebook, you will understand how adversarial techniques can be used to adapt a model to a new, unlabeled domain.


## Model Overview

<div style="text-align:center">
    <img src="https://i.ibb.co/2sCRxmF/Screenshot-2024-10-22-at-11-29-59-AM.png" alt="Screenshot-2024-10-22-at-11-29-59-AM" border="0" width=800>
</div>

Paper: https://arxiv.org/pdf/1702.05464

#### NOTE:
**Here we have MNIST as Source images and MNIST-M as Target**

## Pretraining a feature extractor on the source domain (MNIST)

### Libraries 📚⬇

In [None]:
import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path
import random, math, cv2
from tqdm import tqdm_notebook as tqdm

import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Compose, ToTensor

import matplotlib.pyplot as plt
import plotly
import plotly.express as px
import plotly.graph_objects as go
import warnings
warnings.filterwarnings("ignore")

In [None]:
def visualize_digits(dataset, k=80, mnistm=False, cmap=None, title=None):
    
    ncols = 20
    indices = random.choices(range(len(dataset)), k=k)
    nrows = math.floor(len(indices)/ncols)
    
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols,nrows+0.4), gridspec_kw=dict(wspace=0.1, hspace=0.1), subplot_kw=dict(yticks=[], xticks=[]))
    axes_flat = axes.reshape(-1)
    fig.suptitle(title, fontsize=20)
    
    for list_idx, image_idx in enumerate(indices[:ncols*nrows]):
        ax = axes_flat[list_idx]
        image = dataset[image_idx][0]
        image = image.numpy().transpose(1, 2, 0)
        ax.imshow(image, cmap=cmap)

def set_requires_grad(model, requires_grad=True):
    for param in model.parameters():
        param.requires_grad = requires_grad

def loop_iterable(iterable):
    while True:
        yield from iterable

class GrayscaleToRgb:
    """Convert a grayscale image to rgb"""
    def __call__(self, image):
        image = np.array(image)
        image = np.dstack([image, image, image])
        return Image.fromarray(image)

In [None]:
MNIST_DATA_DIR = Path('/kaggle/working')
BSDS_DATA_DIR = Path('/kaggle/working/bsds500')
MODEL_FILE = Path('best_source_weights_mnist.pth')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
batch_size = 64
epochs = 5


### Download the data

In [None]:
!wget https://raw.githubusercontent.com/fgnt/mnist/master/train-images-idx3-ubyte.gz
!wget https://raw.githubusercontent.com/fgnt/mnist/master/train-labels-idx1-ubyte.gz
!wget https://raw.githubusercontent.com/fgnt/mnist/master/t10k-images-idx3-ubyte.gz
!wget https://raw.githubusercontent.com/fgnt/mnist/master/t10k-labels-idx1-ubyte.gz
!mv train-images-idx3-ubyte.gz /kaggle/working/mnist/MNIST/raw/
!mv train-labels-idx1-ubyte.gz /kaggle/working/mnist/MNIST/raw/
!mv t10k-images-idx3-ubyte.gz /kaggle/working/mnist/MNIST/raw/
!mv t10k-labels-idx1-ubyte.gz /kaggle/working/mnist/MNIST/raw/

In [None]:
!curl -L -o bsds500.zip https://www.kaggle.com/api/v1/datasets/download/balraj98/berkeley-segmentation-dataset-500-bsds500
!mkdir bsds500
!unzip -q bsds500.zip -d bsds500/
!rm bsds500.zip

### Define Model

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 10, kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=5),
            nn.MaxPool2d(2),
            nn.Dropout2d(),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(50, 10),
            nn.LogSoftmax(),
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(x.shape[0], -1)
        logits = self.classifier(features)
        return logits

### Get Dataset & Dataloaders

In [None]:
source_model = Net().to(device)
if MODEL_FILE.exists():
    source_model.load_state_dict(torch.load(MODEL_FILE))

train_dataset = MNIST(MNIST_DATA_DIR / 'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()]))
test_dataset = MNIST(MNIST_DATA_DIR / 'mnist', train=False, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()]))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=32, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=32, pin_memory=True)

source_optim = torch.optim.Adam(source_model.parameters(), lr=0.002)
criterion = nn.NLLLoss()

### Visualize MNIST Data

In [None]:
visualize_digits(dataset=train_dataset, k=120, cmap='gray', title='Sample MNIST Images')

### Train Source Model on MNIST

In [None]:
train_losses, train_accuracies, train_counter = [], [], []
test_losses, test_accuracies = [], []
test_counter = [idx*len(train_loader.dataset) for idx in range(0, epochs+1)]

def train(epoch):
    train_loss, train_accuracy = 0, 0
    source_model.train()
    tqdm_bar = tqdm(train_loader, desc=f'Training Epoch {epoch} ', total=int(len(train_loader)))
    for idx, (images, labels) in enumerate(tqdm_bar):
        images, labels = images.to(device), labels.to(device)
        source_optim.zero_grad()
        with torch.set_grad_enabled(True):
            outputs = source_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            source_optim.step()
        train_loss += loss.item()
        train_losses.append(loss.item())
        outputs = torch.argmax(outputs, dim=1).type(torch.FloatTensor).to(device)
        train_batch_accuracy = torch.mean((outputs == labels).type(torch.FloatTensor)).item()
        train_accuracy += train_batch_accuracy
        train_accuracies.append(train_batch_accuracy)
        tqdm_bar.set_postfix(train_loss=(train_loss/(idx+1)), train_accuracy=train_accuracy/(idx+1))
        train_counter.append(idx*batch_size + images.size(0) + epoch*len(train_dataset))

def test():
    test_loss, test_accuracy = 0, 0
    source_model.eval()
    tqdm_bar = tqdm(test_loader, desc=f'Testing ', total=int(len(test_loader)))
    for idx, (images, labels) in enumerate(tqdm_bar):
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            outputs = source_model(images)
            loss = criterion(outputs, labels)
        test_loss += loss.item()
        outputs = torch.argmax(outputs, dim=1).type(torch.FloatTensor).to(device)
        test_accuracy += torch.mean((outputs == labels).type(torch.FloatTensor)).item()
        tqdm_bar.set_postfix(test_loss=(test_loss/(idx+1)), test_accuracy=test_accuracy/(idx+1))
    test_losses.append(test_loss/len(test_loader))
    test_accuracies.append(test_accuracy/len(test_loader))
    if np.argmax(test_accuracies) == len(test_accuracies)-1:
        torch.save(source_model.state_dict(), 'best_source_weights_mnist.pth')
        
test()
for epoch in range(epochs):
    train(epoch)
    test()

#### Visualize Training & Testing Results 📈

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=train_counter, y=train_losses, mode='lines', name='Train loss'))
fig.add_trace(go.Scatter(x=test_counter, y=test_losses, marker_symbol='star-diamond', 
                         marker_color='orange', marker_line_width=1, marker_size=9, mode='markers', name='Test loss'))
fig.update_layout(
    width=1000,
    height=500,
    title="Train vs. Test Loss",
    xaxis_title="Number of training examples seen",
    yaxis_title="Negative Log Likelihood loss"),
fig.show()

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=train_counter, y=train_accuracies, mode='lines', name='Train loss'))
fig.add_trace(go.Scatter(x=test_counter, y=test_accuracies, marker_symbol='star-diamond', 
                         marker_color='orange', marker_line_width=1, marker_size=9, mode='markers', name='Test Accuracy'))
fig.update_layout(
    width=1000,
    height=500,
    title="Train vs. Test Accuracy",
    xaxis_title="Number of training examples seen",
    yaxis_title="Accuracy")
fig.show()

## Adversarial training for domain adaptation

Now we will load the pre-trained model on source domain and apply the adversarial training for domain adaptation

In [None]:
# Hyperparameters

batch_size = 64
iterations = 500
epochs = 4
k_disc = 1
k_clf = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

`MODEL_FILE` contains the path to pretrained source model

In [None]:
source_model = Net().to(device)
if MODEL_FILE:
    source_model.load_state_dict(torch.load(MODEL_FILE, map_location=device))
source_model.eval()
set_requires_grad(source_model, requires_grad=False)

clf = source_model
source_model = source_model.feature_extractor

target_model = Net().to(device)
if MODEL_FILE:
    target_model.load_state_dict(torch.load(MODEL_FILE, map_location=device))
target_model = target_model.feature_extractor

discriminator = nn.Sequential(
    nn.Linear(320, 120),
    nn.ReLU(),
    nn.Linear(120, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
).to(device)

### Define Dataset Classes

In [None]:
class BSDS500(Dataset):

    def __init__(self):
        image_folder = BSDS_DATA_DIR / 'images'
        self.image_files = list(map(str, image_folder.glob('*/*.jpg')))

    def __getitem__(self, i):
        image = cv2.imread(self.image_files[i], cv2.IMREAD_COLOR)
        tensor = torch.from_numpy(image.transpose(2, 0, 1))
        return tensor

    def __len__(self):
        return len(self.image_files)


class MNISTM(Dataset):

    def __init__(self, train=True):
        super(MNISTM, self).__init__()
        self.mnist = datasets.MNIST(MNIST_DATA_DIR / 'mnist', train=train,
                                    download=True)
        self.bsds = BSDS500()
        # Fix RNG so the same images are used for blending
        self.rng = np.random.RandomState(42)

    def __getitem__(self, i):
        digit, label = self.mnist[i]
        digit = transforms.ToTensor()(digit)
        bsds_image = self._random_bsds_image()
        patch = self._random_patch(bsds_image)
        patch = patch.float() / 255
        blend = torch.abs(patch - digit)
        return blend, label

    def _random_patch(self, image, size=(28, 28)):
        _, im_height, im_width = image.shape
        x = self.rng.randint(0, im_width-size[1])
        y = self.rng.randint(0, im_height-size[0])
        return image[:, y:y+size[0], x:x+size[1]]

    def _random_bsds_image(self):
        i = self.rng.choice(len(self.bsds))
        return self.bsds[i]

    def __len__(self):
        return len(self.mnist)

### Get Dataset & Dataloaders

In [None]:
half_batch = batch_size // 2

source_dataset = MNIST(MNIST_DATA_DIR/'mnist', train=True, download=True, transform=Compose([GrayscaleToRgb(), ToTensor()]))
source_loader = DataLoader(source_dataset, batch_size=half_batch, shuffle=True, num_workers=16, pin_memory=True)

target_train_dataset, target_test_dataset = MNISTM(train=True), MNISTM(train=False)
target_train_loader = DataLoader(target_train_dataset, batch_size=half_batch, shuffle=True, num_workers=16, pin_memory=True)
target_test_loader = DataLoader(target_test_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)

discriminator_optim = torch.optim.Adam(discriminator.parameters())
target_optim = torch.optim.Adam(target_model.parameters())
criterion_train = nn.BCEWithLogitsLoss()
criterion_test = nn.NLLLoss()

### Visualize MNIST-M & MNIST Data 🖼️

In [None]:
visualize_digits(dataset=target_train_dataset, k=200, mnistm=True, title='Sample MNIST-M Images')

In [None]:
visualize_digits(dataset=source_dataset, k=120, cmap='gray', title='Sample MNIST Images')

### Adversarial Discriminative Domain Adaptation

In [None]:
disc_losses, disc_accuracies, disc_train_counter = [], [], []
clf_disc_losses, clf_disc_train_counter = [], []
clf_losses, clf_accuracies = [], []
clf_test_counter = [idx*iterations*k_clf*target_train_loader.batch_size for idx in range(0, epochs+1)]

In [None]:
test_loss, test_accuracy = 0, 0
clf.eval()
tqdm_bar = tqdm(target_test_loader, desc=f'Testing ', total=int(len(target_test_loader)))
for idx, (images, labels) in enumerate(tqdm_bar):
    images, labels = images.to(device), labels.to(device)
    with torch.no_grad():
        outputs = clf(images)
        loss = criterion_test(outputs, labels)
    test_loss += loss.item()
    outputs = torch.argmax(outputs, dim=1).type(torch.FloatTensor).to(device)
    test_accuracy += torch.mean((outputs == labels).type(torch.FloatTensor)).item()
    tqdm_bar.set_postfix(test_loss=(test_loss/(idx+1)), test_accuracy=test_accuracy/(idx+1))
clf_losses.append(test_loss/len(target_test_loader))
clf_accuracies.append(test_accuracy/len(target_test_loader))

In [None]:
for epoch in range(epochs):
    target_batch_iterator = loop_iterable(target_train_loader)
    batch_iterator = zip(loop_iterable(source_loader), loop_iterable(target_train_loader))
    disc_loss, disc_accuracy = 0, 0
    clf_disc_loss = 0
    test_loss, test_accuracy = 0, 0
    tqdm_bar = tqdm(range(iterations), desc=f'Training Epoch {epoch} ', total=iterations)
    for iter_idx in tqdm_bar:
        # Train discriminator
        set_requires_grad(target_model, requires_grad=False)
        set_requires_grad(discriminator, requires_grad=True)
        for disc_idx in range(k_disc):
            (source_x, _), (target_x, _) = next(batch_iterator)
            source_x, target_x = source_x.to(device), target_x.to(device)
            source_features = source_model(source_x).view(source_x.shape[0], -1)
            target_features = target_model(target_x).view(target_x.shape[0], -1)
            discriminator_x = torch.cat([source_features, target_features])
            discriminator_y = torch.cat([torch.ones(source_x.shape[0], device=device), torch.zeros(target_x.shape[0], device=device)])
            preds = discriminator(discriminator_x).squeeze()
            loss = criterion_train(preds, discriminator_y)
            discriminator_optim.zero_grad()
            loss.backward()
            discriminator_optim.step()
            disc_loss += loss.item()
            disc_losses.append(loss.item())
            disc_batch_accuracy = ((preds > 0).long() == discriminator_y.long()).float().mean().item()
            disc_accuracy += disc_batch_accuracy
            disc_accuracies.append(disc_batch_accuracy)
            disc_train_counter.append((disc_idx+1)*source_x.size(0) + iter_idx*k_disc*target_train_loader.batch_size + epoch*iterations*k_disc*target_train_loader.batch_size)

        # Train classifier
        set_requires_grad(target_model, requires_grad=True)
        set_requires_grad(discriminator, requires_grad=False)
        for clf_idx in range(k_clf):
            _, (target_x, _) = next(batch_iterator)
            target_x = target_x.to(device)
            target_features = target_model(target_x).view(target_x.shape[0], -1)
            # Flipped Labels
            discriminator_y = torch.ones(target_x.shape[0], device=device)
            preds = discriminator(target_features).squeeze()
            loss = criterion_train(preds, discriminator_y)
            target_optim.zero_grad()
            loss.backward()
            target_optim.step()
            clf_disc_loss += loss.item()
            clf_disc_losses.append(loss.item())
            clf_disc_train_counter.append(source_x.size(0) + clf_idx*half_batch + iter_idx*k_clf*half_batch + epoch*iterations*k_clf*half_batch)
        tqdm_bar.set_postfix(disc_loss=disc_loss/((iter_idx+1)*k_disc), disc_accuracy=disc_accuracy/((iter_idx+1)*k_disc),
                             clf_disc_loss=clf_disc_loss/((iter_idx+1)*k_clf))

    # Test full target model
    test_loss, test_accuracy = 0, 0
    clf.feature_extractor = target_model
    clf.eval()
    tqdm_bar = tqdm(target_test_loader, desc=f'Testing Epoch {epoch} (Full Target Model)', total=int(len(target_test_loader)))
    for idx, (images, labels) in enumerate(tqdm_bar):
        images, labels = images.to(device), labels.to(device)
        with torch.no_grad():
            outputs = clf(images)
            loss = criterion_test(outputs, labels)
        test_loss += loss.item()
        outputs = torch.argmax(outputs, dim=1).type(torch.FloatTensor).to(device)
        test_accuracy += torch.mean((outputs == labels).type(torch.FloatTensor)).item()
        tqdm_bar.set_postfix(test_loss=(test_loss/(idx+1)), test_accuracy=test_accuracy/(idx+1))
    clf_losses.append(test_loss/len(target_test_loader))
    clf_accuracies.append(test_accuracy/len(target_test_loader))
    if np.argmax(clf_accuracies) == len(clf_accuracies)-1:
        torch.save(clf.state_dict(), 'adda_target_weights.pth')
        

#### Visualize Training & Testing Results 📈

In [None]:
fig = plotly.subplots.make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(go.Scatter(x=disc_train_counter, y=disc_losses, mode='lines', name='Disc Loss'), secondary_y=False)
fig.add_trace(go.Scatter(x=disc_train_counter, y=disc_accuracies, mode='lines', name='Disc Accuracy', line_color='lightseagreen'), secondary_y=True)
fig.update_layout(
    width=1000,
    height=500,
    title="Discriminator Loss vs Accuracy")
fig.update_xaxes(title_text="Number of training examples seen")
fig.update_yaxes(title_text="Discriminator <b>Loss</b> (BCE)", secondary_y=False)
fig.update_yaxes(title_text="Discriminator <b>Accuracy</b>", secondary_y=True)
fig.show()

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=clf_disc_train_counter, y=clf_disc_losses, mode='lines', name='Clf-Disc Train Loss'))
fig.update_layout(
    width=1000,
    height=500,
    title="Clf-Disc Loss",
    xaxis_title="Number of training examples seen",
    yaxis_title="Binary Cross Entropy Loss"),
fig.show()

In [None]:
fig = plotly.subplots.make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(go.Scatter(x=clf_test_counter, y=clf_accuracies, marker_symbol='star-diamond', 
                         marker_line_color="orange", marker_line_width=1, marker_size=9, mode='lines+markers', 
                         name='Target Accuracy'), secondary_y=False)
fig.add_trace(go.Scatter(x=clf_test_counter, y=clf_losses, marker_symbol='star-square', 
                         marker_line_color="lightseagreen", marker_line_width=1, marker_size=9, mode='lines+markers',
                         name='Target Loss'), secondary_y=True)
fig.update_layout(
    width=1000,
    height=500,
    title="Full Target Model Loss vs Accuracy")
fig.update_xaxes(title_text="Number of training examples seen")
fig.update_yaxes(title_text="Target <b>Accuracy</b>", secondary_y=False)
fig.update_yaxes(title_text="Target <b>Loss</b> (NLLLoss)", secondary_y=True)
fig.show()

## How to improve the model?

Try the following:
- Modify the base Model (see `Net` class), try a deeper network.
- Change Hyperparameters: see `k_disc`, `k_clf`.
- Modify the discriminator architecture
- Add Data agumentation to improve generalization. Check:
    - MixUp: https://pytorch.org/vision/main/generated/torchvision.transforms.v2.MixUp.html
    - CutMix: https://pytorch.org/vision/main/generated/torchvision.transforms.v2.CutMix.html#torchvision.transforms.v2.CutMix
    - How to: https://pytorch.org/vision/main/auto_examples/transforms/plot_cutmix_mixup.html#sphx-glr-auto-examples-transforms-plot-cutmix-mixup-py
    - More augmentations: https://pytorch.org/vision/main/auto_examples/transforms/index.html