In [1]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np
from utils import *
from torch import optim
from torch.functional import F
# from torch.utils import tensorboard


In [2]:
np.random.seed(1337)
torch.manual_seed(1337)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
BATCH_SIZE = 128
lr = 2e-4
EPOCHS = 100
N = 4
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [3]:
source_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5, 0.5), (0.5,0.5, 0.5))
])
target_transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.Grayscale(num_output_channels=3), 
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5, 0.5), (0.5,0.5, 0.5))
])

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.generator = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2), #28
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1), #13
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, padding = 2), #9
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1), #4
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, padding = 2), #128x1
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        
        self.linear = nn.Sequential(
            nn.Linear(8192,3072),
            nn.BatchNorm1d(3072),
            nn.ReLU(True),
            nn.Dropout()
        )
    
    def forward(self, x):
        out = self.generator(x)
        out = out.view(out.size(0), 8192)
        out = self.linear(out)
        return out

In [6]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(3072,2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(True),
            nn.Linear(2048,10)
        )

    def forward(self, x):
        out = self.classifier(x)
        return out

### Xavier Weights Initialization

In [7]:
def init_weights(m):
    if type(m) == nn.Conv2d:
        nn.init.xavier_normal(m.weight)
    if type(m) == nn.Linear:
        nn.init.xavier_normal(m.weight)

In [8]:
G = Generator().apply(init_weights).to(device)
C1 = Classifier().apply(init_weights).to(device)
C2 = Classifier().apply(init_weights).to(device)

### Defining optimizers

In [9]:
optimizer_G = optim.Adam(G.parameters(),lr=lr, weight_decay=5e-4)
optimizer_C1 = optim.Adam(C1.parameters(), lr=lr,weight_decay=5e-4)
optimizer_C2 = optim.Adam(C2.parameters(), lr=lr,weight_decay=5e-4)
log_interval = 100
criterion = nn.CrossEntropyLoss()

In [10]:
writter = tensorboard.SummaryWriter()

In [11]:
def discrepancy(out1, out2):
    return torch.mean(torch.abs(F.softmax(out1)- F.softmax(out2)))

def zero_grad():
    optimizer_G.zero_grad()
    optimizer_C1.zero_grad()
    optimizer_C2.zero_grad()

### Train / Test

In [12]:
def train():
    G.train()
    C1.train()
    C2.train()
    
    len_dataloader = min(len(source_train), len(target_train))
    data_zip = enumerate(zip(source_train, target_train))
    for batch_idx, ((img_s, labels),(img_t,_)) in data_zip:
        img_s = img_s.to(device)
        labels = labels.to(device)
        img_t = img_t.to(device)
        
        zero_grad()
        feature_s = G(img_s)
        out_s1 = C1(feature_s)
        out_s2 = C2(feature_s)
        
        loss_s1 = criterion(out_s1, labels)
        loss_s2 = criterion(out_s2, labels)
        loss_s = loss_s1 + loss_s2
        
        loss_s.backward()
        optimizer_G.step()
        optimizer_C1.step()
        optimizer_C2.step()
        
        zero_grad()
        feature_s = G(img_s)
        out_s1 = C1(feature_s)
        out_s2 = C2(feature_s)
        feature_t = G(img_t)
        out_t1 = C1(feature_t)
        out_t2 = C2(feature_t)
        
        loss_s1 = criterion(out_s1, labels)
        loss_s2 = criterion(out_s2, labels)
        loss_s = loss_s1 + loss_s2
        loss_discrepancy = discrepancy(out_t1, out_t2)
        loss = loss_s - loss_discrepancy
        
        
        loss.backward()
        optimizer_C1.step()
        optimizer_C2.step()
        
        zero_grad()
        
        
        for i in range(N):
            feature_t = G(img_t)
            out_t1 = C1(feature_t)
            out_t2 = C2(feature_t)
            loss_discrepancy = discrepancy(out_t1, out_t2)
            loss_discrepancy.backward()
            optimizer_G.step()
            zero_grad()
        
        if (batch_idx+1) % log_interval == 0:
            print("Epoch: {}/{} [{}/{}]: C1 Source Loss ={:.5f}, C2 Source Loss={:.5f}, Discrepancy Loss={:.5f}"
                 .format(epoch+1, EPOCHS, batch_idx + 1, len_dataloader, loss_s1.item(), loss_s2.item(),loss_discrepancy.item()))
        
        if batch_idx == 20:
            writter.add_scalar('Train/Loss C1 Source',loss_s1.item(), epoch)
            writter.add_scalar('Train/Loss C2 Source',loss_s2.item(), epoch)
            writter.add_scalar('Train/Loss Discrepancy',loss_discrepancy.item(), epoch)
        

