# Domain Adaptation
In this notebook we will carry out a simple experiment of domain adaptation. As usual, let's start by importing the necessary libraries

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
import os

## Network definition
Each DA block consist of a domain specific Batch Normalization Layer followed by a domain agnostic scale-shift operation. Note that the domain specific BN layer will accumulate the domain specific first and second order statistics, i.e., mean and std. This can achieved be by setting `track_running_stats=True`. Details about this implementation can be found in the [docs](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html). We will employ the architecture presented in detail [here](http://proceedings.mlr.press/v37/ganin15-supp.pdf). 

In [None]:
class DIALNet(nn.Module):

	def __init__(self):
   
		super(DIALNet, self).__init__()
  
		self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2)
		self.bns1 = nn.BatchNorm2d(64, affine=False)
		self.bnt1 = nn.BatchNorm2d(64, affine=False)
		self.gamma1 = nn.Parameter(torch.ones(64, 1, 1))
		self.beta1 = nn.Parameter(torch.zeros(64, 1, 1))

		self.conv2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
		self.bns2 = nn.BatchNorm2d(64, affine=False)
		self.bnt2 = nn.BatchNorm2d(64, affine=False)
		self.gamma2 = nn.Parameter(torch.ones(64, 1, 1))
		self.beta2 = nn.Parameter(torch.zeros(64, 1, 1))

		self.conv3 = nn.Conv2d(64, 128, kernel_size=5, padding=2)
		self.bns3 = nn.BatchNorm2d(128, affine=False)
		self.bnt3 = nn.BatchNorm2d(128, affine=False)
		self.gamma3 = nn.Parameter(torch.ones(128, 1, 1))
		self.beta3 = nn.Parameter(torch.zeros(128, 1, 1))

		self.fc4 = nn.Linear(6272, 3072)
		self.bns4 = nn.BatchNorm1d(3072, affine=False)
		self.bnt4 = nn.BatchNorm1d(3072, affine=False)
		self.gamma4 = nn.Parameter(torch.ones(1, 3072))
		self.beta4 = nn.Parameter(torch.zeros(1, 3072))

		self.fc5 = nn.Linear(3072, 2048)
		self.bns5 = nn.BatchNorm1d(2048, affine=False)
		self.bnt5 = nn.BatchNorm1d(2048, affine=False)
		self.gamma5 = nn.Parameter(torch.ones(1, 2048))
		self.beta5 = nn.Parameter(torch.zeros(1, 2048))

		self.fc6 = nn.Linear(2048, 10)
		self.bns6 = nn.BatchNorm1d(10, affine=False)
		self.bnt6 = nn.BatchNorm1d(10, affine=False)
		self.gamma6 = nn.Parameter(torch.ones(1, 10))
		self.beta6 = nn.Parameter(torch.zeros(1, 10))

	def forward(self, x):

		if self.training:

			x = self.conv1(x)
			x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0)
			x = F.max_pool2d(F.relu(torch.cat((self.bns1(x_source), self.bnt1(x_target)), dim=0)*self.gamma1 + self.beta1), 
                       kernel_size=3, stride=2)

			x = self.conv2(x)
			x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0)
			x = F.max_pool2d(F.relu(torch.cat((self.bns2(x_source), self.bnt2(x_target)), dim=0)*self.gamma2 + self.beta2), 
                       kernel_size=3, stride=2)

			x = self.conv3(x)
			x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0)
			x = F.relu(torch.cat((self.bns3(x_source), self.bnt3(x_target)), dim=0)*self.gamma3 + self.beta3)

			x = x.view(x.shape[0], -1)
			x = self.fc4(x)
			x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0)
			x = F.dropout(F.relu(torch.cat((self.bns4(x_source), self.bnt4(x_target)), dim=0)*self.gamma4 + self.beta4), 
                    training=self.training)

			x = self.fc5(x)
			x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0)
			x = F.dropout(F.relu(torch.cat((self.bns5(x_source), self.bnt5(x_target)), dim=0)*self.gamma5 + self.beta5), 
                    training=self.training)

			x = self.fc6(x)
			x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0)
			x = torch.cat((self.bns6(x_source), self.bnt6(x_target)), dim=0)*self.gamma6 + self.beta6

		else:

			x = self.conv1(x)
			x = F.max_pool2d(F.relu(self.bnt1(x)*self.gamma1 + self.beta1), kernel_size=3, stride=2)

			x = self.conv2(x)
			x = F.max_pool2d(F.relu(self.bnt2(x)*self.gamma2 + self.beta2), kernel_size=3, stride=2)

			x = self.conv3(x)
			x = F.relu(self.bnt3(x)*self.gamma3 + self.beta3)

			x = x.view(x.shape[0], -1)
			x = self.fc4(x)
			x = F.dropout(F.relu(self.bnt4(x)*self.gamma4 + self.beta4), training=self.training)

			x = self.fc5(x)
			x = F.dropout(F.relu(self.bnt5(x)*self.gamma5 + self.beta5), training=self.training)

			x = self.fc6(x)
			x = self.bnt6(x)*self.gamma6 + self.beta6
      
		return x

