1. Load the packages

In [None]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data import random_split
from torch.utils.data import Dataset,DataLoader,TensorDataset
from torch.autograd.function import Function
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torchvision import datasets, transforms
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data as data
from PIL import Image
import os



2. Load the dataset

In [None]:
num_workers=2
image_size=32
path='.\data'
def get_loader():
    """Builds and returns Dataloader for MNIST and SVHN dataset."""
    
    transform_usps = transforms.Compose([
                    transforms.Scale(image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                    transforms.Grayscale(num_output_channels=1)])
    transform_mnist=  transforms.Compose([
                    transforms.Scale(image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))])
    
    svhn_train = datasets.SVHN(root=path,split="train", download=True, transform=transform_usps)
    svhn_test = datasets.SVHN(root=path,split="test", download=True, transform=transform_usps)

    mnist_train = datasets.MNIST(root=path, train=True,download=True, transform=transform_mnist)
    mnist_test = datasets.MNIST(root=path, train=False,download=True, transform=transform_mnist)


    return svhn_train, svhn_test,mnist_train,mnist_test

In [None]:
svhn_train, svhn_test,mnist_train,mnist_test=get_loader()



Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to .\data/train_32x32.mat


  0%|          | 0/182040794 [00:00<?, ?it/s]

Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to .\data/test_32x32.mat


  0%|          | 0/64275384 [00:00<?, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to .\data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting .\data/MNIST/raw/train-images-idx3-ubyte.gz to .\data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to .\data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting .\data/MNIST/raw/train-labels-idx1-ubyte.gz to .\data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to .\data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting .\data/MNIST/raw/t10k-images-idx3-ubyte.gz to .\data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to .\data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting .\data/MNIST/raw/t10k-labels-idx1-ubyte.gz to .\data/MNIST/raw



3. Define function for randomly selecting samples from source and target domains

In [None]:

def get_index(labels,select_num):
  select_list=np.array([])
  remaining_list=np.array([])
  for i in range(10):
    full_list=np.where(labels==i)[0]
    ran_select=np.random.choice(full_list,select_num,replace=False)
    select_list=np.hstack((select_list,ran_select))

    remain_list=np.setdiff1d(full_list,ran_select)
    remaining_list=np.hstack((remaining_list,remain_list))

  return select_list.astype(int),remaining_list.astype(int)


def decompose_dataset(datesett):
  
  x_set=[]
  y=[]
  for i in range(len(datesett)):
    x,tem_y=datesett[i]
    x_set.append(x.numpy())
    y.append(tem_y)

  return np.array(x_set),np.array(y)
 


4. Center transfer loss (CTL) layer

In [None]:
class CTL(nn.Module):
    def __init__(self, num_classes, feat_dim, size_average=True):
        super(CTL, self).__init__()
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.centerlossfunc = CenterlossFunc.apply
        self.feat_dim = feat_dim
        self.size_average = size_average

    def forward(self, label, flip_label, domain_label,feat):
        batch_size = feat.size(0)
        feat = feat.view(batch_size, -1)
        # To check the dim of centers and features
        if feat.size(1) != self.feat_dim:
            raise ValueError("Center's dim: {0} should be equal to input feature's \
                            dim: {1}".format(self.feat_dim,feat.size(1)))
        batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1)
        loss = self.centerlossfunc(feat, label,flip_label, domain_label,self.centers, batch_size_tensor)
        return loss


