1. Load the packages

In [None]:
import os
import torch
import torchvision
import time
from torchvision import models
from torch.utils.data import Dataset,DataLoader,TensorDataset
from sklearn.datasets import fetch_openml
import numpy as np
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch.autograd.function import Function
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from  torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import random

2. Download Office31 dataset

In [None]:
!wget https://transferlearningdrive.blob.core.windows.net/teamdrive/dataset/office31.zip
!unzip office31.zip

--2022-03-11 16:20:12--  https://transferlearningdrive.blob.core.windows.net/teamdrive/dataset/office31.zip
Resolving transferlearningdrive.blob.core.windows.net (transferlearningdrive.blob.core.windows.net)... 20.150.17.228
Connecting to transferlearningdrive.blob.core.windows.net (transferlearningdrive.blob.core.windows.net)|20.150.17.228|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 79531208 (76M) [application/x-zip-compressed]
Saving to: ‘office31.zip’


2022-03-11 16:20:24 (6.58 MB/s) - ‘office31.zip’ saved [79531208/79531208]

Archive:  office31.zip
   creating: office31/
   creating: office31/amazon/
   creating: office31/amazon/back_pack/
  inflating: office31/amazon/back_pack/frame_0001.jpg  
  inflating: office31/amazon/back_pack/frame_0002.jpg  
  inflating: office31/amazon/back_pack/frame_0003.jpg  
  inflating: office31/amazon/back_pack/frame_0004.jpg  
  inflating: office31/amazon/back_pack/frame_0005.jpg  
  inflating: office31/amazon/back_pac

3. Load and pre-proprecess images

In [None]:
def load_data(root_path, domain, batch_size, phase):
    transform_dict = {
        'src': transforms.Compose(
        [
         transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ]),
        'tar': transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ])}
    data = datasets.ImageFolder(root=os.path.join(root_path, domain), transform=transform_dict[phase])
    return data

4. Define the function for sampling from the dataset

In [None]:
def split_class(input_set):
  loop_x=0
  class_indexset=[]
  index=[]
  for i in range(len(input_set)):
    if loop_x==input_set[i][1]:
      index.append(i)
    else:
      class_indexset.append(index)
      index=[]
      index.append(i)
      loop_x=loop_x+1
  class_indexset.append(index)
  return class_indexset

def sample_method(a_list,samplenum):
  random.seed(10) #amend if you want
  if a_list[-1]-a_list[0]>7:
    training_list = random.sample(range(a_list[0], a_list[-1]+1), samplenum)
  else:
    training_list = random.sample(range(a_list[0], a_list[-1]+1), samplenum-1)

  for i in training_list:
    a_list.remove(i)
  return training_list,a_list


def draw_samples(root_path,domain_src,domain_tar,batch_size):
  source = load_data(root_path, domain_src, batch_size, phase='src')
  target = load_data(root_path, domain_tar, batch_size, phase='tar')  
  soucenum=None
  if domain_src=='amazon':
    soucenum=20
  else:
    soucenum=8
  target_number=3
  
  train_x_scr=[]
  train_y_scr=[]

  train_x_tar=[]
  train_y_tar=[]
  test_x_tar=[]
  test_y_tar=[]

  source_indexset=split_class(source)
  target_indexset=split_class(target)

  for each_srclist in source_indexset:
    training_list,testing_list=sample_method(each_srclist,soucenum)
    for i in training_list:
      train_x_scr.append(source[i][0].numpy())
      train_y_scr.append(source[i][1])


  for each_tarlist in target_indexset:
    training_list,testing_list=sample_method(each_tarlist,target_number)

    for i in training_list:
      train_x_tar.append(target[i][0].numpy())
      train_y_tar.append(target[i][1])
    for j in testing_list:
      test_x_tar.append(target[j][0].numpy())
      test_y_tar.append(target[j][1])
  
  return train_x_scr,train_y_scr,train_x_tar,train_y_tar,test_x_tar,test_y_tar
  


