In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import time
import os
import torch.nn as nn
from torch.autograd import Function
import torchvision
import torch.utils.data as data
from PIL import Image
import os

In [2]:
class GRL(Function):
    @staticmethod
    def forward(ctx, x, constant):
        ctx.constant = constant
        return x.view_as(x) * constant
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.constant, None

In [3]:
class Dann(nn.Module):
    def __init__(self):
        super(Dann, self).__init__()
        self.f = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=5),
                nn.BatchNorm2d(64),
                #nn.MaxPool2d(2,2),
                nn.Conv2d(64,64,kernel_size=3,stride=2,padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                #nn.Conv2d(32, 32, kernel_size=3,padding=1,stride=2),
                #nn.BatchNorm2d(32),
                nn.Conv2d(64, 50, kernel_size=5),
                nn.BatchNorm2d(50),
                nn.Dropout2d(),
                #nn.MaxPool2d(2, 2),
                nn.Conv2d(50, 50, kernel_size=3, stride=2,padding=1),
                nn.BatchNorm2d(50),
                nn.ReLU(),

                #nn.Conv2d(32, 32, kernel_size=3,padding=1,stride=2),
                #nn.BatchNorm2d(32),
                #nn.ReLU(),
                #nn.Conv2d(32, 128, kernel_size=5,padding=2),
                #nn.AvgPool2d(7)
            )
        self.lc = nn.Sequential(
            nn.Linear(50*4*4, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Dropout2d(),
            nn.Linear(100, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Linear(100, 10),
        )
        self.dc = nn.Sequential(
            nn.Linear(50 * 4 * 4, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Linear(100, 2),
        )
    def forward(self, x,alpha):
        x = self.f(x)
        x = x.view(-1, 50*4*4)
        y = GRL.apply(x, alpha)
        x = self.lc(x)
        y = self.dc(y)
        x=x.view(x.shape[0],-1)
        y=y.view(y.shape[0],-1)
        return x, y

In [4]:
source_train = torch.utils.data.DataLoader(
        datasets.MNIST('./', train=True, download=False,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,)),
                           transforms.Lambda(lambda x: x.repeat(3, 1, 1))
                       ])),
        batch_size=512, shuffle=True,num_workers=12)
source_test = torch.utils.data.DataLoader(
        datasets.MNIST('./', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,)),
                           transforms.Lambda(lambda x: x.repeat(3, 1, 1))
                       ])),
        batch_size=512, shuffle=False,num_workers=12)

In [5]:
class GetLoader(data.Dataset):
    def __init__(self, data_root, data_list, transform=None):
        self.root = data_root
        self.transform = transform

        f = open(data_list, 'r')
        data_list = f.readlines()
        f.close()

        self.n_data = len(data_list)

        self.img_paths = []
        self.img_labels = []

        for data in data_list:
            self.img_paths.append(data[:-3])
            self.img_labels.append(data[-2])

    def __getitem__(self, item):
        img_paths, labels = self.img_paths[item], self.img_labels[item]
        imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')

        if self.transform is not None:
            imgs = self.transform(imgs)
            labels = int(labels)

        return imgs, labels

    def __len__(self):
        return self.n_data

In [6]:
image_size=28
img_transform = transforms.Compose([
    transforms.RandomCrop((image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5),
                         std=(0.5,0.5,0.5))
])
train_list = os.path.join('mnist_m', 'mnist_m_train_labels.txt')
dataset_train_target = GetLoader(
    data_root=os.path.join('mnist_m', 'mnist_m_train'),
    data_list=train_list,
    transform=img_transform
)
test_list = os.path.join('mnist_m', 'mnist_m_test_labels.txt')
dataset_test_target = GetLoader(
    data_root=os.path.join('mnist_m', 'mnist_m_test'),
    data_list=test_list,
    transform=img_transform
)
target_train = torch.utils.data.DataLoader(dataset_train_target,batch_size=512, shuffle=True,num_workers=12)
target_test = torch.utils.data.DataLoader(dataset_test_target,batch_size=512, shuffle=True,num_workers=12)

