In [1]:
import dann_dual_model as dannd
import training_helper as th
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os
import numpy as np
from PIL import Image
import torchvision.models as models
import copy
from torchvision.utils import save_image
import PIL
import time

In [2]:
device = th.getCudaDevice(cudaNum = 0, torchSeed = 123)
batch_size = 128
n_epoch = 10
my_net = dannd.CNNModel().to(device)
optimizer = optim.Adam(my_net.parameters(), lr=1e-4, betas = (0.5, 0.999))
loss_class = torch.nn.NLLLoss().to(device)
loss_domain = torch.nn.NLLLoss().to(device)

Device used: cuda:0


In [3]:
img_transform = transforms.Compose([
    transforms.Resize(28),
    #transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])

mnist_train = th.labelImgData(root='hw3_data/digits/mnistm/train', transform=img_transform)
mnist_test = th.labelImgData(root='hw3_data/digits/mnistm/test', transform=img_transform)
svhn_train = th.labelImgData(root='hw3_data/digits/svhn/train', transform=img_transform)
svhn_test = th.labelImgData(root='hw3_data/digits/svhn/test', transform=img_transform)
usps_train = th.labelImgData(root='hw3_data/digits/usps/train', transform=img_transform)
usps_test = th.labelImgData(root='hw3_data/digits/usps/test', transform=img_transform)

mnist_train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=0)
mnist_test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=0)
svhn_train_loader = DataLoader(svhn_train, batch_size=batch_size, shuffle=True, num_workers=0)
svhn_test_loader = DataLoader(svhn_test, batch_size=batch_size, shuffle=False, num_workers=0)
usps_train_loader = DataLoader(usps_train, batch_size=batch_size, shuffle=True, num_workers=0)
usps_test_loader = DataLoader(usps_test, batch_size=batch_size, shuffle=False, num_workers=0)

In [4]:
def test(dataloader, domain= 'target'):
    my_net.eval()
    correct = 0
    for batch_idx, (img, label) in enumerate(dataloader):
        img, label = img.to(device), label.to(device)
        output, _ = my_net(img, domain = domain)
        _, pred = torch.max(output, 1)
        correct += (pred == label).sum()
    my_net.train()
    return correct.item() / len(dataloader.dataset) 

def train(filename):
    source_train, source_test = source
    target_train, target_test = target
    best_acc = test(target_test)
    print('*****Target Accuracy: {:.4f}'.format(best_acc),'\n')
    for epoch in range(n_epoch):
        len_dataloader = min(len(source_train), len(target_train))
        data_source_iter, data_target_iter = iter(source_train), iter(target_train)
        for batch_idx in range(len_dataloader):
            # calculate lambda
            p = float(batch_idx + epoch * len_dataloader) / n_epoch / len_dataloader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # training model using source data
            s_img, s_label = data_source_iter.next()
            s_img, s_label = s_img.to(device), s_label.to(device)
            my_net.zero_grad()
            domain_label_src = torch.zeros(len(s_label)).long().to(device)
            class_output, domain_output_src = my_net(s_img, alpha=alpha, domain = 'source')
            err_s_label = loss_class(class_output, s_label)
            err_s_domain = loss_domain(domain_output_src, domain_label_src)

            # training model using target data
            t_img, _ = data_target_iter.next()
            t_img = t_img.to(device)
            domain_label_tgt = torch.ones(len(t_img)).long().to(device)
            _, domain_output_tgt = my_net(t_img, alpha=alpha, domain = 'target')
            err_t_domain = loss_domain(domain_output_tgt, domain_label_tgt)
            
            # cross
            class_output_x, _ = my_net(s_img, alpha=alpha, domain = 'target')
            err_x_label = loss_class(class_output_x, s_label)
            
            err =  err_x_label + err_t_domain  + alpha*(err_s_label+ err_s_domain)
            err.backward()
            optimizer.step()
            if batch_idx % max(1, int(len_dataloader/3)) == 0:
                print('Epoch: {}/{}\t[Iter: {}/{} ({}%)]\ts_d: {:.4f}\tt_d: {:.4f}\ts_l: {:.4f}\tx_l: {:.4f}'.format(
                      epoch+1, n_epoch, batch_idx, len_dataloader, int(batch_idx*100/len_dataloader), err_s_domain, err_t_domain,
                      err_s_label, err_x_label))
        target_acc = test(target_test)
        print('*****Source Accuracy: {:.4f}\n*****Target Accuracy: {:.4f}'.format(test(source_test, domain = 'source'), target_acc),'\n')
        if target_acc > best_acc:
            best_acc = target_acc
            th.saveModel(filename+'_best.pth', my_net, optimizer)
    th.saveModel(filename+'_final.pth', my_net, optimizer)

