In [1]:

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import time
import os
import torch.nn as nn
from torch.autograd import Function
import torchvision
import torch.utils.data as data
from PIL import Image
import os

In [2]:
class GRL(Function):
    @staticmethod
    def forward(ctx, x, constant):
        ctx.constant = constant
        return x.view_as(x) * constant
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.constant, None

In [3]:
class Dann(nn.Module):
    def __init__(self):
        super(Dann, self).__init__()
        self.f = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=5),
                nn.BatchNorm2d(64),
                nn.MaxPool2d(2,2),
                nn.ReLU(),
                #nn.Conv2d(32, 32, kernel_size=3,padding=1,stride=2),
                #nn.BatchNorm2d(32),
                nn.Conv2d(64, 50, kernel_size=5),
                nn.BatchNorm2d(50),
                nn.Dropout2d(),
                nn.MaxPool2d(2, 2),
                nn.ReLU(),

                #nn.Conv2d(32, 32, kernel_size=3,padding=1,stride=2),
                #nn.BatchNorm2d(32),
                #nn.ReLU(),
                #nn.Conv2d(32, 128, kernel_size=5,padding=2),
                #nn.AvgPool2d(7)
            )
        self.lc = nn.Sequential(
            nn.Linear(50*4*4, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Linear(100, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Linear(100, 10),
        )
        self.dc = nn.Sequential(
            nn.Linear(50 * 4 * 4, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Linear(100, 2),
        )
    def forward(self, x,alpha):
        x = self.f(x)
        x = x.view(-1, 50*4*4)
        y = GRL.apply(x, alpha)
        x = self.lc(x)
        y = self.dc(y)
        x=x.view(x.shape[0],-1)
        y=y.view(y.shape[0],-1)
        return x, y

In [4]:
source_train = torch.utils.data.DataLoader(
        datasets.MNIST('./', train=True, download=False,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,)),
                           transforms.Lambda(lambda x: x.repeat(3, 1, 1))
                       ])),
        batch_size=512, shuffle=True,num_workers=12)
source_test = torch.utils.data.DataLoader(
        datasets.MNIST('./', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,)),
                           transforms.Lambda(lambda x: x.repeat(3, 1, 1))
                       ])),
        batch_size=512, shuffle=False,num_workers=12)

In [5]:
class GetLoader(data.Dataset):
    def __init__(self, data_root, data_list, transform=None):
        self.root = data_root
        self.transform = transform

        f = open(data_list, 'r')
        data_list = f.readlines()
        f.close()

        self.n_data = len(data_list)

        self.img_paths = []
        self.img_labels = []

        for data in data_list:
            self.img_paths.append(data[:-3])
            self.img_labels.append(data[-2])

    def __getitem__(self, item):
        img_paths, labels = self.img_paths[item], self.img_labels[item]
        imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')

        if self.transform is not None:
            imgs = self.transform(imgs)
            labels = int(labels)

        return imgs, labels

    def __len__(self):
        return self.n_data

In [6]:
image_size=28
img_transform = transforms.Compose([
    transforms.RandomCrop((image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5),
                         std=(0.5,0.5,0.5))
])
train_list = os.path.join('mnist_m', 'mnist_m_train_labels.txt')
dataset_train_target = GetLoader(
    data_root=os.path.join('mnist_m', 'mnist_m_train'),
    data_list=train_list,
    transform=img_transform
)
test_list = os.path.join('mnist_m', 'mnist_m_test_labels.txt')
dataset_test_target = GetLoader(
    data_root=os.path.join('mnist_m', 'mnist_m_test'),
    data_list=test_list,
    transform=img_transform
)
target_train = torch.utils.data.DataLoader(dataset_train_target,batch_size=512, shuffle=True,num_workers=12)
target_test = torch.utils.data.DataLoader(dataset_test_target,batch_size=512, shuffle=True,num_workers=12)

In [7]:
device='cuda'
model = Dann().to(device)
#optimizer = optim.Adam(model.parameters(),weight_decay=1e-5)
optimizer = optim.SGD(model.parameters(), lr= 0.01, momentum= 0.9)
#optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
criterion = nn.CrossEntropyLoss()

def optimizer_scheduler(optimizer, p):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.01 / (1. + 10 * p) ** 0.75
    return optimizer

