In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import time
import os
import torch.nn as nn
from torch.autograd import Function
import torchvision
import torch.utils.data as data
from PIL import Image
import os

In [2]:
class GRL(Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        ctx.constant = constant
        return tensor * constant
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.constant, None

In [3]:
class Dann(nn.Module):
    def __init__(self):
        super(Dann, self).__init__()
        self.f = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=5,padding=2),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(64, 64, kernel_size=3,padding=1,stride=2),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(64, 64, kernel_size=3,padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(64, 64, kernel_size=3,padding=1,stride=2),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(64, 64, kernel_size=3,padding=1),
                nn.AvgPool2d(7)
            )

        self.lc = nn.Conv2d(64,10,1,1)
        self.dc = nn.Conv2d(64,2,1,1)
    def forward(self, x):
        x = self.f(x)
        y = GRL.apply(x, 0.5)
        x = self.lc(x)
        y = self.dc(y)
        x=x.view(x.shape[0],-1)
        y=y.view(y.shape[0],-1)
        return x, y

In [4]:
source_train = torch.utils.data.DataLoader(
        datasets.MNIST('./', train=True, download=False,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                           transforms.Lambda(lambda x: x.repeat(3, 1, 1))
                       ])),
        batch_size=25, shuffle=True)
source_test = torch.utils.data.DataLoader(
        datasets.MNIST('./', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,)),
                           transforms.Lambda(lambda x: x.repeat(3, 1, 1))
                       ])),
        batch_size=25, shuffle=False)

In [5]:
class GetLoader(data.Dataset):
    def __init__(self, data_root, data_list, transform=None):
        self.root = data_root
        self.transform = transform

        f = open(data_list, 'r')
        data_list = f.readlines()
        f.close()

        self.n_data = len(data_list)

        self.img_paths = []
        self.img_labels = []

        for data in data_list:
            self.img_paths.append(data[:-3])
            self.img_labels.append(data[-2])

    def __getitem__(self, item):
        img_paths, labels = self.img_paths[item], self.img_labels[item]
        imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')

        if self.transform is not None:
            imgs = self.transform(imgs)
            labels = int(labels)

        return imgs, labels

    def __len__(self):
        return self.n_data

In [6]:
image_size=28
img_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.40824443831346119, 0.46209181664543181, 0.45790558913705282), 
                         std=(0.25870560685997501, 0.23677806687215142, 0.25194991442789383))
])
train_list = os.path.join('mnist_m', 'mnist_m_train_labels.txt')
dataset_target = GetLoader(
    data_root=os.path.join('mnist_m', 'mnist_m_train'),
    data_list=train_list,
    transform=img_transform
)
target_train = torch.utils.data.DataLoader(dataset_target,batch_size=25, shuffle=True)

In [7]:
#device='cuda'
model = Dann()#.to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [8]:
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in source_test:
            data, target = data, target
            output,_ = model(data)
            test_loss += criterion(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(source_test.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(source_test.dataset),
        100. * correct / len(source_test.dataset)))
    
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in target_train:
            output,_ = model(data)
            test_loss += criterion(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(target_train.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(target_train.dataset),
        100. * correct / len(target_train.dataset)))

In [None]:
len(dataset_target),len(source_train)

(59001, 2400)

In [None]:
for epoch in range(100):
    len_dataloader = min(len(source_train), len(dataset_target))
    data_source_iter = iter(source_train)
    data_target_iter = iter(target_train)
    i = 0
    model.train()
    while i < len_dataloader:
        data_source = data_source_iter.next()
        s_img, s_label = data_source

        optimizer.zero_grad()
        batch_size = len(s_label)

        domain_label = torch.zeros(batch_size)
        domain_label = domain_label.long()
        

        a,b = model(s_img)
        err_s_label = criterion(a, s_label)
        err_s_domain = criterion(b, domain_label)

        # training model using target data
        data_target = data_target_iter.next()
        t_img, _ = data_target

        batch_size = len(t_img)

        input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
        domain_label = torch.ones(batch_size)
        domain_label = domain_label.long()



        _, b = model(t_img)
        err_t_domain = criterion(b, domain_label)
        err = err_t_domain + err_s_domain + err_s_label
        err.backward()
        optimizer.step()

        i += 1

        print ('epoch: %d, [iter: %d / all %d], err_s_label: %f, err_s_domain: %f, err_t_domain: %f' \
              % (epoch, i, len_dataloader, err_s_label.cpu().data.numpy(),
                 err_s_domain.cpu().data.numpy(), err_t_domain.cpu().data.numpy()))

    test(epoch)

epoch: 0, [iter: 1 / all 2400], err_s_label: 2.321313, err_s_domain: 0.779572, err_t_domain: 0.615857
epoch: 0, [iter: 2 / all 2400], err_s_label: 2.248077, err_s_domain: 0.809288, err_t_domain: 0.596789
epoch: 0, [iter: 3 / all 2400], err_s_label: 2.324873, err_s_domain: 0.852448, err_t_domain: 0.569518
epoch: 0, [iter: 4 / all 2400], err_s_label: 2.231906, err_s_domain: 0.926140, err_t_domain: 0.522205
epoch: 0, [iter: 5 / all 2400], err_s_label: 2.291523, err_s_domain: 1.008388, err_t_domain: 0.489156
epoch: 0, [iter: 6 / all 2400], err_s_label: 2.257241, err_s_domain: 1.105358, err_t_domain: 0.444690
epoch: 0, [iter: 7 / all 2400], err_s_label: 2.230206, err_s_domain: 1.230088, err_t_domain: 0.398159
epoch: 0, [iter: 8 / all 2400], err_s_label: 2.207369, err_s_domain: 1.374983, err_t_domain: 0.370784
epoch: 0, [iter: 9 / all 2400], err_s_label: 2.180142, err_s_domain: 1.535119, err_t_domain: 0.339016
epoch: 0, [iter: 10 / all 2400], err_s_label: 2.520883, err_s_domain: 1.703441, er

epoch: 0, [iter: 81 / all 2400], err_s_label: 1.829764, err_s_domain: 0.368907, err_t_domain: 1.237989
epoch: 0, [iter: 82 / all 2400], err_s_label: 1.827657, err_s_domain: 0.379873, err_t_domain: 1.214653
epoch: 0, [iter: 83 / all 2400], err_s_label: 1.435142, err_s_domain: 0.408505, err_t_domain: 1.156642
epoch: 0, [iter: 84 / all 2400], err_s_label: 1.796778, err_s_domain: 0.418081, err_t_domain: 1.150382
