In [28]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from zipfile import ZipFile
import skimage.io
from PIL import Image
import seaborn as sns
from sklearn.manifold import TSNE


from mnist_generator import get_mnist_loaders
from mnistm_generator import get_mnistm_loaders
from DANN import *
from DA import *
from test import *
from train import *
from visualize import *
from util import *
#https://github.com/fungtion/DSN/blob/master/train.py#L223

In [29]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [30]:
mnist_train_loader,mnist_eval_loader, mnist_test_loader = get_mnist_loaders(batch_size=128)

In [31]:
mnistm_train_loader, mnistm_eval_loader,mnistm_test_loader = get_mnistm_loaders(batch_size=128)

In [282]:
class ReverseLayerF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        output = ctx.alpha * grad_output.neg()
        return output, None
class GradualLayerF(torch.autograd.Function):
    @staticmethod
    def forward(ctx,x,alpha):
        ctx.aplha = alpha
    @staticmethod
    def backward(ctx,grad_output):
        output=ctx.alpha*grad_output

In [231]:
def DANNTrain(mnist_train, mnistm_train, mnist_eval, mnistm_eval, epochs,intervals):
    dann = DANN().to(device)

    criterion= nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(
        list(dann.parameters()),lr=0.01)
    domain_accs =[]
    source_accs = []
    target_accs = []
    total_steps = epochs* len(mnistm_train)
    for epoch in range(epochs):
        start_steps = epoch * len(mnist_train)
        
        for batch_idx, (source, target) in enumerate(zip(mnist_train, mnistm_train)):
            if (batch_idx >100):
                break
            source_image, source_label = source
            target_image, target_label = target

            p = float(batch_idx + start_steps)/ total_steps
            alpha = 2. / (1.+np.exp(-10*p))-1 

            source_image, source_label = source_image.to(device), source_label.to(device)
            target_image, target_label = target_image.to(device), target_label.to(device)

            # update learning rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = 0.01/(1.+10*p)**0.75

            # clear the grad
            optimizer.zero_grad()
            
            source_yp_labels, source_yp_domains = dann(source_image,alpha)
            target_yp_labels, target_yp_domains = dann(target_image,alpha)

            
            source_labels_loss = criterion(source_yp_labels, source_label)
            source_domain_loss = criterion(source_yp_domains, torch.zeros(source_label.size()[0]).type(torch.LongTensor).to(device))
            target_domain_loss = criterion(target_yp_domains, torch.ones(target_label.size()[0]).type(torch.LongTensor).to(device))
            
            total_loss = source_labels_loss + source_domain_loss + target_domain_loss
            total_loss.backward()
            optimizer.step()
        if (epoch+1) % intervals == 0:
            with torch.no_grad():
                source_acc, target_acc, domain_acc = DANNAccuracy(dann, mnist_eval, mnistm_eval)
                domain_accs.append(domain_acc)
                source_accs.append(source_acc)
                target_accs.append(target_acc)
                print(f'{epoch+1}/{epochs}: source_acc: {source_acc},target_acc: {target_acc}, domain_acc: {domain_acc}')
    return source_accs, target_accs, domain_accs, dann

In [232]:
source_accs, target_accs, domain_accs, dann = DANNTrain(
      mnist_train_loader,
      mnistm_train_loader,
      mnist_eval_loader,
      mnistm_eval_loader,
      epochs=300,
      intervals = 30)

30/300: source_acc: 0.9636166666666667,target_acc: 0.5350166666666667, domain_acc: 0.5958416666666667
60/300: source_acc: 0.9658166666666667,target_acc: 0.6303333333333333, domain_acc: 0.6232
90/300: source_acc: 0.9651166666666666,target_acc: 0.69145, domain_acc: 0.6215583333333333
120/300: source_acc: 0.9673666666666667,target_acc: 0.7286666666666667, domain_acc: 0.606675
150/300: source_acc: 0.97015,target_acc: 0.75335, domain_acc: 0.6072583333333333
180/300: source_acc: 0.9686,target_acc: 0.7903, domain_acc: 0.6009833333333333
210/300: source_acc: 0.9709833333333333,target_acc: 0.8112166666666667, domain_acc: 0.599225
240/300: source_acc: 0.9725333333333334,target_acc: 0.8233833333333334, domain_acc: 0.58615
270/300: source_acc: 0.97235,target_acc: 0.8329166666666666, domain_acc: 0.580875
300/300: source_acc: 0.9711833333333333,target_acc: 0.8392, domain_acc: 0.5718416666666667


