1. Load the packages

In [None]:
import random
import numpy as np
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from keras.utils.np_utils import to_categorical

2. Define the function used to load MNIST and USPS data splits; as well as to create data pairs


You need to download[ MNIST-USPS data splits](https://github.com/samotiian/CCSA) generated in [1] to run this code. Then: 


2.1. If you run the code on Colab, you will need to put these splits in the corresponding folder of your [Google Drive](https://drive.google.com/drive/u/0/my-drive).


2.2. If you run the code locally, you will need to put these splits in the corresponding folder of your device.



[1] Motiian, S., Piccirilli, M., Adjeroh, D. A., & Doretto, G. (2017). Unified deep supervised domain adaptation and generalization. In Proceedings of the IEEE international conference on computer vision (pp. 5715-5725).

In [None]:
# Mount the Google drive, please ignore this cell if you run the code locally.
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Direct the path to where you put the data
%cd /content/drive/MyDrive/Colab Notebooks/domain_adaptation_master
!ls

/content/drive/MyDrive/Colab Notebooks/domain_adaptation_master
baseline  cad  fc_6  mnist_m.tar.gz  MNIST_to_USPS  office-31  row_data


In [None]:
# Define the class to read the data and create pairs for training

class TrainSet(Dataset):
    def __init__(self, domain_adaptation_task, repetition, sample_per_class):
        x_source_path = './row_data/' + domain_adaptation_task + '_X_train_source_repetition_' + str(repetition) + '_sample_per_class_' + str(sample_per_class) + '.npy'
        y_source_path = './row_data/' + domain_adaptation_task + '_y_train_source_repetition_' + str(repetition) + '_sample_per_class_' + str(sample_per_class) + '.npy'
        x_target_path = './row_data/' + domain_adaptation_task + '_X_train_target_repetition_' + str(repetition) + '_sample_per_class_' + str(sample_per_class) + '.npy'
        y_target_path = './row_data/' + domain_adaptation_task + '_y_train_target_repetition_' + str(repetition) + '_sample_per_class_' + str(sample_per_class) + '.npy'

        self.x_source=np.load(x_source_path)
        self.y_source=np.load(y_source_path)
        self.x_target=np.load(x_target_path)
        self.y_target=np.load(y_target_path)

        print("Source X : ", len(self.x_source), " Y : ", len(self.y_source))
        print("Target X : ", len(self.x_target), " Y : ", len(self.y_target))
                
        Training_P=[]
        Training_N=[]
        for trs in range(len(self.y_source)):
            for trt in range(len(self.y_target)):
                if self.y_source[trs] == self.y_target[trt]:
                    Training_P.append([trs,trt, 1])
                else:
                    Training_N.append([trs,trt, 0])
        print("Class P : ", len(Training_P), " N : ", len(Training_N))
        
        random.shuffle(Training_N)
        self.imgs = Training_P+Training_N[:3*len(Training_P)]
        random.shuffle(self.imgs)

    def __getitem__(self, idx):
        src_idx, tgt_idx, domain = self.imgs[idx]

        x_src, y_src = self.x_source[src_idx], self.y_source[src_idx]
        x_tgt, y_tgt = self.x_target[tgt_idx], self.y_target[tgt_idx]

        x_src = torch.from_numpy(x_src).unsqueeze(0)
        x_tgt = torch.from_numpy(x_tgt).unsqueeze(0)

        return x_src, y_src, x_tgt, y_tgt

    def __len__(self):
        return len(self.imgs)


class TestSet(Dataset):
    def __init__(self, domain_adaptation_task, repetition, sample_per_class):
        self.x_test = np.load('./row_data/' + domain_adaptation_task + '_X_test_target_repetition_' + str(repetition) + '_sample_per_class_' + str(sample_per_class)+'.npy')
        self.y_test = np.load('./row_data/' + domain_adaptation_task + '_y_test_target_repetition_' + str(repetition) + '_sample_per_class_' + str(sample_per_class)+'.npy')

    def __getitem__(self, idx):
        x, y = self.x_test[idx], self.y_test[idx]
        x = torch.from_numpy(x).unsqueeze(0)
        return x, y

    def __len__(self):
        return len(self.x_test)

3. Define LetNet++ for MNIST-USPS

In [None]:
class Network(nn.Module):
  def __init__(self):
    super(Network, self).__init__()
    self.conv1_1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
    self.prelu1_1 = nn.PReLU()
    self.conv1_2 = nn.Conv2d(32, 32, kernel_size=5, padding=2)
    self.prelu1_2 = nn.PReLU()
    self.conv2_1 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
    self.prelu2_1 = nn.PReLU()
    self.conv2_2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
    self.prelu2_2 = nn.PReLU()
    self.conv3_1 = nn.Conv2d(64, 128, kernel_size=5, padding=2)
    self.prelu3_1 = nn.PReLU()
    self.conv3_2 = nn.Conv2d(128, 128, kernel_size=5, padding=2)
    self.prelu3_2 = nn.PReLU()
    self.preluip1 = nn.PReLU()
    self.ip1 = nn.Linear(128*2*2, 84)
    self.ip2 = nn.Linear(84, 10, bias=False)


  def forward(self, x):
    x = self.prelu1_1(self.conv1_1(x))
    x = self.prelu1_2(self.conv1_2(x))
    x = F.max_pool2d(x,2)
    x = self.prelu2_1(self.conv2_1(x))
    x = self.prelu2_2(self.conv2_2(x))
    x = F.max_pool2d(x,2)
    x = self.prelu3_1(self.conv3_1(x))
    x = self.prelu3_2(self.conv3_2(x))
    x = F.max_pool2d(x,2)
    x = x.view(-1, 128*2*2)
    ip1 = self.preluip1(self.ip1(x))
    ip2 = self.ip2(ip1)
    return F.log_softmax(ip2, dim=1), ip1

4. Define d-SNE Loss

In [None]:
def dnse_loss(src_feature, src_label, target_feature, target_label):
        """Pytorch implementation of d-SNE loss.
        Original Mxnet implementation found at https://github.com/aws-samples/d-SNE.
        @param y_true: tuple or array of two elements, containing source and target features
        @param y_pred: tuple or array of two elements, containing source and taget labels
        """
        xs = src_feature
        xt = target_feature
        ys = src_label    
        yt = target_label

        batch_size = ys.size()[0]
        embed_size = xs.size()[1]

        # The original implementation provided an optional feature-normalisation (L2) here. We'll skip it

        xs_rpt = torch.broadcast_to(
            torch.unsqueeze(xs, dim=0), size=(batch_size, batch_size, embed_size)
        )
        xt_rpt = torch.broadcast_to(
            torch.unsqueeze(xt, dim=1), size=(batch_size, batch_size, embed_size)
        )

        dists = torch.sum(torch.square(xt_rpt - xs_rpt), dim=2)

        yt_rpt = torch.broadcast_to(
            torch.unsqueeze(yt, dim=1), size=(batch_size, batch_size)
        )
        ys_rpt = torch.broadcast_to(
            torch.unsqueeze(ys, dim=0), size=(batch_size, batch_size)
        )

        y_same = torch.eq(yt_rpt, ys_rpt)
        y_diff = torch.ne(yt_rpt, ys_rpt)

        intra_cls_dists = torch.mul(dists, y_same)
        inter_cls_dists = torch.mul(dists, y_diff)

        max_dists = torch.max(dists, dim=1, keepdims=True)[0]
        max_dists = torch.broadcast_to(max_dists, size=(batch_size, batch_size))
        


        revised_inter_cls_dists = torch.where(y_same, max_dists, inter_cls_dists)

        max_intra_cls_dist,_ = torch.max(intra_cls_dists, dim=1)
        min_inter_cls_dist,_ = torch.min(revised_inter_cls_dists, dim=1)

        loss = torch.nn.functional.relu(max_intra_cls_dist - min_inter_cls_dist + 1)

        return loss

5. Training function

In [None]:
domain_adaptation_task = 'MNIST_to_USPS'
sample_per_class = 2  # 1 to 7
repetition = 9        # The number of splits, i.e., 0 to 9
batch = 256
epochs = 100
alpha = 0.25    # Trade-off: λ

train_set = TrainSet(domain_adaptation_task, repetition, sample_per_class)
train_set_loader = DataLoader(train_set, batch_size=batch, shuffle=True, drop_last=True)
test_set = TestSet(domain_adaptation_task, repetition, sample_per_class)
test_official_loader = DataLoader(test_set, batch_size=batch, shuffle=True, drop_last=True)
print("Dataset Length Train : ", len(train_set), " Test : ", len(test_set))

device = torch.device("cuda")
net = Network().to(device)
ce_loss = nn.CrossEntropyLoss()
optim = torch.optim.Adam(net.parameters())


def train(net, loader):
    net.train()
    for i, (src_img, src_label, target_img, target_label) in enumerate(loader):
        src_img, target_img = (x.to(device, dtype=torch.float) for x in [src_img, target_img])
        src_label, target_label = (x.to(device, dtype=torch.long) for x in [src_label, target_label])
        src_pred, src_feature = net(src_img)
        _, target_feature = net(target_img)
        
        ce  = ce_loss(src_pred, src_label)
        dsne = dnse_loss(src_feature, src_label, target_feature, target_label)                                          
        loss = (1 - alpha) * ce + alpha * dsne
        optim.zero_grad()
        loss.backward(torch.ones_like(loss))
        optim.step()

def test(net, loader):
    correct = 0
    net.eval()
    with torch.no_grad():
        for img, label in loader:
            img = img.to(device, dtype=torch.float)
            label = label.to(device, dtype=torch.long)
            pred, _ = net(img)
            _, idx = pred.max(dim=1)
            correct += (idx == label).sum().cpu().item()
    acc = correct / len(loader.dataset)
    return acc


Source X :  2000  Y :  2000
Target X :  20  Y :  20
Class P :  4000  N :  36000
Dataset Length Train :  16000  Test :  1800


6. Example

The below is an example of N=2, i.e., two sample per class from the target domain, on the task of MNIST to USPS.

You may change the value of variables to get other experimental results.

Note that when N gets larger, the training time **quadratically** increases.


In [None]:
print('Task:', domain_adaptation_task)
print("Number of samples from target domain:", sample_per_class)
print("Repetiton n.o:",repetition)


for epoch in range(epochs):
    train_loss = train(net, train_set_loader)
    test_acc = test(net, test_official_loader)
    if epoch>95:
      print('Train_loss:',train_loss)
      print("Epoch[%d] testing acc : %.4f"%(epoch, test_acc))

Task: MNIST_to_USPS
Number of samples from target domain: 2
Repetiton n.o: 9
Train_loss: None
Epoch[96] testing acc : 0.8694
Train_loss: None
Epoch[97] testing acc : 0.8689
Train_loss: None
Epoch[98] testing acc : 0.8694
Train_loss: None
Epoch[99] testing acc : 0.8694