5. Self-defined layer for implementing center transfer loss (CTL)

In [None]:
class CTL(nn.Module):
    def __init__(self, num_classes, feat_dim, size_average=True):
        super(CTL, self).__init__()
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.centerlossfunc = CenterlossFunc.apply
        self.feat_dim = feat_dim
        self.size_average = size_average

    def forward(self, label, flip_label, domain_label,feat):
        batch_size = feat.size(0)
        feat = feat.view(batch_size, -1)
        # To check the dim of centers and features
        if feat.size(1) != self.feat_dim:
            raise ValueError("Center's dim: {0} should be equal to input feature's \
                            dim: {1}".format(self.feat_dim,feat.size(1)))
        batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1)
        loss = self.centerlossfunc(feat, label,flip_label, domain_label,self.centers, batch_size_tensor)
        return loss


class CenterlossFunc(Function):
    @staticmethod
    def forward(ctx, feature, label,flip_label, domain_label,centers, batch_size):
        ctx.save_for_backward(feature, label,flip_label, domain_label,centers, batch_size)
        centers_batch = centers.index_select(0, flip_label.long())

        domain_counts = centers.new_zeros(2)
        domain_ones = centers.new_ones(domain_label.size(0))
        domain_counts = domain_counts.scatter_add_(0, domain_label.long(), domain_ones)+0.000001
        domain_counts = torch.index_select(domain_counts, 0, domain_label.int())  

        return ((feature - centers_batch).pow(2)/domain_counts.view(-1, 1)).sum() / 2.0 / batch_size

    @staticmethod
    def backward(ctx, grad_output):
        feature,label,flip_label,domain_label,centers,batch_size = ctx.saved_tensors
        flip_centers_batch = centers.index_select(0, flip_label.long())
        flip_diff=flip_centers_batch-feature

        domain_counts = centers.new_zeros(2)
        domain_ones = centers.new_ones(domain_label.size(0))
        domain_counts = domain_counts.scatter_add_(0, domain_label.long(), domain_ones)+0.000001
        domain_counts = torch.index_select(domain_counts, 0, domain_label.int())  
      


        centers_batch = centers.index_select(0, label.long())
        diff = centers_batch - feature
        # init every iteration
        counts = centers.new_ones(centers.size(0))
        ones = centers.new_ones(label.size(0))
        grad_centers = centers.new_zeros(centers.size())
        counts = counts.scatter_add_(0, label.long(), ones)-1+0.00001

        grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centers = grad_centers/counts.view(-1, 1)
       
        return - grad_output * (flip_diff/domain_counts.view(-1, 1)) / batch_size, None, None,None,grad_centers / batch_size, None

#Testing function
def main(test_cuda=False):
    print('-'*80)
    device = torch.device("cuda" if test_cuda else "cpu")
    ct = CTL(20,2,size_average=True).to(device)
    y = torch.Tensor([0,0,2,1,3]).to(device)
    dola= torch.Tensor([0,0,1,1,0]).to(device)
    feat = torch.zeros(5,2).to(device).requires_grad_()
    print(list(ct.parameters()))
    
    print(ct.centers.grad)
    out = ct(y,y,dola,feat)
    print(out.item())
    out.backward()


if __name__ == '__main__':
    torch.manual_seed(999)
    if torch.cuda.is_available():
        main(test_cuda=True)

--------------------------------------------------------------------------------
[Parameter containing:
tensor([[-0.2528,  1.4072],
        [ 0.2910,  1.0365],
        [-0.9816, -3.4219],
        [ 1.4910,  0.2422],
        [ 1.4832, -0.3704],
        [ 0.0941,  2.1528],
        [ 0.6271, -1.1666],
        [-0.7862,  0.0759],
        [-0.0086, -0.6568],
        [-1.0011,  0.2992],
        [ 0.6396, -1.0857],
        [-1.6153,  1.5635],
        [ 0.8194,  0.6117],
        [ 0.7602,  1.4788],
        [ 1.9647,  0.9414],
        [ 0.3883, -0.3957],
        [ 0.5920, -2.8563],
        [-0.4750, -0.9978],
        [ 0.0489,  0.9250],
        [-1.2278, -0.9470]], device='cuda:0', requires_grad=True)]