In [305]:
class DSN(nn.Module):
    def __init__(self):
        super(DSN, self).__init__()
        self.source_encoder = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=64,kernel_size=5),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(in_channels=64,out_channels=50, kernel_size=5),
            nn.BatchNorm2d(50),
            nn.ReLU(),
            nn.Dropout2d(0.5),
            nn.MaxPool2d(kernel_size=2,stride=2)
        )
        self.target_encoder = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=64, kernel_size=5),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(in_channels=64, out_channels=50, kernel_size=5),
            nn.BatchNorm2d(50),
            nn.ReLU(),
            nn.Dropout2d(0.5),
            nn.MaxPool2d(kernel_size=2,stride=2)
        )
        self.shared_encoder = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=64, kernel_size=5),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(in_channels=64, out_channels=50, kernel_size=5),
            nn.BatchNorm2d(50),
            nn.ReLU(),
            nn.Dropout2d(0.5),
            nn.MaxPool2d(kernel_size=2)

        )
        self.class_classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=50*4*4, out_features=100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(in_features=100,out_features=100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Linear(in_features=100, out_features=10),
            nn.LogSoftmax(dim=1)
        )
        self.discriminator = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=50*4*4,out_features=100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Linear(in_features=100,out_features=2),
            nn.LogSoftmax(dim=1)
        )
        self.shared_decoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=50*4*4, out_features=588),
            nn.BatchNorm1d(588),
            nn.ReLU(),
            nn.Unflatten(dim=1,unflattened_size=(3,14,14)),
            nn.Conv2d(in_channels=3, out_channels=16,kernel_size=5,padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout2d(0.5),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5,padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=16, out_channels=16,kernel_size=3,padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=3,kernel_size=3,padding=1)
        )
    
    def forward(self, x,alpha,mode):
        x = x.expand(x.data.shape[0],3,28,28)
        if mode == 'source':
            private_f = self.source_encoder(x)
        elif mode == 'target':
            private_f = self.target_encoder(x)
        
        shared_f = self.shared_encoder(x)
        rev_shared_f = ReverseLayerF.apply(shared_f,alpha)
        
        domain_label = self.discriminator(rev_shared_f)
        
        class_label = self.class_classifier(shared_f)
        union_f = private_f + shared_f
        grad_union_f = GradualLayerF.apply(union_f,0.1)
        rec_img = self.shared_decoder(union_f)
        return private_f, shared_f, rec_img, domain_label, class_label

In [306]:
class DiffLoss(nn.Module):
    def __init__(self):
        super(DiffLoss, self).__init__()
        
    def forward(self, x1, x2):
        batch_size = x1.size(0)
        x1 = x1.view(batch_size,-1)
        x2 = x2.view(batch_size,-1)
        x1_norm = torch.norm(x1,dim=1,keepdim=True).detach()
        x1_l2 = x1.div(x1_norm.expand_as(x1)+1e-6)
        x2_norm = torch.norm(x2,dim=1,keepdim=True).detach()
        x2_l2 = x2.div(x2_norm.expand_as(x2)+1e-6)
        
        diff_loss = torch.mean((x1_l2.T.mm(x2))**2)
        return diff_loss

In [307]:
loss = DiffLoss()
x1=torch.tensor([[1,2,3],[1,2,3],[1,2,3]],dtype=torch.float)
x2=torch.tensor([[2,2,3],[2,2,3],[1,2,3]],dtype=torch.float)
print(loss(x1,x2))

tensor(15.7778)