In [7]:
device='cuda'
model = Dann().to(device)
#optimizer = optim.Adam(model.parameters(),weight_decay=1e-5)
optimizer = optim.SGD(model.parameters(), lr= 0.01, momentum= 0.9)
#optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
criterion = nn.CrossEntropyLoss()

def optimizer_scheduler(optimizer, p):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.01 / (1. + 10 * p) ** 0.75
    return optimizer

In [8]:
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in source_test:
            data, target = data.to(device), target.to(device)
            output, _ = model(data,0.5)
            test_loss += float(criterion(output, target))  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += float(pred.eq(target.view_as(pred)).sum())

    test_loss /= len(source_test.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(source_test.dataset),
        100. * correct / len(source_test.dataset)))

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in target_train:
            data, target = data.to(device), target.to(device)
            output, _ = model(data,0.5)
            test_loss += float(criterion(output, target))  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += float(pred.eq(target.view_as(pred)).sum())

    test_loss /= len(target_train.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(target_train.dataset),
        100. * correct / len(target_train.dataset)))
    
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in target_test:
            data, target = data.to(device), target.to(device)
            output, _ = model(data,0.5)
            test_loss += float(criterion(output, target))  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += float(pred.eq(target.view_as(pred)).sum())

    test_loss /= len(target_test.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(target_test.dataset),
        100. * correct / len(target_test.dataset)))

In [9]:
print(len(target_train),len(source_train))

116 118


In [10]:
allepoch=100

for epoch in range(allepoch):
    len_dataloader = min(len(source_train), len(target_train))
    total_steps = allepoch * len(source_train)
    i = 0
    model.train()
    for batch_idx, (data_source, data_target) in enumerate(zip(source_train, target_train)):
        start_time = time.time()
        s_img, s_label = data_source

        start_steps = epoch * len(source_train)

        p = float(i + start_steps) / total_steps
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        optimizer = optimizer_scheduler(optimizer, p)
        optimizer.zero_grad()


        batch_size = len(s_label)

        domain_label = torch.zeros(batch_size)
        domain_label = domain_label.long()


        a,b = model(s_img.to(device),alpha)
        err_s_label = criterion(a, s_label.to(device))
        err_s_domain = criterion(b, domain_label.to(device))

        # training model using target data
        t_img, _ = data_target

        batch_size = len(t_img)

        domain_label = torch.ones(batch_size)
        domain_label = domain_label.long()



        _, b = model(t_img.to(device),alpha)
        err_t_domain = criterion(b, domain_label.to(device))
        err = err_s_label + err_s_domain + err_t_domain
        err.backward()
        optimizer.step()


        if(i % 1000 == 0):
            print('epoch:{},[{}/{}],s_label:{:.3f},s_domain:{:.3f},t_domain:{:.3f},time{}'.
                      format(epoch, i, len_dataloader, float(err_s_label), float(err_s_domain),
                             float(err_t_domain), time.time() - start_time))

        i += 1

    test(epoch)


epoch:0,[0/116],s_label:2.322,s_domain:0.684,t_domain:0.702,time0.1462998390197754
Test set: Average loss: 0.0006, Accuracy: 9394.0/10000 (94%)
Test set: Average loss: 0.0036, Accuracy: 23400.0/59001 (40%)

Test set: Average loss: 0.0036, Accuracy: 3562.0/9001 (40%)

epoch:1,[0/116],s_label:0.502,s_domain:0.663,t_domain:0.695,time0.11793875694274902
Test set: Average loss: 0.0003, Accuracy: 9606.0/10000 (96%)
Test set: Average loss: 0.0036, Accuracy: 23301.0/59001 (39%)

Test set: Average loss: 0.0037, Accuracy: 3594.0/9001 (40%)