## Cost function
For the source domain, since we are provided with **label information**, we are can simply use the usual cross-entropy classification loss. On the other hand, we can define for the **unlabelled** target domain an entropy loss meant to maximally separate the unlabelled data. Please refer to the [original paper](https://arxiv.org/abs/1704.08082) for further details.

$L^{t}(\theta) = - \frac{1}{m} \displaystyle \sum^{m}_{i=1} p_{i}\log{p_i}$

In [None]:
def get_ce_cost_function():
  cost_function = torch.nn.CrossEntropyLoss()
  return cost_function

def get_entropy_loss(x):
  p = F.softmax(x, dim=1)
  q = F.log_softmax(x, dim=1)
  b = p * q
  b = -1.0 * b.sum(-1).mean()
  return b

## Optimizer
We will employ, as usual, a stochastic gradient descent optimizer

In [None]:
def get_optimizer(net, lr, wd, momentum):
  optimizer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd, momentum=momentum)
  return optimizer

## Training and test steps
We define our training and test steps for the DA experiment

In [None]:
def training_step(net, source_data_loader, target_data_loader, optimizer, 
          get_ce_cost_function, entropy_loss_weight, device='cuda:0'):
  source_samples = 0.
  target_samples = 0.
  cumulative_ce_loss = 0.
  cumulative_en_loss = 0.
  cumulative_accuracy = 0.
  
  target_iter = iter(target_data_loader)

  # strictly needed if network contains layers which has different behaviours between train and test
  net.train()
  for batch_idx, (inputs_source, targets) in enumerate(source_data_loader):
    
    # get target data. If the target iterator reaches the end, restart it
    try:
      inputs_target, _ = next(target_iter)
    except:
      target_iter = iter(target_data_loader)
      inputs_target, _ = next(target_iter)
    
    inputs = torch.cat((inputs_source, inputs_target), dim=0)
    
    # load data into GPU
    inputs = inputs.to(device)
    targets = targets.to(device)
      
    # forward pass
    outputs = net(inputs)
    
    # split the source and target outputs
    source_output, target_output = torch.split(outputs, 
                                               split_size_or_sections=outputs.shape[0] // 2, 
                                               dim=0)
    
    # apply the losses
    ce_loss = get_ce_cost_function(source_output,targets)
    en_loss = get_entropy_loss(target_output)
    
    loss = ce_loss + entropy_loss_weight * en_loss
    
    # backward pass
    loss.backward()
    
    # update parameters
    optimizer.step()
    
    # reset the optimizer
    optimizer.zero_grad()

    # print statistics
    source_samples+=inputs_source.shape[0]
    target_samples+=inputs_target.shape[0]
    
    cumulative_ce_loss += ce_loss.item()
    cumulative_en_loss += en_loss.item()
    _, predicted = source_output.max(1)
    cumulative_accuracy += predicted.eq(targets).sum().item()

  return cumulative_ce_loss/source_samples, cumulative_en_loss/target_samples, cumulative_accuracy/source_samples*100


