In [3]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import time
import os
import torch.nn as nn
from torch.autograd import Function
import torchvision
import torch.utils.data as data
from PIL import Image
import os
import matplotlib.pyplot as plt

### Gradient Reversal Layer

In [4]:
class GradientReverseLayer(Function):
    
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)*alpha
    
    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg()*ctx.alpha
        return output, None

### The three models: Feature Extractor, Class Classifier, Domain CLassifier

In [5]:
class DANN(nn.Module):
    def __init__(self):
        super(DANN, self).__init__()
        self.feature_extractor = nn.Sequential()
        self.feature_extractor.add_module('conv1', nn.Conv2d(3, 64, kernel_size=5)),
        self.feature_extractor.add_module('batchnorm1', nn.BatchNorm2d(64)),
        self.feature_extractor.add_module('maxpool1', nn.MaxPool2d(2)),
        self.feature_extractor.add_module('relu1', nn.ReLU(True)),
        self.feature_extractor.add_module('conv2', nn.Conv2d(64, 50, kernel_size=5)),
        self.feature_extractor.add_module('batchnorm2', nn.BatchNorm2d(50)),
        self.feature_extractor.add_module('drop1', nn.Dropout2d())
        self.feature_extractor.add_module('maxpool2', nn.MaxPool2d(2)),
        self.feature_extractor.add_module('relu2', nn.ReLU(True))
        
        self.class_classifier = nn.Sequential()
        self.class_classifier.add_module('fc1', nn.Linear(50*4*4, 100)),
        self.class_classifier.add_module('batchnorm1', nn.BatchNorm1d(100)),
        self.class_classifier.add_module('relu1', nn.ReLU(True)),
        self.class_classifier.add_module('drop1', nn.Dropout2d()),
        self.class_classifier.add_module('fc2', nn.Linear(100, 100)),
        self.class_classifier.add_module('batchnorm2', nn.BatchNorm1d(100)),
        self.class_classifier.add_module('relu2', nn.ReLU(True)),
        self.class_classifier.add_module('fc3', nn.Linear(100, 10)),
        # self.class_classifier.add_module('softmax', nn.LogSoftmax()),
        
        self.domain_classifier = nn.Sequential()
        self.domain_classifier.add_module('fc1', nn.Linear(50*4*4, 100)),
        self.domain_classifier.add_module('batchnorm1', nn.BatchNorm1d(100)),
        self.domain_classifier.add_module('relu1', nn.ReLU(True)),
        self.domain_classifier.add_module('fc2', nn.Linear(100, 2)),
        # self.domain_classifier.add_module('softmax', nn.LogSoftmax(dim=1))
        
    def forward(self, x, alpha):
        # x = x.expand(x.data.shape[0], 3, 28, 28)
        feature = self.feature_extractor(x)
        feature = feature.view(-1, 50*4*4)
        reverse = GradientReverseLayer.apply(feature, alpha)
        class_output = self.class_classifier(feature)
        domain_output = self.domain_classifier(reverse)
        class_output = class_output.view(class_output.shape[0], -1)
        domain_output = domain_output.view(domain_output.shape[0], -1)
        return class_output, domain_output

In [6]:
model = DANN()
model

DANN(
  (feature_extractor): Sequential(
    (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
    (batchnorm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (relu1): ReLU(inplace=True)
    (conv2): Conv2d(64, 50, kernel_size=(5, 5), stride=(1, 1))
    (batchnorm2): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (drop1): Dropout2d(p=0.5, inplace=False)
    (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (relu2): ReLU(inplace=True)
  )
  (class_classifier): Sequential(
    (fc1): Linear(in_features=800, out_features=100, bias=True)
    (batchnorm1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (drop1): Dropout2d(p=0.5, inplace=False)
    (fc2): Linear(in_features=100, out_features=100, bias=True)
    (

#### Dataset Loading

In [7]:
from torchvision.datasets import MNIST        
import torchvision.transforms as transforms 

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)
print('Min Pixel Value: {} \nMax Pixel Value: {}'.format(trainset.data.min(), trainset.data.max()))
print('Mean Pixel Value {} \nPixel Values Std: {}'.format(trainset.data.float().mean(), trainset.data.float().std()))
print('Scaled Mean Pixel Value {} \nScaled Pixel Values Std: {}'.format(trainset.data.float().mean() / 255, trainset.data.float().std() / 255))

Min Pixel Value: 0 
Max Pixel Value: 255
Mean Pixel Value 33.31842041015625 
Pixel Values Std: 78.56748962402344
Scaled Mean Pixel Value 0.13066047430038452 
Scaled Pixel Values Std: 0.30810779333114624


In [8]:
# cuda = True
# lr = 1e-3
# image_size = 28


# img_transform_source = transforms.Compose([
#     transforms.Resize(image_size),
#     transforms.ToTensor(),
#     transforms.Normalize(mean = (0.1307,), std = (0.3081,))
# ])
# img_transform_target = transforms.Compose([
#     transforms.Resize(image_size),
#     transforms.ToTensor(),
#     transforms.Normalize(mean = (0.5,0.5,0.5), std = (0.5,0.5,0.5))
# ])

In [6]:
device = torch.device("cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [26]:
def repeat_channels(x):
    return x.repeat(3, 1, 1)

source_train = torch.utils.data.DataLoader(
    datasets.MNIST('./', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,)),
                       transforms.Lambda(repeat_channels)
                   ])),
    batch_size=128, shuffle=True, num_workers=4)

source_test = torch.utils.data.DataLoader(
    datasets.MNIST('./', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5,), (0.5,)),
                       transforms.Lambda(repeat_channels)
                   ])),
    batch_size=128, shuffle=False, num_workers=4)

