## Implements RevGrad:
-   [Unsupervised Domain Adaptation by Backpropagation, Ganin & Lemptsky (2014)](https://arxiv.org/abs/1409.7495)
-   [Domain-adversarial training of neural networks, Ganin et al. (2016)](https://arxiv.org/abs/1505.07818)

### Model Architecture
![](model_archs/Unsupervised_Domain_Adaptation_by_Backpropagation_model_arch.png)
Image borrowed from [Unsupervised Domain Adaptation by Backpropagation, Ganin & Lemptsky (2014)](https://arxiv.org/abs/1409.7495)

### Make Necessary Imports

In [1]:
import argparse

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor
from tqdm import tqdm

import config
from data import MNISTM
from models import Net
from utils import GrayscaleToRgb, GradientReversal

If CUDA-enabled GPU isn't found, we run on CPU.

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Set necessary hyperparameters

In [3]:
MODEL_FILE = 'trained_models/source.pt'
batch_size = 64
epochs = 15

### Train RevGrad model 

In [4]:
model = Net().to(device)
model.load_state_dict(torch.load(MODEL_FILE))
feature_extractor = model.feature_extractor
clf = model.classifier

discriminator = nn.Sequential(
    GradientReversal(),
    nn.Linear(320, 50),
    nn.ReLU(),
    nn.Linear(50, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
).to(device)

half_batch = batch_size // 2
source_dataset = MNIST(config.DATA_DIR/'mnist', train=True, download=True,
                      transform=Compose([GrayscaleToRgb(), ToTensor()]))
source_loader = DataLoader(source_dataset, batch_size=half_batch,
                           shuffle=True, num_workers=1, pin_memory=True)

target_dataset = MNISTM(train=False)
target_loader = DataLoader(target_dataset, batch_size=half_batch,
                           shuffle=True, num_workers=1, pin_memory=True)

optim = torch.optim.Adam(list(discriminator.parameters()) + list(model.parameters()))

for epoch in range(1, epochs+1):
    batches = zip(source_loader, target_loader)
    n_batches = min(len(source_loader), len(target_loader))

    total_domain_loss = total_label_accuracy = 0
    for (source_x, source_labels), (target_x, _) in tqdm(batches, leave=False, total=n_batches):
            x = torch.cat([source_x, target_x])
            x = x.to(device)
            domain_y = torch.cat([torch.ones(source_x.shape[0]),
                                  torch.zeros(target_x.shape[0])])
            domain_y = domain_y.to(device)
            label_y = source_labels.to(device)

            features = feature_extractor(x).view(x.shape[0], -1)
            domain_preds = discriminator(features).squeeze()
            label_preds = clf(features[:source_x.shape[0]])

            domain_loss = F.binary_cross_entropy_with_logits(domain_preds, domain_y)
            label_loss = F.cross_entropy(label_preds, label_y)
            loss = domain_loss + label_loss

            optim.zero_grad()
            loss.backward()
            optim.step()

            total_domain_loss += domain_loss.item()
            total_label_accuracy += (label_preds.max(1)[1] == label_y).float().mean().item()

    mean_loss = total_domain_loss / n_batches
    mean_accuracy = total_label_accuracy / n_batches
    tqdm.write(f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, '
               f'source_accuracy={mean_accuracy:.4f}')
    
    torch.save(model.state_dict(), 'trained_models/revgrad.pt')

                                                                                                                                                    

EPOCH 001: domain_loss=0.4291, source_accuracy=0.9508


                                                                                                                                                    

EPOCH 002: domain_loss=0.3018, source_accuracy=0.9522


                                                                                                                                                    

EPOCH 003: domain_loss=0.2726, source_accuracy=0.9452


                                                                                                                                                    

EPOCH 004: domain_loss=0.2652, source_accuracy=0.9467


                                                                                                                                                    

EPOCH 005: domain_loss=0.2541, source_accuracy=0.9470


                                                                                                                                                    

EPOCH 006: domain_loss=0.2657, source_accuracy=0.9436


                                                                                                                                                    

EPOCH 007: domain_loss=0.2772, source_accuracy=0.9397


                                                                                                                                                    

EPOCH 008: domain_loss=0.2777, source_accuracy=0.9421


                                                                                                                                                    

EPOCH 009: domain_loss=0.2978, source_accuracy=0.9446


                                                                                                                                                    

EPOCH 010: domain_loss=0.3275, source_accuracy=0.9412


                                                                                                                                                    

EPOCH 011: domain_loss=0.3843, source_accuracy=0.9364


                                                                                                                                                    

EPOCH 012: domain_loss=0.4990, source_accuracy=0.9234


                                                                                                                                                    

EPOCH 013: domain_loss=0.5654, source_accuracy=0.9242


                                                                                                                                                    

EPOCH 014: domain_loss=0.5986, source_accuracy=0.9237


                                                                                                                                                    

EPOCH 015: domain_loss=0.5667, source_accuracy=0.9347