In [7]:
# usps->mnistm
source = usps_train_loader, usps_test_loader
target = mnist_train_loader, mnist_test_loader
filename = 'p4_usps->mnistm'
th.loadModel(filename+'_best.pth', my_net, optimizer)
train(filename)

model loaded from p4_usps->mnistm_best.pth
*****Target Accuracy: 0.5101 

*****Source Accuracy: 0.9581
*****Target Accuracy: 0.4904 

*****Source Accuracy: 0.9576
*****Target Accuracy: 0.4952 

*****Source Accuracy: 0.9557
*****Target Accuracy: 0.4852 

*****Source Accuracy: 0.9552
*****Target Accuracy: 0.4966 

*****Source Accuracy: 0.9542
*****Target Accuracy: 0.4877 

*****Source Accuracy: 0.9586
*****Target Accuracy: 0.4892 

*****Source Accuracy: 0.9596
*****Target Accuracy: 0.4960 

*****Source Accuracy: 0.9601
*****Target Accuracy: 0.4875 

*****Source Accuracy: 0.9571
*****Target Accuracy: 0.5122 

model saved to p4_usps->mnistm_best.pth
*****Source Accuracy: 0.9512
*****Target Accuracy: 0.5018 

model saved to p4_usps->mnistm_final.pth


In [5]:
# svhn->usps
source = svhn_train_loader, svhn_test_loader
target = usps_train_loader, usps_test_loader
filename = 'p4_svhn->usps'
th.loadModel(filename+'_final.pth', my_net, optimizer)
train(filename)

model loaded from p4_svhn->usps_final.pth
*****Target Accuracy: 0.5112 

*****Source Accuracy: 0.8605
*****Target Accuracy: 0.5690 

model saved to p4_svhn->usps_best.pth
*****Source Accuracy: 0.8490
*****Target Accuracy: 0.5152 

*****Source Accuracy: 0.8457
*****Target Accuracy: 0.6273 

model saved to p4_svhn->usps_best.pth
*****Source Accuracy: 0.8123
*****Target Accuracy: 0.5864 

*****Source Accuracy: 0.8108
*****Target Accuracy: 0.6079 

*****Source Accuracy: 0.8176
*****Target Accuracy: 0.5810 

*****Source Accuracy: 0.8453
*****Target Accuracy: 0.6378 

model saved to p4_svhn->usps_best.pth
*****Source Accuracy: 0.8424
*****Target Accuracy: 0.6831 

model saved to p4_svhn->usps_best.pth
*****Source Accuracy: 0.8344
*****Target Accuracy: 0.6522 

*****Source Accuracy: 0.8478
*****Target Accuracy: 0.5630 

model saved to p4_svhn->usps_final.pth


In [5]:
# mnistm -> svhn
source = mnist_train_loader, mnist_test_loader
target = svhn_train_loader, svhn_test_loader
filename = 'p4_mnistm->svhn'
th.loadModel(filename+'_best.pth', my_net, optimizer)
train(filename)

model loaded from p4_mnistm->svhn_best.pth
*****Target Accuracy: 0.5290 

*****Source Accuracy: 0.9660
*****Target Accuracy: 0.5084 

*****Source Accuracy: 0.9589
*****Target Accuracy: 0.4386 

*****Source Accuracy: 0.9524
*****Target Accuracy: 0.4991 

*****Source Accuracy: 0.9533
*****Target Accuracy: 0.5105 

*****Source Accuracy: 0.9569
*****Target Accuracy: 0.5194 

*****Source Accuracy: 0.9620
*****Target Accuracy: 0.5479 

model saved to p4_mnistm->svhn_best.pth
*****Source Accuracy: 0.9574
*****Target Accuracy: 0.5211 

*****Source Accuracy: 0.9588
*****Target Accuracy: 0.5210 

*****Source Accuracy: 0.9625
*****Target Accuracy: 0.5294 

*****Source Accuracy: 0.9413
*****Target Accuracy: 0.5351 

model saved to p4_mnistm->svhn_final.pth