In [308]:
def DSNAccuracy(dsn, mnist_gen, mnistm_gen):
    s_cor = 0
    t_cor = 0
    domain_cor = 0
    total_classify = 0
    total_diff = 0
    total_rec = 0
    total_sim = 0
    loss_classification = torch.nn.CrossEntropyLoss().to(device)
    loss_reconstruct = nn.MSELoss().to(device)
    loss_diff = DiffLoss().to(device)
    loss_similarity = torch.nn.CrossEntropyLoss().to(device)
    
    for batch_idx, (source, target) in enumerate(zip(mnist_gen, mnistm_gen)):
        p = float(batch_idx)/len(mnist_gen)
        alpha = 2. / (1.+np.exp(-10*p))-1 
        source_image, source_label = source
        target_image, target_label = target

        source_image, source_label = source_image.to(device), source_label.to(device)
        target_image, target_label = target_image.to(device), target_label.to(device)
        source_image = source_image.expand(source_image.data.shape[0],3,28,28)

        source_domain = torch.zeros(source_label.size()[0]).type(torch.LongTensor).to(device)
        target_domain = torch.ones(target_label.size()[0]).type(torch.LongTensor).to(device)
        
                
        source_private, source_shared, source_rec_image,source_pred_domain, source_pred_label = dsn(source_image,alpha,mode='source')
        loss_classify = loss_classification(source_pred_label,source_label)
        loss_source_diff = loss_diff(source_private,source_shared)
        loss_source_rec = loss_reconstruct(source_rec_image,source_image)
        loss_source_similarity = loss_similarity(source_pred_domain,source_domain)
        target_private, target_shared, target_rec_image,target_pred_domain, target_pred_label = dsn(target_image,alpha,mode='target')
        loss_target_diff = loss_diff(target_private,target_shared)
        loss_target_rec = loss_reconstruct(target_rec_image,target_image)
        loss_target_similarity = loss_similarity(target_pred_domain,target_domain)

        total_rec += (loss_source_rec + loss_target_rec)
        total_classify += (loss_classify)
        total_diff += (loss_source_diff+loss_target_diff)
        total_sim += (loss_source_similarity+loss_target_similarity)
        source_yp_labels = source_pred_label.data.max(1,keepdim=True)[1]
        s_cor += source_yp_labels.eq(source_label.data.view_as(source_yp_labels)).cpu().sum()

        target_yp_labels = target_pred_label.data.max(1,keepdim=True)[1]
        t_cor += target_yp_labels.eq(target_label.data.view_as(target_yp_labels)).cpu().sum()
        
        source_yp_domains = source_pred_domain.data.max(1,keepdim=True)[1]
        domain_cor += source_yp_domains.eq(source_domain.data.view_as(source_yp_domains)).cpu().sum()
        target_yp_domains = target_pred_domain.data.max(1,keepdim=True)[1]
        domain_cor += target_yp_domains.eq(target_domain.data.view_as(target_yp_domains)).cpu().sum()
    domain_acc = domain_cor.item()/(len(mnist_gen.dataset)+len(mnistm_gen.dataset))
    s_acc = s_cor.item()/len(mnist_gen.dataset)
    t_acc = t_cor.item()/len(mnistm_gen.dataset)
    rec_loss = total_rec.item()/(len(mnist_gen.dataset)+len(mnistm_gen.dataset))
    classify_loss = total_classify.item()/(len(mnist_gen.dataset)+len(mnistm_gen.dataset))
    diff_loss = total_diff.item()/(len(mnist_gen.dataset)+len(mnistm_gen.dataset))
    sim_loss = total_sim.item()/(len(mnist_gen.dataset)+len(mnistm_gen.dataset))
    return s_acc, t_acc, domain_acc, rec_loss, classify_loss, diff_loss, sim_loss