def test_step(net, data_loader, cost_function, device='cuda:0'):
  samples = 0.
  cumulative_loss = 0.
  cumulative_accuracy = 0.

  # strictly needed if network contains layers which has different behaviours between train and test
  net.eval()

  with torch.no_grad():

    for batch_idx, (inputs, targets) in enumerate(data_loader):

      # load data into GPU
      inputs = inputs.to(device)
      targets = targets.to(device)
        
      # forward pass
      outputs = net(inputs)

      # apply the loss
      loss = cost_function(outputs, targets)

      # print statistics
      samples+=inputs.shape[0]
      cumulative_loss += loss.item() # Note: the .item() is needed to extract scalars from tensors
      _, predicted = outputs.max(1)
      cumulative_accuracy += predicted.eq(targets).sum().item()

  return cumulative_loss/samples, cumulative_accuracy/samples*100

## Data loading
In this block we define the data loading utility for our experiment

In [None]:
def get_data(batch_size, test_batch_size=256):
  
  # prepare data transformations and then combine them sequentially
  transform_mnist = list()
  transform_mnist.append(T.ToTensor())                                              # convert Numpy to Pytorch Tensor
  transform_mnist.append(T.Lambda(lambda x: F.pad(x, (2, 2, 2, 2), 'constant', 0))) # pad zeros to make MNIST 32 x 32
  transform_mnist.append(T.Lambda(lambda x: x.repeat(3, 1, 1)))                     # to make MNIST RGB instead of grayscale
  transform_mnist.append(T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))    # normalizes the Tensors between [-1, 1]
  transform_mnist = T.Compose(transform_mnist)                                      # composes the above transformations into one.
  
  transform_svhn = list()
  transform_svhn.append(T.ToTensor())                                              # converts Numpy to Pytorch Tensor
  transform_svhn.append(T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))    # normalizes the Tensors between [-1, 1]
  transform_svhn = T.Compose(transform_svhn)                                       # composes the above transformations into one.
  
  # load SVHN
  source_training_data = torchvision.datasets.SVHN('./data/svhn', split='train', transform=transform_svhn, download=True) 
  
  # load MNIST
  target_training_data = torchvision.datasets.MNIST('./data/mnist', train=True, transform=transform_mnist, download=True) 
  target_test_data = torchvision.datasets.MNIST('./data/mnist', train=False, transform=transform_mnist, download=True)
  
  # initialize dataloaders
  source_train_loader = torch.utils.data.DataLoader(source_training_data, batch_size, shuffle=True, drop_last=True)
  target_train_loader = torch.utils.data.DataLoader(target_training_data, batch_size, shuffle=True, drop_last=True)
  
  target_test_loader = torch.utils.data.DataLoader(target_test_data, test_batch_size, shuffle=False)
  
  return source_train_loader, target_train_loader, target_test_loader

## Put everything together
We are now ready to wrap everything up into our main function, where we initialize our components and loop over multiple epochs.

In [None]:
'''
Input arguments
  batch_size: Size of a mini-batch
  device: GPU where you want to train your network
  weight_decay: Weight decay co-efficient for regularization of weights
  momentum: Momentum for SGD optimizer
  epochs: Number of epochs for training the network
'''

def main(batch_size=32, 
         device='cuda:0', 
         learning_rate=0.01, 
         weight_decay=0.000001, 
         momentum=0.9, 
         epochs=50,
         entropy_loss_weight=0.1):
  
  source_train_loader, target_train_loader, target_test_loader = get_data(batch_size)
  
  net = DIALNet().to(device)
  
  optimizer = get_optimizer(net, learning_rate, weight_decay, momentum)
  
  cost_function = get_ce_cost_function()
  
  for e in range(epochs):
    train_ce_loss, train_en_loss, train_accuracy = training_step(net=net, source_data_loader=source_train_loader, 
                                                         target_data_loader=target_train_loader, 
                                                         optimizer=optimizer, get_ce_cost_function=cost_function,
                                                         entropy_loss_weight=entropy_loss_weight)
    test_loss, test_accuracy = test_step(net, target_test_loader, cost_function)
    
    print('Epoch: {:d}'.format(e+1))
    print('\t Train: CE loss {:.5f}, Entropy loss {:.5f}, Accuracy {:.2f}'.format(train_ce_loss, train_en_loss, train_accuracy))
    print('\t Test: CE loss {:.5f}, Accuracy {:.2f}'.format(test_loss, test_accuracy))
    print('-----------------------------------------------------')

## Let's train!

In [None]:
main()