In [8]:
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in source_test:
            data, target = data.to(device), target.to(device)
            output, _ = model(data,0.5)
            test_loss += float(criterion(output, target))  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += float(pred.eq(target.view_as(pred)).sum())

    test_loss /= len(source_test.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(source_test.dataset),
        100. * correct / len(source_test.dataset)))

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in target_train:
            data, target = data.to(device), target.to(device)
            output, _ = model(data,0.5)
            test_loss += float(criterion(output, target))  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += float(pred.eq(target.view_as(pred)).sum())

    test_loss /= len(target_train.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(target_train.dataset),
        100. * correct / len(target_train.dataset)))
    
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in target_test:
            data, target = data.to(device), target.to(device)
            output, _ = model(data,0.5)
            test_loss += float(criterion(output, target))  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += float(pred.eq(target.view_as(pred)).sum())

    test_loss /= len(target_test.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(target_test.dataset),
        100. * correct / len(target_test.dataset)))

In [9]:
print(len(target_train),len(source_train))

116 118


In [10]:
allepoch=100

for epoch in range(allepoch):
    len_dataloader = min(len(source_train), len(target_train))
    total_steps = allepoch * len(source_train)
    i = 0
    model.train()
    for batch_idx, (data_source, data_target) in enumerate(zip(source_train, target_train)):
        start_time = time.time()
        s_img, s_label = data_source

        start_steps = epoch * len(source_train)

        p = float(i + start_steps) / total_steps
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        optimizer = optimizer_scheduler(optimizer, p)
        optimizer.zero_grad()


        batch_size = len(s_label)

        domain_label = torch.zeros(batch_size)
        domain_label = domain_label.long()


        a,b = model(s_img.to(device),alpha)
        err_s_label = criterion(a, s_label.to(device))
        err_s_domain = criterion(b, domain_label.to(device))

        # training model using target data
        t_img, _ = data_target

        batch_size = len(t_img)

        domain_label = torch.ones(batch_size)
        domain_label = domain_label.long()



        _, b = model(t_img.to(device),alpha)
        err_t_domain = criterion(b, domain_label.to(device))
        err = err_s_label + err_s_domain + err_t_domain
        err.backward()
        optimizer.step()


        if(i % 1000 == 0):
            print('epoch:{},[{}/{}],s_label:{:.3f},s_domain:{:.3f},t_domain:{:.3f},time{}'.
                      format(epoch, i, len_dataloader, float(err_s_label), float(err_s_domain),
                             float(err_t_domain), time.time() - start_time))

        i += 1

    test(epoch)


epoch:0,[0/116],s_label:2.311,s_domain:0.722,t_domain:0.665,time0.13048028945922852
Test set: Average loss: 0.0006, Accuracy: 9511.0/10000 (95%)
Test set: Average loss: 0.0035, Accuracy: 24156.0/59001 (41%)

Test set: Average loss: 0.0036, Accuracy: 3649.0/9001 (41%)

epoch:1,[0/116],s_label:0.590,s_domain:0.686,t_domain:0.696,time0.08186531066894531
Test set: Average loss: 0.0003, Accuracy: 9581.0/10000 (96%)
Test set: Average loss: 0.0033, Accuracy: 24618.0/59001 (42%)

Test set: Average loss: 0.0034, Accuracy: 3775.0/9001 (42%)

epoch:2,[0/116],s_label:0.261,s_domain:0.688,t_domain:0.692,time0.09545540809631348
Test set: Average loss: 0.0002, Accuracy: 9677.0/10000 (97%)
Test set: Average loss: 0.0034, Accuracy: 24747.0/59001 (42%)

Test set: Average loss: 0.0034, Accuracy: 3756.0/9001 (42%)

epoch:3,[0/116],s_label:0.215,s_domain:0.681,t_domain:0.686,time0.08481979370117188
Test set: Average loss: 0.0002, Accuracy: 9771.0/10000 (98%)
Test set: Average loss: 0.0032, Accuracy: 26016.

Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7facaefed7b8>>
Traceback (most recent call last):
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 349, in __del__
    self._shutdown_workers()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 328, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/coder.chenshicheng

Test set: Average loss: 0.0001, Accuracy: 9845.0/10000 (98%)
Test set: Average loss: 0.0021, Accuracy: 40323.0/59001 (68%)

Test set: Average loss: 0.0021, Accuracy: 6245.0/9001 (69%)

