Take a network of your choice and conduct experiments for domain adaptation between CIFAR10 $\leftrightarrow$ STL10. These datasets are available in PyTorch. Notice that for CIFAR10 $\leftrightarrow$ STL10 there are only 9 overlapping classes out of 10. So, exclude the classes which do not overlap before training. Check the dataset webpages for details

In [15]:
import torch, torchvision
import torch.nn.functional as F
import torchvision.transforms as T
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from torchvision.datasets import CIFAR10
from torchvision.datasets import STL10
from torch.utils.data import Dataset

In [2]:
class DALayer2d(nn.Module):
  def __init__(self, in_features):
    super(DALayer2d, self).__init__()
    self.in_features = in_features

    self.batchnormsource = nn.BatchNorm2d(self.in_features, affine=False)
    self.batchnormtarget = nn.BatchNorm2d(self.in_features, affine=False)
    self.gamma = nn.parameter.Parameter(torch.ones(self.in_features, 1, 1))
    self.beta = nn.parameter.Parameter(torch.zeros(self.in_features, 1, 1))

  def forward(self, x):
    if self.training:
      x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0)
      return torch.cat((self.batchnormsource(x_source), self.batchnormtarget(x_target)), dim=0) * self.gamma + self.beta
    else:
      return self.batchnormtarget(x) * self.gamma + self.beta


class DALayer1d(nn.Module):
    def __init__(self, in_features):
      super(DALayer1d, self).__init__()
      self.in_features = in_features

      self.batchnormsource = nn.BatchNorm1d(self.in_features, affine=False)
      self.batchnormtarget = nn.BatchNorm1d(self.in_features, affine=False)
      self.gamma = nn.parameter.Parameter(torch.ones(1, in_features))
      self.beta = nn.parameter.Parameter(torch.zeros(1, in_features))

    def forward(self, x):
      if self.training:
        x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0)
        return torch.cat((self.batchnormsource(x_source), self.batchnormtarget(x_target)), dim=0) * self.gamma + self.beta

      else:
        return self.batchnormtarget(x) * self.gamma + self.beta

class DIALNet_rev(nn.Module):

  def __init__(self):
    super(DIALNet_rev, self).__init__()
    self.dial = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5, padding=2),
                   DALayer2d(64),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2),

                   nn.Conv2d(64, 64, kernel_size=5, padding=2),
                   DALayer2d(64),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2),

                   nn.Conv2d(64, 128, kernel_size=5, padding=2),
                   DALayer2d(128),
                   nn.ReLU(),

                   nn.Flatten(),

                   nn.Linear(6272, 3072),
                   DALayer1d(3072),
                   nn.ReLU(),
                   nn.Dropout(),

                   nn.Linear(3072, 2048),
                   DALayer1d(2048),
                   nn.ReLU(),
                   nn.Dropout(),

                   nn.Linear(2048, 10),
                   DALayer1d(10))


  def forward(self, x):

    return self.dial(x)

In [3]:
def source_loss():
  cost_function = torch.nn.CrossEntropyLoss()
  return cost_function

In [4]:
def target_loss(x):
  # Compute p_i
  p = F.softmax(x, dim=1)
  # Compute log p_i
  q = F.log_softmax(x, dim=1)

  b = p * q
  b = -1.0 * b.sum(-1).mean()
  return b

In [5]:
def train_one_epoch(model, source_loader, target_loader, optimizer, source_loss, entropy_loss_weights, device):
  source_samples = 0.
  target_samples = 0.

  cumulative_source_loss = 0.
  cumulative_target_loss = 0.
  cumulative_accuracy = 0.


  target_iter = iter(target_loader)

  model.train()

  for (x_source, y) in source_loader:
    # Gets target data. If the target iterator reaches the end, restarts it
    try:
      # if end of data is reached
      x_target, _ = next(target_iter)
    except:
      # restart data loader
      target_iter = iter(target_loader)
      # iterate again
      x_target, _ = next(target_iter)
    
    x = torch.cat((x_source, x_target), dim=0)

    # Load into GPU
    x = x.to(device)
    y = y.to(device)

    # Forward pass
    out = model(x)

    # Split source and target outputs

    source_y, target_y = torch.split(out,
                                     split_size_or_sections=out.shape[0] // 2,
                                     dim=0)

    # Apply losses
    sl = source_loss(source_y, y)
    tl = target_loss(target_y)

    loss = sl + entropy_loss_weights * tl

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()
    
    # Zeros the gradients
    optimizer.zero_grad()


    source_samples += x_source.shape[0]
    target_samples += x_target.shape[0]

    cumulative_source_loss += sl.item()
    cumulative_target_loss += tl.item()

    _, predicted = source_y.max(1)
    cumulative_accuracy += predicted.eq(y).sum().item()

  return cumulative_source_loss/source_samples, cumulative_target_loss/target_samples, cumulative_accuracy/source_samples*100

In [6]:
def test_one_epoch(model, loader, cost_function, device):
  samples = 0.
  cumulative_loss = 0.
  cumulative_accuracy = 0.

  model.eval()
  with torch.no_grad():
    for (x, y) in loader:
      x = x.to(device)
      y = y.to(device)

      out = model(x)

      loss = cost_function(out, y)

      samples += x.shape[0]
      cumulative_loss += loss.item()
      _, predicted = out.max(dim=1)
      cumulative_accuracy += predicted.eq(y).sum().item()
  
  return cumulative_loss/samples, cumulative_accuracy/samples*100

