In [1]:
import dann_model as dann
import training_helper as th
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os
import numpy as np
from PIL import Image
import torchvision.models as models
import copy
from torchvision.utils import save_image
import PIL
import time
from sklearn.manifold import TSNE

In [2]:
device = th.getCudaDevice(cudaNum = 1, torchSeed = 123)
batch_size = 128
n_epoch = 10
my_net = dann.CNNModel().to(device)
optimizer = optim.Adam(my_net.parameters(), lr=1e-3)
loss_class = torch.nn.NLLLoss().to(device)
loss_domain = torch.nn.NLLLoss().to(device)

Device used: cuda:1


In [3]:
img_transform = transforms.Compose([
    transforms.Resize(28),
    #transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])

mnist_train = th.labelImgData(root='hw3_data/digits/mnistm/train', transform=img_transform)
mnist_test = th.labelImgData(root='hw3_data/digits/mnistm/test', transform=img_transform)
svhn_train = th.labelImgData(root='hw3_data/digits/svhn/train', transform=img_transform)
svhn_test = th.labelImgData(root='hw3_data/digits/svhn/test', transform=img_transform)
usps_train = th.labelImgData(root='hw3_data/digits/usps/train', transform=img_transform)
usps_test = th.labelImgData(root='hw3_data/digits/usps/test', transform=img_transform)

mnist_train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=0)
mnist_test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=0)
svhn_train_loader = DataLoader(svhn_train, batch_size=batch_size, shuffle=True, num_workers=0)
svhn_test_loader = DataLoader(svhn_test, batch_size=batch_size, shuffle=False, num_workers=0)
usps_train_loader = DataLoader(usps_train, batch_size=batch_size, shuffle=True, num_workers=0)
usps_test_loader = DataLoader(usps_test, batch_size=batch_size, shuffle=False, num_workers=0)

In [5]:
def test(dataloader):
    my_net.eval()
    correct = 0
    for batch_idx, (img, label) in enumerate(dataloader):
        img, label = img.to(device), label.to(device)
        output, _ = my_net(img)
        _, pred = torch.max(output, 1)
        correct += (pred == label).sum()
    my_net.train()
    return correct.item() / len(dataloader.dataset) 

def train(filename):
    source_train, source_test = source
    target_train, target_test = target
    best_acc = 0
    for epoch in range(n_epoch):
        len_dataloader = min(len(source_train), len(target_train))
        data_source_iter, data_target_iter = iter(source_train), iter(target_train)
        for batch_idx in range(len_dataloader):
            # calculate lambda
            p = float(batch_idx + epoch * len_dataloader) / n_epoch / len_dataloader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # training model using source data
            s_img, s_label = data_source_iter.next()
            s_img, s_label = s_img.to(device), s_label.to(device)
            my_net.zero_grad()
            domain_label = torch.zeros(len(s_label)).long().to(device)
            class_output, domain_output = my_net(s_img, alpha=alpha)
            err_s_label = loss_class(class_output, s_label)
            err_s_domain = loss_domain(domain_output, domain_label)

            # training model using target data
            t_img, _ = data_target_iter.next()
            t_img = t_img.to(device)
            domain_label = torch.ones(len(t_img)).long().to(device)
            _, domain_output = my_net(t_img, alpha=alpha)
            err_t_domain = loss_domain(domain_output, domain_label)
            err = err_t_domain + alpha*(err_s_domain + err_s_label)
            err.backward()
            optimizer.step()
            if batch_idx % max(1, int(len_dataloader/3)) == 0:
                print('Epoch: {}/{}\t[Iter: {}/{} ({}%)]\tSource Labeling Loss: {:.4f}\tSource Domain Loss: {:.4f}\tTarget Domain Loss: {:.4f}'.format(
                      epoch+1, n_epoch, batch_idx, len_dataloader, int(batch_idx*100/len_dataloader), err_s_label, err_s_domain, err_t_domain))
        target_acc = test(target_test)
        print('*****Source Accuracy: {:.4f}\n*****Target Accuracy: {:.4f}'.format(test(source_test), target_acc),'\n')
        if target_acc > best_acc:
            best_acc = target_acc
            th.saveModel(filename+'_best.pth', my_net, optimizer)
    th.saveModel(filename+'_final.pth', my_net, optimizer)

In [6]:
# usps->mnistm
source = usps_train_loader, usps_test_loader
target = mnist_train_loader, mnist_test_loader
filename = 'p3_usps->mnistm'
th.loadModel(filename+'_final.pth', my_net, optimizer)
train(filename)

model loaded from p3_usps->mnistm_final.pth
*****Source Accuracy: 0.9377
*****Target Accuracy: 0.4359 

model saved to p3_usps->mnistm_best.pth
*****Source Accuracy: 0.9402
*****Target Accuracy: 0.4132 

*****Source Accuracy: 0.9442
*****Target Accuracy: 0.4235 

*****Source Accuracy: 0.9312
*****Target Accuracy: 0.4585 

model saved to p3_usps->mnistm_best.pth
*****Source Accuracy: 0.9238
*****Target Accuracy: 0.4448 

*****Source Accuracy: 0.9372
*****Target Accuracy: 0.4366 

*****Source Accuracy: 0.9352
*****Target Accuracy: 0.4128 

*****Source Accuracy: 0.9322
*****Target Accuracy: 0.4419 

*****Source Accuracy: 0.9342
*****Target Accuracy: 0.4482 

*****Source Accuracy: 0.9402
*****Target Accuracy: 0.4553 

model saved to p3_usps->mnistm_final.pth


In [6]:
# mnistm->svhn
source = mnist_train_loader, mnist_test_loader
target = svhn_train_loader, svhn_test_loader
train('p3_mnistm->svhn')