In [309]:
def DSNTrain(mnist_train, mnistm_train, mnist_eval, mnistm_eval, epochs=300,intervals=30,rec_weight = 0.5, diff_weight=0.075, sim_weight=1):
    dsn = DSN().to(device)
    #for p in dsn.parameters():
    #    p.requires_grad = True

    optimizer = optim.SGD(
        list(dsn.parameters()),lr=0.01)
    
    loss_classification = torch.nn.CrossEntropyLoss().to(device)
    loss_reconstruct = nn.MSELoss().to(device)
    loss_diff = DiffLoss().to(device)
    loss_similarity = torch.nn.CrossEntropyLoss().to(device)
    total_steps = epochs* len(mnistm_train)
    
    domain_accs =[]
    source_accs = []
    target_accs = []
    rec_losses = []
    
    for epoch in range(epochs):
        start_steps = epoch * len(mnist_train)
        for batch_idx, (source, target) in enumerate(zip(mnist_train, mnistm_train)):
            if (batch_idx >100):
                break
            p = float(batch_idx + start_steps)/ total_steps
            alpha = 2. / (1.+np.exp(-10*p))-1
            for param_group in optimizer.param_groups:
                param_group['lr'] = 0.01/(1.+10*p)**0.75
            source_image, source_label = source
            target_image, target_label = target

            source_image, source_label = source_image.to(device), source_label.to(device)
            target_image, target_label = target_image.to(device), target_label.to(device)
            source_image = source_image.expand(source_image.data.shape[0],3,28,28)
            
            optimizer.zero_grad()
            
            source_domain = torch.zeros(source_label.size()[0]).type(torch.LongTensor).to(device)
            target_domain = torch.ones(target_label.size()[0]).type(torch.LongTensor).to(device)
            
            
            source_private, source_shared, source_rec_image,source_pred_domain, source_pred_label = dsn(source_image,alpha,mode='source')
            loss_classify = loss_classification(source_pred_label,source_label)
            loss_source_diff = loss_diff(source_private,source_shared)
            loss_source_rec = loss_reconstruct(source_rec_image,source_image)
            loss_source_similarity = loss_similarity(source_pred_domain,source_domain)
            target_private, target_shared, target_rec_image,target_pred_domain, target_pred_label = dsn(target_image,alpha,mode='target')
            loss_target_diff = loss_diff(target_private,target_shared)
            loss_target_rec = loss_reconstruct(target_rec_image,target_image)
            loss_target_similarity = loss_similarity(target_pred_domain,target_domain)
         
            total_loss = (#(rec_weight*(loss_source_rec+loss_target_rec) +
                        #diff_weight*(loss_source_diff+loss_target_diff) + 
                        sim_weight*(loss_source_similarity+loss_target_similarity) +
                        loss_classify)
            total_loss.backward()
            optimizer.step()
        if (epoch+1) % intervals == 0:
            with torch.no_grad():
                source_acc, target_acc, domain_acc,rec_loss,classify_loss,diff_loss,sim_loss = DSNAccuracy(dsn, mnist_eval, mnistm_eval)
                domain_accs.append(domain_acc)
                source_accs.append(source_acc)
                target_accs.append(target_acc)
                rec_losses.append(rec_loss)
                print(f'{epoch+1}/{epochs}: source_acc: {source_acc},target_acc: {target_acc}, domain_acc: {domain_acc},rec_loss: {rec_loss}')
    return source_accs, target_accs, domain_accs, rec_losses, dsn

In [310]:
source_accs, target_accs, domain_accs, rec_losses, dsn = DSNTrain(mnist_train_loader,
                                                                 mnistm_train_loader,
                                                                 mnist_eval_loader,
                                                                 mnistm_eval_loader,intervals=1)

1/300: source_acc: 0.6044166666666667,target_acc: 0.2862, domain_acc: 0.5257916666666667,rec_loss: 0.0003552673657735189
2/300: source_acc: 0.7393166666666666,target_acc: 0.3576166666666667, domain_acc: 0.5427416666666667,rec_loss: 0.00035534451802571613
3/300: source_acc: 0.8172333333333334,target_acc: 0.38763333333333333, domain_acc: 0.5565333333333333,rec_loss: 0.00035541454950968425
4/300: source_acc: 0.85765,target_acc: 0.40671666666666667, domain_acc: 0.562225,rec_loss: 0.00035546477635701497
5/300: source_acc: 0.8814166666666666,target_acc: 0.41646666666666665, domain_acc: 0.5696666666666667,rec_loss: 0.0003556041717529297
6/300: source_acc: 0.8983666666666666,target_acc: 0.4325333333333333, domain_acc: 0.575275,rec_loss: 0.00035575440724690754
7/300: source_acc: 0.9101666666666667,target_acc: 0.4396833333333333, domain_acc: 0.5793416666666666,rec_loss: 0.00035588728586832683
8/300: source_acc: 0.9182666666666667,target_acc: 0.44966666666666666, domain_acc: 0.58005,rec_loss: 0.0

KeyboardInterrupt: 