In [13]:
def test(max):
    G.eval()
    C1.eval()
    C2.eval()
    test_loss = 0
    correct = 0
    for (imgs, labels) in target_val:
        imgs = imgs.to(device)
        labels = labels.to(device)
        features = G(imgs)
        output1 = C1(features)
        output2 = C2(features)
        output_of_two = output1 + output2
        pred_of_two = output_of_two.data.max(1)[1]
        correct += pred_of_two.eq(labels.data).cpu().sum()
    
    size = len(target_val.dataset)
    acc_mnist = 100. * correct / size
    if acc_mnist > max:
        max = acc_mnist
        torch.save(G.state_dict(),'Weights/MCD/Generator_epoch{}_{:.2f}.pth'.format(epoch,acc_mnist))
        torch.save(C1.state_dict(),'Weights/MCD/C1_epoch{}_{:.2f}.pth'.format(epoch,acc_mnist))
        torch.save(C2.state_dict(),'Weights/MCD/C2_epoch{}_{:.2f}.pth'.format(epoch,acc_mnist))
    
    correct = 0
    for (imgs, labels) in source_train:
        imgs = imgs.to(device)
        labels = labels.to(device)
        features = G(imgs)
        output1 = C1(features)
        output2 = C2(features)
        output_of_two = output1 + output2
        pred_of_two = output_of_two.data.max(1)[1]
        correct += pred_of_two.eq(labels.data).cpu().sum()
    
    size = len(source_train.dataset)
    acc_svhn_train = 100. * correct / size
    
    correct = 0
    for (imgs, labels) in source_val:
        imgs = imgs.to(device)
        labels = labels.to(device)
        features = G(imgs)
        output1 = C1(features)
        output2 = C2(features)
        output_of_two = output1 + output2
        pred_of_two = output_of_two.data.max(1)[1]
        correct += pred_of_two.eq(labels.data).cpu().sum()
    
    size = len(source_val.dataset)
    acc_svhn_test = 100. * correct / size
        
    print(
        '\nTest: Accuracy on MNIST: {:.0f}% Accuracy on SVHN train: {:.0f}% Accuracy on SVHN test: {:.0f}% \n'.format(
            acc_mnist, acc_svhn_train, acc_svhn_test))
    writter.add_scalar('Test/Accuracy MNIST Test', acc_mnist, epoch)
    writter.add_scalar('Test/Accuracy SVHN Train', acc_svhn_train, epoch)
    writter.add_scalar('Test/Accuracy SVHN Test', acc_svhn_test, epoch)
    writter.flush()
    return max

In [14]:
max = 0
for epoch in range(EPOCHS):
    train()
    max = test(max)

Epoch: 1/100 [100/469]: C1 Source Loss =0.42135, C2 Source Loss=0.41960, Discrepancy Loss=0.01512
Epoch: 1/100 [200/469]: C1 Source Loss =0.30488, C2 Source Loss=0.30933, Discrepancy Loss=0.00883
Epoch: 1/100 [300/469]: C1 Source Loss =0.39869, C2 Source Loss=0.38143, Discrepancy Loss=0.00825
Epoch: 1/100 [400/469]: C1 Source Loss =0.47868, C2 Source Loss=0.47859, Discrepancy Loss=0.00660

Test: Accuracy on MNIST: 76% Accuracy on SVHN train: 62% Accuracy on SVHN test: 61% 

Epoch: 2/100 [100/469]: C1 Source Loss =0.34077, C2 Source Loss=0.35216, Discrepancy Loss=0.00695
Epoch: 2/100 [200/469]: C1 Source Loss =0.42092, C2 Source Loss=0.42545, Discrepancy Loss=0.00463
Epoch: 2/100 [300/469]: C1 Source Loss =0.32008, C2 Source Loss=0.32773, Discrepancy Loss=0.00407
Epoch: 2/100 [400/469]: C1 Source Loss =0.27736, C2 Source Loss=0.27654, Discrepancy Loss=0.00591

Test: Accuracy on MNIST: 76% Accuracy on SVHN train: 69% Accuracy on SVHN test: 69% 

Epoch: 3/100 [100/469]: C1 Source Loss =0.

## t-SNE Embeddings visualization

Graph of embeddings was built using TensorBoard

Before DA (Pre-trained on SVHN)                  |  After DA
:-----------------------------------------------:|:-------------------------:
![](imgs/t-SNE_bfDA.png)                         |  ![](imgs/t-SNE_MCD.jpg) 