None
0.9039329290390015


6. Defined the pre-trained VGG-16 net

In [None]:
class TransferModel(nn.Module):
    def __init__(self,
                base_model : str = 'VGG16',
                pretrain : bool = True,
                n_class : int = 31):
        super(TransferModel, self).__init__()
        self.base_model = base_model
        self.pretrain = pretrain
        self.n_class = n_class
        if self.base_model == 'VGG16':
            self.model = torchvision.models.vgg16(pretrained=True)
            n_features = self.model.classifier[6].in_features
            fc = torch.nn.Linear(n_features, 1024)
            self.model.classifier[6] = fc
        else:
            pass
        
        self.ip1 = nn.Linear(1024, 128)
        self.ip2 = nn.Linear(128, 31)
        
    def forward(self, x):
        x= self.model(x)
        feature=self.ip1(x)
        feature.view(-1, 128)
        output=self.ip2(feature)
        return feature,output
    def predict(self, x):
        return self.forward(x)


In [None]:
from torchsummary import summary
RAND_TENSOR = torch.randn(1, 3, 224, 224).cuda()
device = torch.device("cuda")
net = TransferModel().to(device)
output = net(RAND_TENSOR)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

7. Define the funtion for fine-tuning the model pre-trained by ImageNet

In [None]:
def finetune(model, dataloaders, optimizer,loss_weight):
    ctl = CTL(62, 128).to(device)
    optimzer4center = torch.optim.SGD(ctl.parameters(), lr =0.5)
    n_epoch = 100
    criterion = nn.CrossEntropyLoss()
    early_stop = 20
    since = time.time()
    best_acc = 0
    stop = 0
    for epoch in range(0, n_epoch):
        stop += 1
        for phase in ['src']:
            if phase == 'src':
                model.train()
            else:
                model.eval()
            total_loss, total_center,correct = 0, 0,0
            for inputs, labels,good_ord,flip_ord,domain_label in dataloaders[phase]:
                inputs, labels,good_ord,flip_ord,domain_label = inputs.cuda(),labels.cuda(), good_ord.cuda(),flip_ord.cuda(), domain_label.cuda()
                optimizer.zero_grad()
                optimzer4center.zero_grad()
                with torch.set_grad_enabled(phase == 'src'):
                    ip1, outputs = model(inputs)
                    softmax=criterion(outputs, labels)
                    centerlss=ctl(good_ord, flip_ord,domain_label,ip1)
                    loss = softmax + loss_weight * centerlss
                preds = torch.max(outputs, 1)[1]
                if phase == 'src':
                    loss.backward()
                    optimizer.step()
                    optimzer4center.step()
                total_loss += softmax.item() * inputs.size(0)
                total_center +=centerlss.item()
                correct += torch.sum(preds == labels.data)
            epoch_loss = total_loss / len(dataloaders[phase].dataset)
            epoch_acc = correct.double() / len(dataloaders[phase].dataset)

        model.eval()
        correct = 0
        for inputs, labels in dataloaders['tar']:
          inputs, labels= inputs.cuda(),labels.cuda()
          ip1, outputs = model(inputs)
          preds = torch.max(outputs, 1)[1]
          correct += torch.sum(preds == labels.data)
        epoch_acc = correct.double()/ len(dataloaders['tar'].dataset)
        if epoch_acc>best_acc:
          best_acc=epoch_acc
        if epoch>95:
            print(f'Epoch: [{epoch:02d}/{n_epoch:02d}]---{phase}, softmax_loss: {epoch_loss:.6f}')
            print_center = total_center / len(dataloaders[phase].dataset)
            print('CTL:',print_center)
            print('Testing acc:',epoch_acc.item())
        