In [215]:
def GridSearch():
    rec_weights = [0.01,0.05,0.1,0.3,0.6,0.9,1]
    diff_weights = [0.01,0.05,0.1,0.3,0.6,0.9,1]
    for rec_weight in rec_weights:
        for diff_weight in diff_weights:
            source_accs, target_accs, domain_accs, rec_losses, dsn = DSNTrain(mnist_train_loader,
                                                                 mnistm_train_loader,
                                                                 mnist_eval_loader,
                                                                 mnistm_eval_loader,intervals=300,rec_weight = rec_weight, diff_weight=diff_weight)
            print(f"rec_weight: {rec_weight}, diff_weight: {diff_weight}, target_accs: {mean(target_accs)}, domain_acc:{mean(domain_accs)},rec_loss: {mean(rec_losses)}")

In [216]:
GridSearch()

rec_weight: 0.01, diff_weight: 0.01, target_accs: 0.6619333333333334, domain_acc:0.5521833333333334,rec_loss: 0.00012644611994425456
rec_weight: 0.01, diff_weight: 0.05, target_accs: 0.5970333333333333, domain_acc:0.5671583333333333,rec_loss: 0.00012849899133046469
rec_weight: 0.01, diff_weight: 0.1, target_accs: 0.4876, domain_acc:0.5565833333333333,rec_loss: 0.00012662607034047446
rec_weight: 0.01, diff_weight: 0.3, target_accs: 0.17886666666666667, domain_acc:0.5096166666666667,rec_loss: 0.00011933932304382324
rec_weight: 0.01, diff_weight: 0.6, target_accs: 0.26918333333333333, domain_acc:0.5120166666666667,rec_loss: 0.00011965169111887614
rec_weight: 0.01, diff_weight: 0.9, target_accs: 0.19748333333333334, domain_acc:0.5104,rec_loss: 0.00012323212623596193
rec_weight: 0.01, diff_weight: 1, target_accs: 0.167, domain_acc:0.5031,rec_loss: 0.00012214375336964926
rec_weight: 0.05, diff_weight: 0.01, target_accs: 0.6818666666666666, domain_acc:0.5459916666666667,rec_loss: 0.0001071224

In [218]:
rec_weights = [0.05,0.075,0.01]
sim_weights = [0.1,0.5,1]
for rec_weight in rec_weights:
    for sim_weight in sim_weights:
        source_accs, target_accs, domain_accs, rec_losses, dsn = DSNTrain(mnist_train_loader,
                                                                 mnistm_train_loader,
                                                                 mnist_eval_loader,
                                                                 mnistm_eval_loader,intervals=300,rec_weight = rec_weight, sim_weight=sim_weight)
        print(f"rec_weight: {rec_weight}, sim_weight: {sim_weight}, target_accs: {mean(target_accs)}, domain_acc:{mean(domain_accs)},rec_loss: {mean(rec_losses)}")

rec_weight: 0.05, sim_weight: 0.1, target_accs: 0.52545, domain_acc:0.5838166666666667,rec_loss: 0.00010751407146453858
rec_weight: 0.05, sim_weight: 0.5, target_accs: 0.5583833333333333, domain_acc:0.5787333333333333,rec_loss: 0.00010754357973734538
rec_weight: 0.05, sim_weight: 1, target_accs: 0.5898666666666667, domain_acc:0.5696333333333333,rec_loss: 0.00010700028737386068
rec_weight: 0.075, sim_weight: 0.1, target_accs: 0.4706, domain_acc:0.58465,rec_loss: 0.00010387233893076579
rec_weight: 0.075, sim_weight: 0.5, target_accs: 0.5092166666666667, domain_acc:0.568125,rec_loss: 0.00010423275629679362
rec_weight: 0.075, sim_weight: 1, target_accs: 0.45458333333333334, domain_acc:0.5474583333333334,rec_loss: 0.0001055508534113566
rec_weight: 0.01, sim_weight: 0.1, target_accs: 0.45531666666666665, domain_acc:0.5716666666666667,rec_loss: 0.00012778352896372478
rec_weight: 0.01, sim_weight: 0.5, target_accs: 0.5272333333333333, domain_acc:0.5715583333333333,rec_loss: 0.00012677634557088