epoch:22,[0/116],s_label:0.161,s_domain:0.625,t_domain:0.643,time0.08194780349731445
Test set: Average loss: 0.0001, Accuracy: 9796.0/10000 (98%)
Test set: Average loss: 0.0020, Accuracy: 40438.0/59001 (69%)

Test set: Average loss: 0.0020, Accuracy: 6179.0/9001 (69%)

epoch:23,[0/116],s_label:0.161,s_domain:0.678,t_domain:0.647,time0.0777430534362793
Test set: Average loss: 0.0001, Accuracy: 9799.0/10000 (98%)
Test set: Average loss: 0.0020, Accuracy: 41293.0/59001 (70%)

Test set: Average loss: 0.0020, Accuracy: 6406.0/9001 (71%)

epoch:24,[0/116],s_label:0.161,s_domain:0.645,t_domain:0.666,time0.10076546669006348
Test set: Average loss: 0.0001, Accuracy: 9794.0/10000 (98%)
Test set: Average loss: 0.0026, Accuracy: 38793.0/59001 (66%)

Test set: Average loss: 0.0026, Accuracy: 6040.0/9001 (67%)

epoch:

Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7facaefed710>>
Traceback (most recent call last):
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 349, in __del__
    self._shutdown_workers()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 328, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/coder.chenshicheng

Test set: Average loss: 0.0001, Accuracy: 9817.0/10000 (98%)
Test set: Average loss: 0.0017, Accuracy: 43567.0/59001 (74%)

Test set: Average loss: 0.0016, Accuracy: 6826.0/9001 (76%)

epoch:47,[0/116],s_label:0.163,s_domain:0.649,t_domain:0.658,time0.07735347747802734
Test set: Average loss: 0.0001, Accuracy: 9784.0/10000 (98%)
Test set: Average loss: 0.0017, Accuracy: 43642.0/59001 (74%)

Test set: Average loss: 0.0016, Accuracy: 6791.0/9001 (75%)

epoch:48,[0/116],s_label:0.113,s_domain:0.651,t_domain:0.668,time0.09015011787414551
Test set: Average loss: 0.0001, Accuracy: 9839.0/10000 (98%)
Test set: Average loss: 0.0018, Accuracy: 43840.0/59001 (74%)

Test set: Average loss: 0.0017, Accuracy: 6810.0/9001 (76%)

epoch:49,[0/116],s_label:0.164,s_domain:0.643,t_domain:0.662,time0.08941030502319336
Test set: Average loss: 0.0001, Accuracy: 9839.0/10000 (98%)
Test set: Average loss: 0.0017, Accuracy: 43697.0/59001 (74%)

Test set: Average loss: 0.0017, Accuracy: 6732.0/9001 (75%)

epoch

Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7facaefed748>>
Traceback (most recent call last):
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 349, in __del__
    self._shutdown_workers()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 328, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/coder.chenshicheng

Test set: Average loss: 0.0001, Accuracy: 9796.0/10000 (98%)
Test set: Average loss: 0.0018, Accuracy: 43857.0/59001 (74%)

Test set: Average loss: 0.0018, Accuracy: 6773.0/9001 (75%)

epoch:52,[0/116],s_label:0.111,s_domain:0.690,t_domain:0.664,time0.12810277938842773
Test set: Average loss: 0.0001, Accuracy: 9828.0/10000 (98%)
Test set: Average loss: 0.0018, Accuracy: 43480.0/59001 (74%)

Test set: Average loss: 0.0017, Accuracy: 6738.0/9001 (75%)

epoch:53,[0/116],s_label:0.148,s_domain:0.649,t_domain:0.678,time0.12303018569946289
Test set: Average loss: 0.0001, Accuracy: 9826.0/10000 (98%)
Test set: Average loss: 0.0020, Accuracy: 43372.0/59001 (74%)

Test set: Average loss: 0.0019, Accuracy: 6663.0/9001 (74%)

epoch:54,[0/116],s_label:0.148,s_domain:0.658,t_domain:0.687,time0.09293246269226074
Test set: Average loss: 0.0001, Accuracy: 9798.0/10000 (98%)
Test set: Average loss: 0.0017, Accuracy: 43307.0/59001 (73%)

Test set: Average loss: 0.0016, Accuracy: 6726.0/9001 (75%)

epoch

Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7facaefed7b8>>
Traceback (most recent call last):
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 349, in __del__
    self._shutdown_workers()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 328, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/coder.chenshicheng