*****Source Accuracy: 0.7733
*****Target Accuracy: 0.3099 

model saved to p3_mnistm->svhn_best.pth
*****Source Accuracy: 0.9330
*****Target Accuracy: 0.4849 

model saved to p3_mnistm->svhn_best.pth
*****Source Accuracy: 0.9274
*****Target Accuracy: 0.4764 

*****Source Accuracy: 0.9307
*****Target Accuracy: 0.3914 

*****Source Accuracy: 0.9216
*****Target Accuracy: 0.4541 

*****Source Accuracy: 0.9292
*****Target Accuracy: 0.4506 

*****Source Accuracy: 0.8569
*****Target Accuracy: 0.3855 

*****Source Accuracy: 0.9335
*****Target Accuracy: 0.5071 

model saved to p3_mnistm->svhn_best.pth
*****Source Accuracy: 0.9342
*****Target Accuracy: 0.4712 

*****Source Accuracy: 0.9315
*****Target Accuracy: 0.4387 

model saved to p3_mnistm->svhn_final.pth


In [10]:
# svhn->usps
source = svhn_train_loader, svhn_test_loader
target = usps_train_loader, usps_test_loader
train('p3_svhn->usps')

*****Source Accuracy: 0.3654
*****Target Accuracy: 0.2770 

model saved to p3_svhn->usps_best.pth
*****Source Accuracy: 0.6040
*****Target Accuracy: 0.4544 

model saved to p3_svhn->usps_best.pth
*****Source Accuracy: 0.7359
*****Target Accuracy: 0.5571 

model saved to p3_svhn->usps_best.pth
*****Source Accuracy: 0.7537
*****Target Accuracy: 0.5282 

*****Source Accuracy: 0.7374
*****Target Accuracy: 0.5635 

model saved to p3_svhn->usps_best.pth
*****Source Accuracy: 0.7634
*****Target Accuracy: 0.5451 

*****Source Accuracy: 0.7074
*****Target Accuracy: 0.5605 

*****Source Accuracy: 0.7067
*****Target Accuracy: 0.5546 

*****Source Accuracy: 0.7176
*****Target Accuracy: 0.5964 

model saved to p3_svhn->usps_best.pth
*****Source Accuracy: 0.7515
*****Target Accuracy: 0.5092 

model saved to p3_svhn->usps_final.pth


In [8]:
def train_source(filename):
    source_train, source_test = source
    best_acc, n_epoch = 0, 10
    for epoch in range(n_epoch):
        for batch_idx, (img, label) in enumerate(source_train):
            img, label = img.to(device), label.to(device)
            my_net.zero_grad()
            output, _ = my_net(img)
            loss = loss_class(output, label)
            loss.backward()
            optimizer.step()
            if batch_idx % max(1, int(len(source_train)/3)) == 0:
                print('Epoch: {}/{}\t[Iter: {}/{} ({}%)]\tLoss: {:.4f}'.format(epoch+1, n_epoch, batch_idx,
                      len(source_train), int(batch_idx*100/len(source_train)), loss.item()))
        acc = test(source_test)
        print('*****Accuracy: {:.4f}'.format(acc),'\n')
        if acc > best_acc:
            best_acc = acc
            th.saveModel(filename+'_best.pth', my_net, optimizer)
    th.saveModel(filename+'_final.pth', my_net, optimizer)

In [16]:
# train single data
source = svhn_train_loader, svhn_test_loader
train_source('p3_svhn')

*****Accuracy: 0.8486 

model saved to p3_svhn_best.pth
*****Accuracy: 0.8648 

model saved to p3_svhn_best.pth
*****Accuracy: 0.8790 

model saved to p3_svhn_best.pth
*****Accuracy: 0.8775 

*****Accuracy: 0.8835 

model saved to p3_svhn_best.pth
*****Accuracy: 0.8931 

model saved to p3_svhn_best.pth
*****Accuracy: 0.8972 

model saved to p3_svhn_best.pth
*****Accuracy: 0.8907 

*****Accuracy: 0.8978 

model saved to p3_svhn_best.pth
*****Accuracy: 0.9019 

model saved to p3_svhn_best.pth
model saved to p3_svhn_final.pth


In [17]:
# train single data
source = usps_train_loader, usps_test_loader
train_source('p3_usps')

*****Accuracy: 0.9332 

model saved to p3_usps_best.pth
*****Accuracy: 0.9432 

model saved to p3_usps_best.pth
*****Accuracy: 0.9532 

model saved to p3_usps_best.pth
*****Accuracy: 0.9562 

model saved to p3_usps_best.pth
*****Accuracy: 0.9557 

*****Accuracy: 0.9581 

model saved to p3_usps_best.pth
*****Accuracy: 0.9586 

model saved to p3_usps_best.pth
*****Accuracy: 0.9641 

model saved to p3_usps_best.pth
*****Accuracy: 0.9596 

*****Accuracy: 0.9651 

model saved to p3_usps_best.pth
model saved to p3_usps_final.pth


In [9]:
# train single data
source = mnist_train_loader, mnist_test_loader
train_source('p3_mnist')

*****Accuracy: 0.9485 

model saved to p3_mnist_best.pth
*****Accuracy: 0.9465 

*****Accuracy: 0.9645 

model saved to p3_mnist_best.pth
*****Accuracy: 0.9654 

model saved to p3_mnist_best.pth
*****Accuracy: 0.9612 

*****Accuracy: 0.9699 

model saved to p3_mnist_best.pth
*****Accuracy: 0.9735 

model saved to p3_mnist_best.pth
*****Accuracy: 0.9729 

*****Accuracy: 0.9714 

*****Accuracy: 0.9631 

model saved to p3_mnist_final.pth