8. Define evaluation function

In [None]:
def trainandevaluation(domain_src,domain_tar,xx):
  data_folder = 'office31'
  batch_size = 32
  n_class = 31
  random.seed(xx) #amend if you want
  train_x_scr,train_y_scr,train_x_tar,train_y_tar,test_x_tar,test_y_tar=draw_samples(data_folder,domain_src,domain_tar,batch_size)
  train_x_scr,train_y_scr,train_x_tar,train_y_tar,test_x_tar,test_y_tar=np.array(train_x_scr),np.array(train_y_scr),np.array(train_x_tar),np.array(train_y_tar),np.array(test_x_tar),np.array(test_y_tar)
  domain_label=np.hstack((np.zeros(len(train_x_scr)),np.ones(len(train_x_tar))))
  x_train=np.vstack((train_x_scr,train_x_tar))
  y_train=np.hstack((train_y_scr,train_y_tar))
  x_test=test_x_tar
  y_test=test_y_tar
  good_order=np.hstack((train_y_scr,train_y_tar+31))
  flip_order=np.hstack((train_y_scr+31,train_y_tar))
  x_train=torch.from_numpy(x_train)
  y_train=torch.from_numpy(y_train)
  x_test=torch.from_numpy(x_test)
  y_test=torch.from_numpy(y_test)

  good_order=torch.from_numpy(good_order)
  flip_order=torch.from_numpy(flip_order)
  domain_label=torch.from_numpy(domain_label)
  x_train=x_train.float()
  x_test=x_test.float()
  y_train=y_train.long()
  y_test=y_test.long()
  
  good_order=good_order.long()
  flip_order=flip_order.long()
  domain_label=domain_label.long()
  # form the dataset
  train_dataset=TensorDataset(x_train,y_train,good_order,flip_order,domain_label)
  val_dataset=TensorDataset(x_test,y_test)
  train_loader= DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  test_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

  model = TransferModel().cuda()
  param_group = []
  learning_rate = 0.001
  momentum = 5e-3
  for k, v in model.named_parameters():


    if k.__contains__('classifier'):
        param_group += [{'params': v, 'lr': 10*learning_rate }]
    elif k.__contains__('ip'):
        param_group += [{'params': v, 'lr': 10*learning_rate }]

    else:
        param_group += [{'params': v, 'lr': learning_rate}]

  
  optimizer = torch.optim.SGD(param_group, momentum=momentum)
  dataloaders = {'src': train_loader,
               'tar': test_loader}
  finetune(model, dataloaders, optimizer,0.1)

9. Example

Example below is the *Amaon to DSLR* domain adaptation task.
You may change the value of variables to get other experimental results.

Note that the model will be trained from scratch. Training time should last for around **25 minutes** for one repetition if you use the GPU (GeForce RTX 3090).

Although the results you get may be slightly different from the ones of the manuscript due to randomized initialization, the gap should be small.

In [None]:
#'amazon', 'webcam','dslr'
source_domain='amazon'
target_domain='dslr'
random_seed=1
trainandevaluation(source_domain,target_domain,random_seed)

Epoch: [96/100]---src, softmax_loss: 0.021324
CTL: 0.018971398787156438
Testing acc: 0.9213759213759214
Epoch: [97/100]---src, softmax_loss: 0.019139
CTL: 0.020149789027523894
Testing acc: 0.9189189189189189
Epoch: [98/100]---src, softmax_loss: 0.019888
CTL: 0.02055117035213905
Testing acc: 0.9238329238329238
Epoch: [99/100]---src, softmax_loss: 0.019471
CTL: 0.020247976907362582
Testing acc: 0.9262899262899262
