Refactor the DialNet class so that instead of declaring DA layersr as separate components:

  ```
  self.bns1 = nn.batchNorm2d(64, affine=False)
  self.bnt1 = nn.BatchNorm2d(64, affine=False)
  self.gamma1 = nn.Parameter(torch.ones(64, 1, 1))
  self.beta1 = nn.Parameter(torch.zeros(64, 1, 1))
  ```

  Defines them as self-contained DALayer2d or DALayer1d modules:

  ```
  self.da1 = DALayer2d(64)
  ```

In [None]:
import torch, torchvision
import torch.nn.functional as F
import torchvision.transforms as T

from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

In [None]:
class DALayer2d(nn.Module):
  def __init__(self, in_features):
    super(DALayer2d, self).__init__()
    self.in_features = in_features

    self.batchnormsource = nn.BatchNorm2d(self.in_features, affine=False)
    self.batchnormtarget = nn.BatchNorm2d(self.in_features, affine=False)
    self.gamma = nn.parameter.Parameter(torch.ones(self.in_features, 1, 1))
    self.beta = nn.parameter.Parameter(torch.zeros(self.in_features, 1, 1))

  def forward(self, x):
    if self.training:
      x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0)
      return torch.cat((self.batchnormsource(x_source), self.batchnormtarget(x_target)), dim=0) * self.gamma + self.beta
    else:
      return self.batchnormtarget(x) * self.gamma + self.beta


class DALayer1d(nn.Module):
    def __init__(self, in_features):
      super(DALayer1d, self).__init__()
      self.in_features = in_features

      self.batchnormsource = nn.BatchNorm1d(self.in_features, affine=False)
      self.batchnormtarget = nn.BatchNorm1d(self.in_features, affine=False)
      self.gamma = nn.parameter.Parameter(torch.ones(1, in_features))
      self.beta = nn.parameter.Parameter(torch.zeros(1, in_features))

    def forward(self, x):
      if self.training:
        x_source, x_target = torch.split(x, split_size_or_sections=x.shape[0] // 2, dim=0)
        return torch.cat((self.batchnormsource(x_source), self.batchnormtarget(x_target)), dim=0) * self.gamma + self.beta

      else:
        return self.batchnormtarget(x) * self.gamma + self.beta

class DIALNet_rev(nn.Module):

  def __init__(self):
    super(DIALNet_rev, self).__init__()
    self.dial = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5, padding=2),
                   DALayer2d(64),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2),

                   nn.Conv2d(64, 64, kernel_size=5, padding=2),
                   DALayer2d(64),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2),

                   nn.Conv2d(64, 128, kernel_size=5, padding=2),
                   DALayer2d(128),
                   nn.ReLU(),

                   nn.Flatten(),

                   nn.Linear(6272, 3072),
                   DALayer1d(3072),
                   nn.ReLU(),
                   nn.Dropout(),

                   nn.Linear(3072, 2048),
                   DALayer1d(2048),
                   nn.ReLU(),
                   nn.Dropout(),

                   nn.Linear(2048, 10),
                   DALayer1d(10))


  def forward(self, x):

    return self.dial(x)

In [None]:
def source_loss():
  cost_function = torch.nn.CrossEntropyLoss()
  return cost_function

In [None]:
def target_loss(x):
  # Compute p_i
  p = F.softmax(x, dim=1)
  # Compute log p_i
  q = F.log_softmax(x, dim=1)

  b = p * q
  b = -1.0 * b.sum(-1).mean()
  return b

In [None]:
def train_one_epoch(model, source_loader, target_loader, optimizer, source_loss, entropy_loss_weights, device):
  source_samples = 0.
  target_samples = 0.

  cumulative_source_loss = 0.
  cumulative_target_loss = 0.
  cumulative_accuracy = 0.


  target_iter = iter(target_loader)

  model.train()

  for (x_source, y) in source_loader:
    # Gets target data. If the target iterator reaches the end, restarts it
    try:
      # if end of data is reached
      x_target, _ = next(target_iter)
    except:
      # restart data loader
      target_iter = iter(target_loader)
      # iterate again
      x_target, _ = next(target_iter)
    
    x = torch.cat((x_source, x_target), dim=0)

    # Load into GPU
    x = x.to(device)
    y = y.to(device)

    # Forward pass
    out = model(x)

    # Split source and target outputs

    source_y, target_y = torch.split(out,
                                     split_size_or_sections=out.shape[0] // 2,
                                     dim=0)

    # Apply losses
    sl = source_loss(source_y, y)
    tl = target_loss(target_y)

    loss = sl + entropy_loss_weights * tl

    # Backward pass
    loss.backward()

    # Update parameters
    optimizer.step()
    
    # Zeros the gradients
    optimizer.zero_grad()


    source_samples += x_source.shape[0]
    target_samples += x_target.shape[0]

    cumulative_source_loss += sl.item()
    cumulative_target_loss += tl.item()

    _, predicted = source_y.max(1)
    cumulative_accuracy += predicted.eq(y).sum().item()

  return cumulative_source_loss/source_samples, cumulative_target_loss/target_samples, cumulative_accuracy/source_samples*100