class CenterlossFunc(Function):
    @staticmethod
    def forward(ctx, feature, label,flip_label, domain_label,centers, batch_size):
        ctx.save_for_backward(feature, label,flip_label, domain_label,centers, batch_size)
        centers_batch = centers.index_select(0, flip_label.long())

        domain_counts = centers.new_zeros(2)
        domain_ones = centers.new_ones(domain_label.size(0))
        domain_counts = domain_counts.scatter_add_(0, domain_label.long(), domain_ones)+0.0001
        domain_counts = torch.index_select(domain_counts, 0, domain_label.int())  

        return ((feature - centers_batch).pow(2)/domain_counts.view(-1, 1)).sum() / 2.0 / batch_size

    @staticmethod
    def backward(ctx, grad_output):
        feature,label,flip_label,domain_label,centers,batch_size = ctx.saved_tensors
        flip_centers_batch = centers.index_select(0, flip_label.long())
        flip_diff=flip_centers_batch-feature

        domain_counts = centers.new_zeros(2)
        domain_ones = centers.new_ones(domain_label.size(0))
        domain_counts = domain_counts.scatter_add_(0, domain_label.long(), domain_ones)+0.0001
        domain_counts = torch.index_select(domain_counts, 0, domain_label.int())  
      


        centers_batch = centers.index_select(0, label.long())
        diff = centers_batch - feature
        # init every iteration
        counts = centers.new_ones(centers.size(0))
        ones = centers.new_ones(label.size(0))
        grad_centers = centers.new_zeros(centers.size())
        counts = counts.scatter_add_(0, label.long(), ones)-1+0.001

        grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centers = grad_centers/counts.view(-1, 1)
       
        return - grad_output * (flip_diff/domain_counts.view(-1, 1)) / batch_size, None, None,None,grad_centers / batch_size, None

#Testing function
def main(test_cuda=False):
    print('-'*80)
    device = torch.device("cuda" if test_cuda else "cpu")
    ct = CTL(20,2,size_average=True).to(device)
    y = torch.Tensor([0,0,2,1,3]).to(device)
    dola= torch.Tensor([0,0,1,1,0]).to(device)
    feat = torch.zeros(5,2).to(device).requires_grad_()
    print(list(ct.parameters()))
    
    print(ct.centers.grad)
    out = ct(y,y,dola,feat)
    print(out.item())
    out.backward()


if __name__ == '__main__':
    torch.manual_seed(999)
    if torch.cuda.is_available():
        main(test_cuda=True)

--------------------------------------------------------------------------------
[Parameter containing:
tensor([[-0.2528,  1.4072],
        [ 0.2910,  1.0365],
        [-0.9816, -3.4219],
        [ 1.4910,  0.2422],
        [ 1.4832, -0.3704],
        [ 0.0941,  2.1528],
        [ 0.6271, -1.1666],
        [-0.7862,  0.0759],
        [-0.0086, -0.6568],
        [-1.0011,  0.2992],
        [ 0.6396, -1.0857],
        [-1.6153,  1.5635],
        [ 0.8194,  0.6117],
        [ 0.7602,  1.4788],
        [ 1.9647,  0.9414],
        [ 0.3883, -0.3957],
        [ 0.5920, -2.8563],
        [-0.4750, -0.9978],
        [ 0.0489,  0.9250],
        [-1.2278, -0.9470]], device='cuda:0', requires_grad=True)]
None
0.9038917422294617


5. Define LetNet++

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1_1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
        self.prelu1_1 = nn.PReLU()
        self.conv1_2 = nn.Conv2d(32, 32, kernel_size=5, padding=2)
        self.prelu1_2 = nn.PReLU()
        self.conv2_1 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        self.prelu2_1 = nn.PReLU()
        self.conv2_2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
        self.prelu2_2 = nn.PReLU()
        self.conv3_1 = nn.Conv2d(64, 128, kernel_size=5, padding=2)
        self.prelu3_1 = nn.PReLU()
        self.conv3_2 = nn.Conv2d(128, 128, kernel_size=5, padding=2)
        self.prelu3_2 = nn.PReLU()
        self.preluip1 = nn.PReLU()
        self.ip1 = nn.Linear(128*2*2*4, 256)
        self.ip2 = nn.Linear(256, 10, bias=True)


    def forward(self, x):
        x = self.prelu1_1(self.conv1_1(x))
        x = self.prelu1_2(self.conv1_2(x))
        x = F.max_pool2d(x,2)
        x = self.prelu2_1(self.conv2_1(x))
        x = self.prelu2_2(self.conv2_2(x))
        x = F.max_pool2d(x,2)
        x = self.prelu3_1(self.conv3_1(x))
        x = self.prelu3_2(self.conv3_2(x))
        x = F.max_pool2d(x,2)
        x = x.view(-1, 128*2*2*4)

        ip1 = self.preluip1(self.ip1(x))
        ip2 = self.ip2(ip1)
        return ip1, F.log_softmax(ip2, dim=1)