Test set: Average loss: 0.0001, Accuracy: 9833.0/10000 (98%)
Test set: Average loss: 0.0017, Accuracy: 44392.0/59001 (75%)

Test set: Average loss: 0.0017, Accuracy: 6877.0/9001 (76%)

epoch:61,[0/116],s_label:0.115,s_domain:0.633,t_domain:0.673,time0.09624242782592773
Test set: Average loss: 0.0001, Accuracy: 9820.0/10000 (98%)
Test set: Average loss: 0.0016, Accuracy: 44884.0/59001 (76%)

Test set: Average loss: 0.0016, Accuracy: 6932.0/9001 (77%)

epoch:62,[0/116],s_label:0.124,s_domain:0.664,t_domain:0.667,time0.07787418365478516
Test set: Average loss: 0.0001, Accuracy: 9829.0/10000 (98%)
Test set: Average loss: 0.0016, Accuracy: 44767.0/59001 (76%)

Test set: Average loss: 0.0016, Accuracy: 6872.0/9001 (76%)

epoch:63,[0/116],s_label:0.135,s_domain:0.676,t_domain:0.679,time0.09090018272399902
Test set: Average loss: 0.0001, Accuracy: 9837.0/10000 (98%)
Test set: Average loss: 0.0017, Accuracy: 44642.0/59001 (76%)

Test set: Average loss: 0.0016, Accuracy: 6900.0/9001 (77%)

epoch

Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7facab9683c8>>
Traceback (most recent call last):
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 349, in __del__
    self._shutdown_workers()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 328, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/coder.chenshicheng

Test set: Average loss: 0.0001, Accuracy: 9854.0/10000 (99%)
Test set: Average loss: 0.0018, Accuracy: 44191.0/59001 (75%)

Test set: Average loss: 0.0018, Accuracy: 6846.0/9001 (76%)

epoch:66,[0/116],s_label:0.083,s_domain:0.664,t_domain:0.679,time0.09017610549926758
Test set: Average loss: 0.0001, Accuracy: 9850.0/10000 (98%)
Test set: Average loss: 0.0019, Accuracy: 43616.0/59001 (74%)

Test set: Average loss: 0.0019, Accuracy: 6752.0/9001 (75%)

epoch:67,[0/116],s_label:0.101,s_domain:0.652,t_domain:0.690,time0.14485454559326172
Test set: Average loss: 0.0001, Accuracy: 9829.0/10000 (98%)
Test set: Average loss: 0.0020, Accuracy: 43202.0/59001 (73%)

Test set: Average loss: 0.0019, Accuracy: 6665.0/9001 (74%)

epoch:68,[0/116],s_label:0.140,s_domain:0.664,t_domain:0.666,time0.11439847946166992
Test set: Average loss: 0.0001, Accuracy: 9842.0/10000 (98%)
Test set: Average loss: 0.0019, Accuracy: 43297.0/59001 (73%)

Test set: Average loss: 0.0019, Accuracy: 6674.0/9001 (74%)

epoch

Test set: Average loss: 0.0016, Accuracy: 6901.0/9001 (77%)

epoch:96,[0/116],s_label:0.074,s_domain:0.672,t_domain:0.669,time0.08373236656188965
Test set: Average loss: 0.0001, Accuracy: 9833.0/10000 (98%)
Test set: Average loss: 0.0016, Accuracy: 44888.0/59001 (76%)

Test set: Average loss: 0.0016, Accuracy: 6909.0/9001 (77%)

epoch:97,[0/116],s_label:0.087,s_domain:0.664,t_domain:0.675,time0.1050422191619873
Test set: Average loss: 0.0001, Accuracy: 9842.0/10000 (98%)
Test set: Average loss: 0.0016, Accuracy: 44247.0/59001 (75%)

Test set: Average loss: 0.0016, Accuracy: 6792.0/9001 (75%)

epoch:98,[0/116],s_label:0.087,s_domain:0.646,t_domain:0.650,time0.07918691635131836
Test set: Average loss: 0.0001, Accuracy: 9826.0/10000 (98%)
Test set: Average loss: 0.0017, Accuracy: 44107.0/59001 (75%)

Test set: Average loss: 0.0016, Accuracy: 6794.0/9001 (75%)

epoch:99,[0/116],s_label:0.088,s_domain:0.650,t_domain:0.691,time0.07661056518554688
Test set: Average loss: 0.0001, Accuracy: 983