In [None]:
def test_one_epoch(model, loader, cost_function, device):
  samples = 0.
  cumulative_loss = 0.
  cumulative_accuracy = 0.

  model.eval()
  with torch.no_grad():
    for (x, y) in loader:
      x = x.to(device)
      y = y.to(device)

      out = model(x)

      loss = cost_function(out, y)

      samples += x.shape[0]
      cumulative_loss += loss.item()
      _, predicted = out.max(dim=1)
      cumulative_accuracy += predicted.eq(y).sum().item()
  
  return cumulative_loss/samples, cumulative_accuracy/samples*100

In [None]:
def get_data(batch_size, test_batch_size=256):

  transform_mnist = T.Compose([T.ToTensor(),
                               T.Lambda(lambda x: F.pad(x, (2, 2, 2, 2), 'constant')),
                               T.Lambda(lambda x: x.repeat(3, 1, 1)),
                               T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
  
  transform_svhn = T.Compose([T.ToTensor(),
                              T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
  
  source_training_data = torchvision.datasets.SVHN('./data/svhn', split='train', transform=transform_svhn, download=True)

  target_training_data = torchvision.datasets.MNIST('./data/mnist', train=True, transform=transform_mnist, download=True)
  target_test_data = torchvision.datasets.MNIST('./data/mnist', train=False, transform=transform_mnist, download=True)
  
  # Init DataLoaders
  source_train_loader = torch.utils.data.DataLoader(source_training_data, batch_size, shuffle=True, drop_last=True)
  target_train_loader = torch.utils.data.DataLoader(target_training_data, batch_size, shuffle=True, drop_last=True)

  target_test_loader = torch.utils.data.DataLoader(target_test_data, test_batch_size, shuffle=False)

  return source_train_loader, target_train_loader, target_test_loader

In [None]:
def run_model(batch_size=32, device='cuda:0', learning_rate=1e-2, weight_decay=1e-6, epochs=25, entropy_loss_weight=0.1):

  source_train_loader, target_train_loader, target_test_loader = get_data(batch_size=64)

  model = DIALNet_rev().to(device)

  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

  cost_function = source_loss()

  torch.manual_seed(42)

  for e in range(1, epochs+1):
    train_source_loss, train_target_loss, train_accuracy = train_one_epoch(model=model,
                                                                           source_loader=source_train_loader,
                                                                           target_loader=target_train_loader,
                                                                           optimizer=optimizer,
                                                                           source_loss=cost_function,
                                                                           entropy_loss_weights=entropy_loss_weight,
                                                                           device=device)
    test_loss, test_accuracy = test_one_epoch(model=model,
                                              loader=target_test_loader,
                                              cost_function=cost_function,
                                              device=device)

    print('Epoch: {:d}'.format(e))
    print('\t Train: Source loss {:.5f}, Target loss {:.2f}, Accuracy {:.2f}'.format(train_source_loss, train_target_loss, train_accuracy))
    print('\t Test: Source loss {:.5f}, Accuracy {:.2f}'.format(test_loss, test_accuracy))
    print('-----------------------------------------------------')

In [None]:
run_model()

Using downloaded and verified file: ./data/svhn/train_32x32.mat
Epoch: 1
	 Train: Source loss 0.01230, Target loss 0.01, Accuracy 74.49
	 Test: Source loss 0.00280, Accuracy 81.08
-----------------------------------------------------
Epoch: 2
	 Train: Source loss 0.00628, Target loss 0.00, Accuracy 87.83
	 Test: Source loss 0.00242, Accuracy 87.23
-----------------------------------------------------
Epoch: 3
	 Train: Source loss 0.00523, Target loss 0.00, Accuracy 89.84
	 Test: Source loss 0.00250, Accuracy 89.10
-----------------------------------------------------
Epoch: 4
	 Train: Source loss 0.00474, Target loss 0.00, Accuracy 90.82
	 Test: Source loss 0.00196, Accuracy 90.99
-----------------------------------------------------
Epoch: 5
	 Train: Source loss 0.00443, Target loss 0.00, Accuracy 91.52
	 Test: Source loss 0.00162, Accuracy 93.23
-----------------------------------------------------
Epoch: 6
	 Train: Source loss 0.00416, Target loss 0.00, Accuracy 92.06
	 Test: Source