CIFAR10 classes:

- 0, Airplane
- 1, Automobile
- 2, Bird
- 3, Cat
- 4, Deer
- 5, Dog
- 6, Frog
- 7, Horse
- 8, Ship
- 9, Truck

STL10 classes:

- 0, Airplane
- 1, Bird
- 2, Automobile
- 3, Cat
- 4, Deer
- 5, Dog
- 6, Horse
- 7, Monkey
- 8, Ship
- 9, Truck

so, we need to remove monkeys and refactor the labels of each class.

In [121]:
transform_cifar = T.Compose([T.ToTensor(),
                             T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

transform_stl = T.Compose([T.Resize((32, 32)),
                             T.ToTensor(),
                             T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

class MyCIFAR(Dataset):
  def __init__(self):
    self.cifar = CIFAR10(root='./CIFAR',
                         download=True,
                         train=True,
                         transform=transform_cifar)

    self.labels = {0: 0,
                   1: 1,
                   2: 2,
                   3: 3,
                   4: 4,
                   5: 5,
                   7: 6,
                   8: 7,
                   9: 8}

    self.data = []

    #ignore 6, frog
    for idx, i in enumerate(self.cifar):
      if i[-1] != 6:
        x, y = self.cifar[idx]
        self.data.append((x, self.labels[y]))

  def __getitem__(self, index):
    data, target = self.data[index][0], self.data[index][1]
    return data, target

  def __len__(self):
    return len(self.data)




class MySTL(Dataset):
  def __init__(self, split='train'):

    self.stl = STL10(root='./STL',
                                    download=True,
                                    split=split,
                                    transform=transform_stl)
    self.labels = {0: 0,
                   1: 2,
                   2: 1,
                   3: 3,
                   4: 4,
                   5: 5,
                   6: 6,
                   8: 7,
                   9: 8}

    self.data = []

    #ignore 7
    for idx, i in enumerate(self.stl):
      if i[-1] != 7:
        x, y = self.stl[idx]
        self.data.append((x, self.labels[y]))
    
  def __getitem__(self, index):
      data, target = self.data[index][0], self.data[index][1]
      return data, target

  def __len__(self):
      return len(self.data)

In [122]:
def get_data(batch_size, test_batch_size=256):

  
  source_training_data = MyCIFAR()

  target_training_data = MySTL(split='train')
  target_test_data = MySTL(split='test')
    
  # Init DataLoaders
  source_train_loader = torch.utils.data.DataLoader(source_training_data, batch_size, shuffle=True, drop_last=True)
  target_train_loader = torch.utils.data.DataLoader(target_training_data, batch_size, shuffle=True, drop_last=True)

  target_test_loader = torch.utils.data.DataLoader(target_test_data, test_batch_size, shuffle=False)

  return source_train_loader, target_train_loader, target_test_loader

In [123]:
def run_model(batch_size=32, device='cuda:0', learning_rate=1e-2, weight_decay=1e-6, epochs=25, entropy_loss_weight=0.1):

  source_train_loader, target_train_loader, target_test_loader = get_data(batch_size=64)

  model = DIALNet_rev().to(device)

  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

  cost_function = source_loss()

  torch.manual_seed(42)

  for e in range(1, epochs+1):
    train_source_loss, train_target_loss, train_accuracy = train_one_epoch(model=model,
                                                                           source_loader=source_train_loader,
                                                                           target_loader=target_train_loader,
                                                                           optimizer=optimizer,
                                                                           source_loss=cost_function,
                                                                           entropy_loss_weights=entropy_loss_weight,
                                                                           device=device)
    test_loss, test_accuracy = test_one_epoch(model=model,
                                              loader=target_test_loader,
                                              cost_function=cost_function,
                                              device=device)

    print('Epoch: {:d}'.format(e))
    print('\t Train: Source loss {:.5f}, Target loss {:.2f}, Accuracy {:.2f}'.format(train_source_loss, train_target_loss, train_accuracy))
    print('\t Test: Source loss {:.5f}, Accuracy {:.2f}'.format(test_loss, test_accuracy))
    print('-----------------------------------------------------')

In [124]:
run_model()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Epoch: 1
	 Train: Source loss 0.01977, Target loss 0.02, Accuracy 54.10
	 Test: Source loss 0.00526, Accuracy 53.61
-----------------------------------------------------
Epoch: 2
	 Train: Source loss 0.01385, Target loss 0.01, Accuracy 68.55
	 Test: Source loss 0.00509, Accuracy 58.82
-----------------------------------------------------
Epoch: 3
	 Train: Source loss 0.01176, Target loss 0.01, Accuracy 73.51
	 Test: Source loss 0.00505, Accuracy 60.50
-----------------------------------------------------
Epoch: 4
	 Train: Source loss 0.01043, Target loss 0.01, Accuracy 76.54
	 Test: Source loss 0.00457, Accuracy 64.06
-----------------------------------------------------
Epoch: 5
	 Train: Source loss 0.00941, Target loss 0.01, Accuracy 78.95
	 Test: Source loss 0.00468, Accuracy 64.88
-----------------------------------------------------
Epoch: 6
	 Train: Source loss 0.0085