In [27]:
class GetLoader(data.Dataset):
    def __init__(self, data_root, data_list, transform=None):
        self.root = data_root
        self.transform = transform

        f = open(data_list, 'r')
        data_list = f.readlines()
        f.close()

        self.n_data = len(data_list)

        self.img_paths = []
        self.img_labels = []

        for data in data_list:
            self.img_paths.append(data[:-3])
            self.img_labels.append(data[-2])

    def __getitem__(self, item):
        img_paths, labels = self.img_paths[item], self.img_labels[item]
        imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')

        if self.transform is not None:
            imgs = self.transform(imgs)
            labels = int(labels)

        return imgs, labels

    def __len__(self):
        return self.n_data

In [28]:
image_size=28
img_transform = transforms.Compose([
    transforms.RandomCrop((image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5),
                         std=(0.5,0.5,0.5))
])
train_list = 'C:\\Users\\Dell\\Downloads\\mnist_m\\mnist_m\\mnist_m_train_labels.txt'
dataset_train_target = GetLoader(
    data_root='C:\\Users\\Dell\\Downloads\\mnist_m\\mnist_m\\mnist_m_train',
    data_list=train_list,
    transform=img_transform
)
test_list = 'C:\\Users\\Dell\\Downloads\\mnist_m\\mnist_m\\mnist_m_test_labels.txt'
dataset_test_target = GetLoader(
    data_root='C:\\Users\\Dell\\Downloads\\mnist_m\\mnist_m\\mnist_m_test',
    data_list=test_list,
    transform=img_transform
)
target_train = torch.utils.data.DataLoader(dataset_train_target,batch_size=512, shuffle=True,num_workers=4)
target_test = torch.utils.data.DataLoader(dataset_test_target,batch_size=512, shuffle=True,num_workers=4)

### Model Training

In [29]:
import os
import matplotlib.pyplot as plt
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchvision import transforms
from torchvision import datasets
from PIL import Image

In [30]:
device = 'cuda'
model = DANN()
optimizer = optim.SGD(model.parameters(), lr= 0.01, momentum= 0.9)
criterion = nn.CrossEntropyLoss()

def optimizer_scheduler(optimizer, p):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.01 / (1. + 10 * p) ** 0.75
    return optimizer

loss_class = torch.nn.NLLLoss()
loss_domain = torch.nn.NLLLoss()

In [31]:
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in source_test:
            data, target = data, target
            output, _ = model(data,0.5)
            test_loss += float(criterion(output, target))  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += float(pred.eq(target.view_as(pred)).sum())

    test_loss /= len(source_test.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(source_test.dataset),
        100. * correct / len(source_test.dataset)))

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in target_train:
            data, target = data, target
            output, _ = model(data,0.5)
            test_loss += float(criterion(output, target))  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += float(pred.eq(target.view_as(pred)).sum())

    test_loss /= len(target_train.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(target_train.dataset),
        100. * correct / len(target_train.dataset)))
    
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in target_test:
            data, target = data, target
            output, _ = model(data,0.5)
            test_loss += float(criterion(output, target))  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += float(pred.eq(target.view_as(pred)).sum())

    test_loss /= len(target_test.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(target_test.dataset),
        100. * correct / len(target_test.dataset)))

In [32]:
print(len(target_train),len(source_train))

116 469


In [None]:
allepoch=100

for epoch in range(allepoch):
    len_dataloader = min(len(source_train), len(target_train))
    total_steps = allepoch * len(source_train)
    i = 0
    model.train()
    for batch_idx, (data_source, data_target) in enumerate(zip(source_train, target_train)):
        start_time = time.time()
        s_img, s_label = data_source

        start_steps = epoch * len(source_train)

        p = float(i + start_steps) / total_steps
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        optimizer = optimizer_scheduler(optimizer, p)
        optimizer.zero_grad()


        batch_size = len(s_label)

        domain_label = torch.zeros(batch_size)
        domain_label = domain_label.long()


        a,b = model(s_img,alpha)
        err_s_label = criterion(a, s_label)
        err_s_domain = criterion(b, domain_label)

        # training model using target data
        t_img, _ = data_target

        batch_size = len(t_img)

        domain_label = torch.ones(batch_size)
        domain_label = domain_label.long()



        _, b = model(t_img,alpha)
        err_t_domain = criterion(b, domain_label)
        err = err_s_label + err_s_domain + err_t_domain
        err.backward()
        optimizer.step()


        if(i % 1000 == 0):
            print('epoch:{},[{}/{}],s_label:{:.3f},s_domain:{:.3f},t_domain:{:.3f},time{}'.
                      format(epoch, i, len_dataloader, float(err_s_label), float(err_s_domain),
                             float(err_t_domain), time.time() - start_time))

        i += 1

    test(epoch)