epoch:2,[0/116],s_label:0.229,s_domain:0.665,t_domain:0.677,time0.1772465705871582
Test set: Average loss: 0.0002, Accuracy: 9779.0/10000 (98%)
Test set: Average loss: 0.0032, Accuracy: 30049.0/59001 (51%)

Test set: Average loss: 0.0032, Accuracy: 4694.0/9001 (52%)

epoch:3,[0/116],s_label:0.135,s_domain:0.654,t_domain:0.661,time0.14191651344299316
Test set: Average loss: 0.0001, Accuracy: 9793.0/10000 (98%)
Test set: Average loss: 0.0028, Accuracy: 33607.0/

Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7fda10c42780>>
Traceback (most recent call last):
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 349, in __del__
    self._shutdown_workers()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 328, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/coder.chenshicheng

Test set: Average loss: 0.0001, Accuracy: 9780.0/10000 (98%)
Test set: Average loss: 0.0020, Accuracy: 43007.0/59001 (73%)

Test set: Average loss: 0.0020, Accuracy: 6612.0/9001 (73%)

epoch:13,[0/116],s_label:0.150,s_domain:0.650,t_domain:0.665,time0.1572885513305664
Test set: Average loss: 0.0002, Accuracy: 9633.0/10000 (96%)
Test set: Average loss: 0.0020, Accuracy: 42782.0/59001 (73%)

Test set: Average loss: 0.0020, Accuracy: 6584.0/9001 (73%)

epoch:14,[0/116],s_label:0.176,s_domain:0.705,t_domain:0.664,time0.11926960945129395
Test set: Average loss: 0.0001, Accuracy: 9840.0/10000 (98%)
Test set: Average loss: 0.0018, Accuracy: 44027.0/59001 (75%)

Test set: Average loss: 0.0018, Accuracy: 6657.0/9001 (74%)

epoch:15,[0/116],s_label:0.127,s_domain:0.660,t_domain:0.682,time0.13114213943481445
Test set: Average loss: 0.0001, Accuracy: 9801.0/10000 (98%)
Test set: Average loss: 0.0016, Accuracy: 45188.0/59001 (77%)

Test set: Average loss: 0.0016, Accuracy: 6904.0/9001 (77%)

epoch:

Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7fda10c426a0>>
Traceback (most recent call last):
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 349, in __del__
    self._shutdown_workers()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 328, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/coder.chenshicheng

Test set: Average loss: 0.0001, Accuracy: 9817.0/10000 (98%)
Test set: Average loss: 0.0014, Accuracy: 47525.0/59001 (81%)

Test set: Average loss: 0.0013, Accuracy: 7324.0/9001 (81%)

epoch:31,[0/116],s_label:0.073,s_domain:0.669,t_domain:0.666,time0.12107086181640625
Test set: Average loss: 0.0001, Accuracy: 9831.0/10000 (98%)
Test set: Average loss: 0.0014, Accuracy: 47355.0/59001 (80%)

Test set: Average loss: 0.0014, Accuracy: 7277.0/9001 (81%)

epoch:32,[0/116],s_label:0.133,s_domain:0.650,t_domain:0.689,time0.12311673164367676


Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7fda10c42828>>
Traceback (most recent call last):
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 349, in __del__
    self._shutdown_workers()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 328, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/coder.chenshicheng

Test set: Average loss: 0.0001, Accuracy: 9842.0/10000 (98%)
Test set: Average loss: 0.0013, Accuracy: 47215.0/59001 (80%)

Test set: Average loss: 0.0014, Accuracy: 7192.0/9001 (80%)

epoch:33,[0/116],s_label:0.075,s_domain:0.682,t_domain:0.667,time0.12757134437561035
Test set: Average loss: 0.0001, Accuracy: 9863.0/10000 (99%)
Test set: Average loss: 0.0013, Accuracy: 48122.0/59001 (82%)

Test set: Average loss: 0.0013, Accuracy: 7320.0/9001 (81%)