In [None]:
from torchsummary import summary
device = torch.device("cuda")
net = Net().to(device)

summary(net, (1, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             832
             PReLU-2           [-1, 32, 32, 32]               1
            Conv2d-3           [-1, 32, 32, 32]          25,632
             PReLU-4           [-1, 32, 32, 32]               1
            Conv2d-5           [-1, 64, 16, 16]          51,264
             PReLU-6           [-1, 64, 16, 16]               1
            Conv2d-7           [-1, 64, 16, 16]         102,464
             PReLU-8           [-1, 64, 16, 16]               1
            Conv2d-9            [-1, 128, 8, 8]         204,928
            PReLU-10            [-1, 128, 8, 8]               1
           Conv2d-11            [-1, 128, 8, 8]         409,728
            PReLU-12            [-1, 128, 8, 8]               1
           Linear-13                  [-1, 256]         524,544
            PReLU-14                  [

6. Define the training in one epoch

In [None]:
#deine the training within 1 epoch

def train(epoch,loss_weight,train_loader,model,nllloss,ctl,optimizer4nn,optimzer4center,test_loader):
  ip1_loader = []
  idx_loader = []
  center_total=0
  nll_losss=0
  #Training in each epoch#
  for data, target,good_ord,flip_ord,domain_label in train_loader:
    data, target,good_ord,flip_ord,domain_label = data.to(device), target.to(device),good_ord.to(device), flip_ord.to(device),domain_label.to(device)
    ip1, pred = model(data)

    loss = nllloss(pred, target) + loss_weight* ctl(good_ord, flip_ord,domain_label,ip1)

    optimizer4nn.zero_grad()
    optimzer4center.zero_grad()

    loss.backward()

    optimizer4nn.step()
    optimzer4center.step()

    ip1_loader.append(ip1)
    idx_loader.append((target))
    center_total=center_total+ctl(good_ord, flip_ord,domain_label,ip1).item()
    nll_losss= nll_losss+nllloss(pred, target).item()*64
  model.eval()
  correct = 0
  for inputs, labels in test_loader:
    inputs, labels= inputs.cuda(),labels.cuda()
    ip1, outputs = model(inputs)
    preds = torch.max(outputs, 1)[1]
    correct += torch.sum(preds == labels.data)
  epoch_acc = correct.double() / len(test_loader.dataset)
  if epoch>45:
    print("Training... Epoch = %d" % epoch)
    print('Softmax:', nll_losss/len(train_loader.dataset))
    print('CTL:', center_total/len(train_loader))
    print('Testing acc:',epoch_acc.item())
  return model,epoch_acc



7. Define the training function

In [None]:
def trainss(transfer,weightss,center_step):

  #Load data & formulate necessary labels for implementation
  if transfer=='MNISTtoSVHN':
    src_label = np.array(mnist_train.targets)
    tar_label = np.array(svhn_train.labels)

  
    srctrain_index,_=get_index(src_label,5000)
    source_trainset=torch.utils.data.Subset(mnist_train, srctrain_index)
    source_trainset=source_trainset
    tartrain_index,tartest_index=get_index(tar_label,10)
    target_trainset=torch.utils.data.Subset(svhn_train, tartrain_index)

    train_set=torch.utils.data.ConcatDataset([source_trainset, target_trainset])
    test_set = svhn_test
  else:
    src_label = np.array(svhn_train.labels)
    tar_label = np.array(mnist_train.targets)
    
    srctrain_index,_=get_index(src_label,3000)
    source_trainset=torch.utils.data.Subset(svhn_train, srctrain_index)
    tartrain_index,tartest_index=get_index(tar_label,10)
    target_trainset=torch.utils.data.Subset(mnist_train, tartrain_index)
    train_set=torch.utils.data.ConcatDataset([source_trainset, target_trainset])
    test_set = mnist_test



  print(len(train_set))
  print(len(test_set))




  train_x,train_y=decompose_dataset(train_set)
  train_x=torch.from_numpy(train_x)
  train_x=train_x.float()
  domain_label=np.hstack((np.zeros(len(source_trainset)),np.ones(len(target_trainset))))
  good_order=np.hstack((train_y[:-100],train_y[-100:]+10))
  flip_order=np.hstack((train_y[:-100]+10,train_y[-100:]))
  good_order=torch.from_numpy(good_order)
  flip_order=torch.from_numpy(flip_order)
  domain_label=torch.from_numpy(domain_label)


  good_order=good_order.long()
  flip_order=flip_order.long()
  domain_label=domain_label.long()

  train_y=torch.from_numpy(train_y)
  train_y=train_y.long()


  batch_size=64

  train_dataset=TensorDataset(train_x,train_y,good_order,flip_order,domain_label)

  train_loader= DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=2)

  # Model
  model = Net().to(device)


  # Softmax loss
  nllloss = nn.CrossEntropyLoss().to(device) #CrossEntropyLoss = log_softmax + NLLLoss

  # trade-off
  loss_weight =  weightss

  # ctl
  ctl = CTL(20, 256).to(device)
  # optim.Adadelta(model.parameters(),lr=0.005, rho=0.95, eps=1e-07, weight_decay=0)

  # optimzer4nn
  optimizer4nn = optim.Adam(model.parameters(),lr=0.00025)

  sheduler = lr_scheduler.StepLR(optimizer4nn,20,gamma=0.8)

  # optimzer4center
  optimzer4center = torch.optim.SGD(ctl.parameters(), lr =center_step)
  best=[0]
  eps=0
  for epoch in range(50):
    
    # print optimizer4nn.param_groups[0]['lr']
    model,accuracy=train(epoch+1, loss_weight,train_loader,model,nllloss,ctl,optimizer4nn,optimzer4center,test_loader)
    sheduler.step()
    if accuracy > max(best):
      best.append(accuracy)
      eps=epoch

  print("Epoch:",str(eps),"Best acc:", str(max(best)))
  return model


The below is an example of MNISTtoSVHN digit transfer task.

You may change the value of variables to get other experimental results.



1.   MNISTtoSVHN
2.   SVHNtoMNIST



Note that models will be trained from scratch. Training time should last for around 18 minutes for one repetition if you use the GPU (GeForce RTX 3090).

Although the results you get may be slightly different from the ones of the manuscript due to randomized initialization, the gap should be small.

In [None]:
transfer='MNISTtoSVHN' # "MNISTtoSVHN" or "SVHNtoMNIST"
lambda_para,center_step=0.75,0.5 # λ and α
model=trainss(transfer,lambda_para,center_step)

50100
26032
Training... Epoch = 46
Softmax: 1.744815137550075e-06
CTL: 0.0034882187255893715
Testing acc: 0.6853488014751075
Training... Epoch = 47
Softmax: 0.0021979602598181745
CTL: 0.006607642894539992
Testing acc: 0.6926475107559925
Training... Epoch = 48
Softmax: 0.00041845343958661337
CTL: 0.0032615397200713918
Testing acc: 0.6919176398279041
Training... Epoch = 49
Softmax: 1.7930066338837883e-06
CTL: 0.0019091646390176696
Testing acc: 0.6946834665027658
Training... Epoch = 50
Softmax: 1.6224514521361457e-06
CTL: 0.0022163369166837248
Testing acc: 0.6818915181315304
Epoch: 48 Best acc: tensor(0.6947, device='cuda:0', dtype=torch.float64)