epoch:34,[0/116],s_label:0.132,s_domain:0.676,t_domain:0.673,time0.14115500450134277
Test set: Average loss: 0.0001, Accuracy: 9845.0/10000 (98%)
Test set: Average loss: 0.0016, Accuracy: 46635.0/59001 (79%)

Test set: Average loss: 0.0016, Accuracy: 7183.0/9001 (80%)

epoch:35,[0/116],s_label:0.134,s_domain:0.670,t_domain:0.674,time0.1454324722290039
Test set: Average loss: 0.0001, Accuracy: 9807.0/10000 (98%)
Test set: Average loss: 0.0014, Accuracy: 47566.0/59001 (81%)

Test set: Average loss: 0.0013, Accuracy: 7293.0/9001 (81%)

epoch:

Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7fda10c42780>>
Traceback (most recent call last):
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 349, in __del__
    self._shutdown_workers()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 328, in _shutdown_workers
    self.worker_result_queue.get()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/coder.chenshicheng/anaconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/coder.chenshicheng

Test set: Average loss: 0.0001, Accuracy: 9860.0/10000 (99%)
Test set: Average loss: 0.0014, Accuracy: 47411.0/59001 (80%)

Test set: Average loss: 0.0014, Accuracy: 7223.0/9001 (80%)

epoch:50,[0/116],s_label:0.080,s_domain:0.649,t_domain:0.682,time0.11936235427856445
Test set: Average loss: 0.0001, Accuracy: 9836.0/10000 (98%)
Test set: Average loss: 0.0013, Accuracy: 47821.0/59001 (81%)

Test set: Average loss: 0.0013, Accuracy: 7292.0/9001 (81%)

epoch:51,[0/116],s_label:0.065,s_domain:0.663,t_domain:0.677,time0.11913299560546875
Test set: Average loss: 0.0001, Accuracy: 9860.0/10000 (99%)
Test set: Average loss: 0.0013, Accuracy: 47912.0/59001 (81%)

Test set: Average loss: 0.0014, Accuracy: 7298.0/9001 (81%)

epoch:52,[0/116],s_label:0.141,s_domain:0.666,t_domain:0.686,time0.1194753646850586
Test set: Average loss: 0.0001, Accuracy: 9855.0/10000 (99%)
Test set: Average loss: 0.0017, Accuracy: 46567.0/59001 (79%)

Test set: Average loss: 0.0016, Accuracy: 7195.0/9001 (80%)

epoch:

Test set: Average loss: 0.0013, Accuracy: 7442.0/9001 (83%)

epoch:80,[0/116],s_label:0.089,s_domain:0.680,t_domain:0.684,time0.14845705032348633
Test set: Average loss: 0.0001, Accuracy: 9856.0/10000 (99%)
Test set: Average loss: 0.0014, Accuracy: 48365.0/59001 (82%)

Test set: Average loss: 0.0014, Accuracy: 7435.0/9001 (83%)

epoch:81,[0/116],s_label:0.076,s_domain:0.686,t_domain:0.677,time0.13457345962524414
Test set: Average loss: 0.0001, Accuracy: 9857.0/10000 (99%)
Test set: Average loss: 0.0014, Accuracy: 48329.0/59001 (82%)

Test set: Average loss: 0.0014, Accuracy: 7462.0/9001 (83%)

epoch:82,[0/116],s_label:0.078,s_domain:0.688,t_domain:0.686,time0.17340993881225586
Test set: Average loss: 0.0001, Accuracy: 9850.0/10000 (98%)
Test set: Average loss: 0.0013, Accuracy: 48596.0/59001 (82%)

Test set: Average loss: 0.0014, Accuracy: 7463.0/9001 (83%)

epoch:83,[0/116],s_label:0.066,s_domain:0.667,t_domain:0.678,time0.11987948417663574
Test set: Average loss: 0.0001, Accuracy: 98