In [1]:
import torch
import torchvision
import torch.nn as nn

from pytorch_revgrad import RevGrad

import copy
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from torchvision.datasets import MNIST
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Compose, ToTensor

from pathlib import Path
from PIL import Image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


### Params

In [2]:
# Specifiy MNIST and BSDS500 Data Path

DATA_PATH = '/project/GutIntelligenceLab/ys5hd/MNIST/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Data

In [3]:
class GrayscaleToRgb:
    """Convert a grayscale image to rgb"""
    def __call__(self, image):
        image = np.array(image)
        image = np.dstack([image, image, image])
        return Image.fromarray(image)


class BSDS500(Dataset):
    def __init__(self):
        image_folder = Path(DATA_PATH+'BSDS500/BSR/BSDS500/data/images')
        self.image_files = list(map(str, image_folder.glob('*/*.jpg')))

    def __getitem__(self, i):
        image = cv2.imread(self.image_files[i], cv2.IMREAD_COLOR)
        tensor = torch.from_numpy(image.transpose(2, 0, 1))
        return tensor

    def __len__(self):
        return len(self.image_files)


class MNISTM(Dataset):

    def __init__(self, train=True):
        super(MNISTM, self).__init__()
        self.mnist = datasets.MNIST(DATA_PATH, train=train,
                                    download=False)
        self.bsds = BSDS500()
        # Fix RNG so the same images are used for blending
        self.rng = np.random.RandomState(42)

    def __getitem__(self, i):
        digit, label = self.mnist[i]
        digit = transforms.ToTensor()(digit)
        bsds_image = self._random_bsds_image()
        patch = self._random_patch(bsds_image)
        patch = patch.float() / 255
        blend = torch.abs(patch - digit)
        return blend, label

    def _random_patch(self, image, size=(28, 28)):
        _, im_height, im_width = image.shape
        x = self.rng.randint(0, im_width-size[1])
        y = self.rng.randint(0, im_height-size[0])
        return image[:, y:y+size[0], x:x+size[1]]

    def _random_bsds_image(self):
        i = self.rng.choice(len(self.bsds))
        return self.bsds[i]

    def __len__(self):
        return len(self.mnist)

### Model

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # CNN Feature Extractor
        self.fc_extractor = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Dropout(0.25)
        )
        
        
        # Classifier
        self.classifier = nn.Sequential(
             nn.Linear(6400, 32),
             nn.ReLU(),
             nn.Dropout(0.5),
             nn.Linear(32, 10),
        )
        
        # Domain Classifier
        self.domain_classifier = nn.Sequential(
             RevGrad(),
             nn.Linear(6400, 32),
             nn.ReLU(),
             nn.Linear(32, 1),
        )
        
        # Projection Layer
        self.domain_projection = nn.Sequential(
            nn.Linear(6400, 64),
            nn.ReLU(),
        )


    def forward(self, x):
        features = self.fc_extractor(x)
        features = features.view(x.shape[0], -1)
        domain_projection = self.domain_projection(features)
        prediction_class = self.classifier(features)
        prediction_domain = self.domain_classifier(features)
        return prediction_class, prediction_domain, domain_projection

### Data Loader

In [5]:
BATCH_SIZE = 64

In [6]:
source_dataset = MNIST(DATA_PATH, train=True, download=False,
                      transform=Compose([GrayscaleToRgb(), ToTensor()]))
source_loader = DataLoader(source_dataset, batch_size=BATCH_SIZE,
                           shuffle=True, drop_last=True)

source_dataset_test = MNIST(DATA_PATH, train=False, download=False,
                      transform=Compose([GrayscaleToRgb(), ToTensor()]))
source_loader_test = DataLoader(source_dataset_test, batch_size=BATCH_SIZE,
                           shuffle=False)

target_dataset = MNISTM(train=True)
target_loader = DataLoader(target_dataset, batch_size=BATCH_SIZE,
                           shuffle=True, drop_last=True)

target_dataset_test = MNISTM(train=False)
target_loader_test = DataLoader(target_dataset_test, batch_size=BATCH_SIZE,
                           shuffle=False)

### Base Model Training

In [7]:
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [8]:
epochs = 10
for epoch in range(epochs):
    running_loss = 0

    model.train()
    for images, labels in tqdm(source_loader):
        
        # Cuda
        images = images.to(device)
        labels = labels.to(device)
        
        # Training pass
        optimizer.zero_grad()
        
        output, _, _ = model(images)
        loss = criterion(output.squeeze(), labels)
        
        #This is where the model learns by backpropagating
        loss.backward()
        
        #And optimizes its weights here
        optimizer.step()
        
        running_loss += loss.item()

    print("Epoch {} - Training loss: {}".format(epoch, running_loss/len(source_dataset)))
    
    model.eval()
    correct_prediction = 0
    for images, labels in tqdm(source_loader_test):

        # Cuda
        images = images.to(device)
        labels = labels.to(device)
        
        with torch.no_grad():
            prediction, _, _ = model(images)
            correct_prediction += sum(torch.max(prediction.squeeze(), 1)[1]==labels).item()
    print("Epoch {} - Validation Accuracy: {}".format(epoch, correct_prediction/len(source_dataset_test)))

100%|██████████| 937/937 [00:08<00:00, 108.21it/s]


Epoch 0 - Training loss: 0.008400399645666282


100%|██████████| 157/157 [00:01<00:00, 130.27it/s]


Epoch 0 - Validation Accuracy: 0.9807


100%|██████████| 937/937 [00:08<00:00, 109.63it/s]


Epoch 1 - Training loss: 0.004837346463650465


100%|██████████| 157/157 [00:01<00:00, 115.38it/s]


Epoch 1 - Validation Accuracy: 0.9862


100%|██████████| 937/937 [00:08<00:00, 107.80it/s]


Epoch 2 - Training loss: 0.00403600363569955


100%|██████████| 157/157 [00:01<00:00, 131.47it/s]


Epoch 2 - Validation Accuracy: 0.988


100%|██████████| 937/937 [00:08<00:00, 109.11it/s]


Epoch 3 - Training loss: 0.0036848410570373136


100%|██████████| 157/157 [00:01<00:00, 116.45it/s]


Epoch 3 - Validation Accuracy: 0.9882


100%|██████████| 937/937 [00:09<00:00, 103.11it/s]


Epoch 4 - Training loss: 0.0034581702732170622


100%|██████████| 157/157 [00:01<00:00, 114.54it/s]


Epoch 4 - Validation Accuracy: 0.9904


100%|██████████| 937/937 [00:08<00:00, 109.29it/s]


Epoch 5 - Training loss: 0.00329492202103138


100%|██████████| 157/157 [00:01<00:00, 132.13it/s]


Epoch 5 - Validation Accuracy: 0.9909


100%|██████████| 937/937 [00:08<00:00, 113.07it/s]


Epoch 6 - Training loss: 0.0030344458530346553


100%|██████████| 157/157 [00:01<00:00, 117.04it/s]


Epoch 6 - Validation Accuracy: 0.9914


100%|██████████| 937/937 [00:08<00:00, 111.91it/s]


Epoch 7 - Training loss: 0.0027968465963068108


100%|██████████| 157/157 [00:01<00:00, 131.02it/s]


Epoch 7 - Validation Accuracy: 0.9919


100%|██████████| 937/937 [00:08<00:00, 116.19it/s]


Epoch 8 - Training loss: 0.0026944371844020982


100%|██████████| 157/157 [00:01<00:00, 132.25it/s]


Epoch 8 - Validation Accuracy: 0.992


100%|██████████| 937/937 [00:08<00:00, 116.43it/s]


Epoch 9 - Training loss: 0.002626434665483733


100%|██████████| 157/157 [00:01<00:00, 132.31it/s]

Epoch 9 - Validation Accuracy: 0.9913





In [9]:
PATH = 'DANN.pt'
torch.save(model.state_dict(), PATH)

#### MNIST-M Test Accuracy

In [10]:
model.eval()
correct_prediction = 0
for images, labels in tqdm(target_loader_test):
    
    # Cuda
    images = images.to(device)
    labels = labels.to(device)
        
    with torch.no_grad():
        prediction, _, _ = model(images)
        correct_prediction += sum(torch.max(prediction.squeeze(), 1)[1]==labels).item()
        
print('Test Accuracy: {}'.format(correct_prediction/len(target_dataset_test)))

100%|██████████| 157/157 [00:31<00:00,  4.95it/s]

Test Accuracy: 0.53





### DANN Model Training

In [17]:
model = Net().to(device)

criterion = nn.CrossEntropyLoss()
criterion_bce = nn.BCEWithLogitsLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [18]:
# PATH = 'DANN.pt'
# model.load_state_dict(torch.load(PATH))

In [19]:
domain_iter = iter(target_loader)

epochs = 40
for epoch in range(epochs):
    running_loss_class = 0
    running_loss_domain = 0    

    model.train()
    for images, labels in tqdm(source_loader):
        
        # Cuda
        images = images.to(device)
        labels = labels.to(device)
        
        try:
            images_domain, _ = next(domain_iter)
        except StopIteration:
            domain_iter = iter(target_loader)
            images_domain, _ = next(domain_iter)
        
        images_domain = images_domain.to(device)
        
        # Combine Images
        images = torch.cat((images, images_domain))
        
        # Domain Target
        target_domain = torch.cat((torch.ones(BATCH_SIZE, 1), torch.zeros(BATCH_SIZE, 1))).to(device)        
        
        # Training pass
        optimizer.zero_grad()
                
        output, output_domain, _ = model(images)

        loss_class = criterion(output[:BATCH_SIZE].squeeze(), labels)
        loss_bce = criterion_bce(output_domain, target_domain)
                     
        loss = loss_class + loss_bce
        
        #This is where the model learns by backpropagating
        loss.backward()
        
        #And optimizes its weights here
        optimizer.step()
        
        running_loss_class += loss_class.item()
        running_loss_domain += loss_bce.item()

    print("Epoch {} - Training loss Classifier: {}, Training loss Domain: {}"\
          .format(epoch, running_loss_class/len(source_dataset), running_loss_domain/(2*len(source_dataset))))
    
    model.eval()
    correct_prediction = 0
    for images, labels in tqdm(source_loader_test):

        # Cuda
        images = images.to(device)
        labels = labels.to(device)
        
        with torch.no_grad():
            prediction, _, _ = model(images)
            correct_prediction += sum(torch.max(prediction.squeeze(), 1)[1]==labels).item()
    print("Epoch {} - Validation Accuracy: {}".format(epoch, correct_prediction/len(source_dataset_test)))
    
    model.eval()
    correct_prediction = 0
    for images, labels in tqdm(target_loader_test):

        # Cuda
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            prediction, _, _ = model(images)
            correct_prediction += sum(torch.max(prediction.squeeze(), 1)[1]==labels).item()

    print('Test Accuracy: {}'.format(correct_prediction/len(target_dataset_test)))    

100%|██████████| 937/937 [03:24<00:00,  4.59it/s]


Epoch 0 - Training loss Classifier: 0.017157927534977593, Training loss Domain: 0.004095627647141616


100%|██████████| 157/157 [00:01<00:00, 131.65it/s]


Epoch 0 - Validation Accuracy: 0.9302


100%|██████████| 157/157 [00:32<00:00,  4.89it/s]


Test Accuracy: 0.5547


100%|██████████| 937/937 [03:24<00:00,  4.58it/s]


Epoch 1 - Training loss Classifier: 0.009538516893982888, Training loss Domain: 0.0035297243821124234


100%|██████████| 157/157 [00:01<00:00, 132.91it/s]


Epoch 1 - Validation Accuracy: 0.955


100%|██████████| 157/157 [00:32<00:00,  4.87it/s]


Test Accuracy: 0.5785


100%|██████████| 937/937 [03:27<00:00,  4.52it/s]


Epoch 2 - Training loss Classifier: 0.007879015571624041, Training loss Domain: 0.003305537870153785


100%|██████████| 157/157 [00:01<00:00, 132.17it/s]


Epoch 2 - Validation Accuracy: 0.9614


100%|██████████| 157/157 [00:32<00:00,  4.88it/s]


Test Accuracy: 0.683


100%|██████████| 937/937 [03:21<00:00,  4.64it/s]


Epoch 3 - Training loss Classifier: 0.00732632819438974, Training loss Domain: 0.0033115156778444847


100%|██████████| 157/157 [00:01<00:00, 134.32it/s]


Epoch 3 - Validation Accuracy: 0.9649


100%|██████████| 157/157 [00:32<00:00,  4.82it/s]


Test Accuracy: 0.6774


100%|██████████| 937/937 [03:18<00:00,  4.72it/s]


Epoch 4 - Training loss Classifier: 0.0069668562290569145, Training loss Domain: 0.0034416267986098923


100%|██████████| 157/157 [00:01<00:00, 132.76it/s]


Epoch 4 - Validation Accuracy: 0.9678


100%|██████████| 157/157 [00:31<00:00,  5.02it/s]


Test Accuracy: 0.6693


100%|██████████| 937/937 [03:18<00:00,  4.72it/s]


Epoch 5 - Training loss Classifier: 0.006699891543885072, Training loss Domain: 0.0035347421939174333


100%|██████████| 157/157 [00:01<00:00, 133.31it/s]


Epoch 5 - Validation Accuracy: 0.9698


100%|██████████| 157/157 [00:32<00:00,  4.87it/s]


Test Accuracy: 0.6412


100%|██████████| 937/937 [03:20<00:00,  4.68it/s]


Epoch 6 - Training loss Classifier: 0.006374832578003407, Training loss Domain: 0.0034236044793079295


100%|██████████| 157/157 [00:01<00:00, 133.66it/s]


Epoch 6 - Validation Accuracy: 0.9699


100%|██████████| 157/157 [00:31<00:00,  4.96it/s]


Test Accuracy: 0.5839


100%|██████████| 937/937 [03:21<00:00,  4.66it/s]


Epoch 7 - Training loss Classifier: 0.0064538358464837076, Training loss Domain: 0.0037384666539728643


100%|██████████| 157/157 [00:01<00:00, 132.64it/s]


Epoch 7 - Validation Accuracy: 0.971


100%|██████████| 157/157 [00:32<00:00,  4.88it/s]


Test Accuracy: 0.6358


100%|██████████| 937/937 [03:19<00:00,  4.70it/s]


Epoch 8 - Training loss Classifier: 0.006532121153920889, Training loss Domain: 0.003950370573624969


100%|██████████| 157/157 [00:01<00:00, 133.23it/s]


Epoch 8 - Validation Accuracy: 0.9682


100%|██████████| 157/157 [00:31<00:00,  4.95it/s]


Test Accuracy: 0.6174


100%|██████████| 937/937 [03:20<00:00,  4.67it/s]


Epoch 9 - Training loss Classifier: 0.006512194087107976, Training loss Domain: 0.003542769214883447


100%|██████████| 157/157 [00:01<00:00, 133.96it/s]


Epoch 9 - Validation Accuracy: 0.97


100%|██████████| 157/157 [00:31<00:00,  5.00it/s]


Test Accuracy: 0.5818


100%|██████████| 937/937 [03:17<00:00,  4.74it/s]


Epoch 10 - Training loss Classifier: 0.006448134218653043, Training loss Domain: 0.003930688262730837


100%|██████████| 157/157 [00:01<00:00, 134.52it/s]


Epoch 10 - Validation Accuracy: 0.9736


100%|██████████| 157/157 [00:31<00:00,  4.94it/s]


Test Accuracy: 0.5894


100%|██████████| 937/937 [03:21<00:00,  4.65it/s]


Epoch 11 - Training loss Classifier: 0.006182883614301682, Training loss Domain: 0.003567539728929599


100%|██████████| 157/157 [00:01<00:00, 132.60it/s]


Epoch 11 - Validation Accuracy: 0.9729


100%|██████████| 157/157 [00:31<00:00,  4.93it/s]


Test Accuracy: 0.6105


100%|██████████| 937/937 [03:25<00:00,  4.56it/s]


Epoch 12 - Training loss Classifier: 0.006198473986983299, Training loss Domain: 0.0037087826754897833


100%|██████████| 157/157 [00:01<00:00, 132.41it/s]


Epoch 12 - Validation Accuracy: 0.9712


100%|██████████| 157/157 [00:31<00:00,  4.94it/s]


Test Accuracy: 0.5369


100%|██████████| 937/937 [03:21<00:00,  4.64it/s]


Epoch 13 - Training loss Classifier: 0.006188803090403478, Training loss Domain: 0.003782974289357662


100%|██████████| 157/157 [00:01<00:00, 133.50it/s]


Epoch 13 - Validation Accuracy: 0.9717


100%|██████████| 157/157 [00:32<00:00,  4.89it/s]


Test Accuracy: 0.6179


100%|██████████| 937/937 [03:20<00:00,  4.67it/s]


Epoch 14 - Training loss Classifier: 0.006084649345030387, Training loss Domain: 0.003623421535268426


100%|██████████| 157/157 [00:01<00:00, 134.72it/s]


Epoch 14 - Validation Accuracy: 0.9716


100%|██████████| 157/157 [00:31<00:00,  4.92it/s]


Test Accuracy: 0.5772


100%|██████████| 937/937 [03:20<00:00,  4.67it/s]


Epoch 15 - Training loss Classifier: 0.006117511348426342, Training loss Domain: 0.004027530299375455


100%|██████████| 157/157 [00:01<00:00, 133.95it/s]


Epoch 15 - Validation Accuracy: 0.9751


100%|██████████| 157/157 [00:32<00:00,  4.87it/s]


Test Accuracy: 0.6478


100%|██████████| 937/937 [03:19<00:00,  4.71it/s]


Epoch 16 - Training loss Classifier: 0.005991038441409667, Training loss Domain: 0.0037869250155985355


100%|██████████| 157/157 [00:01<00:00, 133.64it/s]


Epoch 16 - Validation Accuracy: 0.9765


100%|██████████| 157/157 [00:31<00:00,  4.98it/s]


Test Accuracy: 0.6666


100%|██████████| 937/937 [03:21<00:00,  4.66it/s]


Epoch 17 - Training loss Classifier: 0.005929243542005618, Training loss Domain: 0.004059640913332502


100%|██████████| 157/157 [00:01<00:00, 134.21it/s]


Epoch 17 - Validation Accuracy: 0.9746


100%|██████████| 157/157 [00:32<00:00,  4.78it/s]


Test Accuracy: 0.6512


100%|██████████| 937/937 [03:21<00:00,  4.65it/s]


Epoch 18 - Training loss Classifier: 0.005941567930579186, Training loss Domain: 0.003972113780677319


100%|██████████| 157/157 [00:01<00:00, 134.16it/s]


Epoch 18 - Validation Accuracy: 0.9766


100%|██████████| 157/157 [00:31<00:00,  4.94it/s]


Test Accuracy: 0.5973


100%|██████████| 937/937 [03:23<00:00,  4.60it/s]


Epoch 19 - Training loss Classifier: 0.005969757581253846, Training loss Domain: 0.004176823792234063


100%|██████████| 157/157 [00:01<00:00, 131.87it/s]


Epoch 19 - Validation Accuracy: 0.9748


100%|██████████| 157/157 [00:33<00:00,  4.76it/s]


Test Accuracy: 0.5635


100%|██████████| 937/937 [03:27<00:00,  4.50it/s]


Epoch 20 - Training loss Classifier: 0.00583722693870465, Training loss Domain: 0.0037830614605297644


100%|██████████| 157/157 [00:01<00:00, 130.02it/s]


Epoch 20 - Validation Accuracy: 0.9764


100%|██████████| 157/157 [00:32<00:00,  4.76it/s]


Test Accuracy: 0.5823


100%|██████████| 937/937 [03:26<00:00,  4.53it/s]


Epoch 21 - Training loss Classifier: 0.006054334650437037, Training loss Domain: 0.004142735792075594


100%|██████████| 157/157 [00:01<00:00, 129.98it/s]


Epoch 21 - Validation Accuracy: 0.9737


100%|██████████| 157/157 [00:32<00:00,  4.88it/s]


Test Accuracy: 0.5317


100%|██████████| 937/937 [03:24<00:00,  4.57it/s]


Epoch 22 - Training loss Classifier: 0.005920306529725592, Training loss Domain: 0.004181469435493151


100%|██████████| 157/157 [00:01<00:00, 131.21it/s]


Epoch 22 - Validation Accuracy: 0.976


100%|██████████| 157/157 [00:32<00:00,  4.82it/s]


Test Accuracy: 0.6738


100%|██████████| 937/937 [03:23<00:00,  4.61it/s]


Epoch 23 - Training loss Classifier: 0.0059296703211963175, Training loss Domain: 0.004439557929585378


100%|██████████| 157/157 [00:01<00:00, 129.76it/s]


Epoch 23 - Validation Accuracy: 0.9767


100%|██████████| 157/157 [00:32<00:00,  4.79it/s]


Test Accuracy: 0.6703


100%|██████████| 937/937 [03:26<00:00,  4.53it/s]


Epoch 24 - Training loss Classifier: 0.005847904785970847, Training loss Domain: 0.004118964093675216


100%|██████████| 157/157 [00:01<00:00, 131.97it/s]


Epoch 24 - Validation Accuracy: 0.9748


100%|██████████| 157/157 [00:33<00:00,  4.71it/s]


Test Accuracy: 0.5905


100%|██████████| 937/937 [03:26<00:00,  4.54it/s]


Epoch 25 - Training loss Classifier: 0.0057290376782417295, Training loss Domain: 0.0041535775644083815


100%|██████████| 157/157 [00:01<00:00, 131.29it/s]


Epoch 25 - Validation Accuracy: 0.9757


100%|██████████| 157/157 [00:33<00:00,  4.73it/s]


Test Accuracy: 0.5787


100%|██████████| 937/937 [03:24<00:00,  4.57it/s]


Epoch 26 - Training loss Classifier: 0.005567679720371962, Training loss Domain: 0.004132259136935075


100%|██████████| 157/157 [00:01<00:00, 130.90it/s]


Epoch 26 - Validation Accuracy: 0.9758


100%|██████████| 157/157 [00:32<00:00,  4.83it/s]


Test Accuracy: 0.5679


100%|██████████| 937/937 [03:21<00:00,  4.64it/s]


Epoch 27 - Training loss Classifier: 0.005497017104985813, Training loss Domain: 0.004079642495512963


100%|██████████| 157/157 [00:01<00:00, 129.89it/s]


Epoch 27 - Validation Accuracy: 0.9762


100%|██████████| 157/157 [00:32<00:00,  4.86it/s]


Test Accuracy: 0.5793


100%|██████████| 937/937 [03:17<00:00,  4.74it/s]


Epoch 28 - Training loss Classifier: 0.00552937698289752, Training loss Domain: 0.0042425764846305055


100%|██████████| 157/157 [00:01<00:00, 132.66it/s]


Epoch 28 - Validation Accuracy: 0.9785


100%|██████████| 157/157 [00:30<00:00,  5.14it/s]


Test Accuracy: 0.6543


100%|██████████| 937/937 [03:19<00:00,  4.69it/s]


Epoch 29 - Training loss Classifier: 0.005451229974006613, Training loss Domain: 0.00438250158727169


100%|██████████| 157/157 [00:01<00:00, 129.07it/s]


Epoch 29 - Validation Accuracy: 0.9775


100%|██████████| 157/157 [00:32<00:00,  4.85it/s]


Test Accuracy: 0.6467


100%|██████████| 937/937 [03:24<00:00,  4.59it/s]


Epoch 30 - Training loss Classifier: 0.00530622064769268, Training loss Domain: 0.003651499014099439


100%|██████████| 157/157 [00:01<00:00, 130.19it/s]


Epoch 30 - Validation Accuracy: 0.9807


100%|██████████| 157/157 [00:34<00:00,  4.50it/s]


Test Accuracy: 0.6135


100%|██████████| 937/937 [03:23<00:00,  4.59it/s]


Epoch 31 - Training loss Classifier: 0.005169641302277645, Training loss Domain: 0.0035906206528345743


100%|██████████| 157/157 [00:01<00:00, 129.68it/s]


Epoch 31 - Validation Accuracy: 0.9801


100%|██████████| 157/157 [00:32<00:00,  4.81it/s]


Test Accuracy: 0.6586


100%|██████████| 937/937 [03:27<00:00,  4.52it/s]


Epoch 32 - Training loss Classifier: 0.005016365986814102, Training loss Domain: 0.003456865968927741


100%|██████████| 157/157 [00:01<00:00, 130.57it/s]


Epoch 32 - Validation Accuracy: 0.9805


100%|██████████| 157/157 [00:33<00:00,  4.75it/s]


Test Accuracy: 0.6342


100%|██████████| 937/937 [03:26<00:00,  4.53it/s]


Epoch 33 - Training loss Classifier: 0.005078527855376403, Training loss Domain: 0.0034230374691387017


100%|██████████| 157/157 [00:01<00:00, 131.16it/s]


Epoch 33 - Validation Accuracy: 0.9813


100%|██████████| 157/157 [00:31<00:00,  4.97it/s]


Test Accuracy: 0.6517


100%|██████████| 937/937 [03:20<00:00,  4.68it/s]


Epoch 34 - Training loss Classifier: 0.005060752120365699, Training loss Domain: 0.0035185653245697417


100%|██████████| 157/157 [00:01<00:00, 131.80it/s]


Epoch 34 - Validation Accuracy: 0.9808


100%|██████████| 157/157 [00:33<00:00,  4.62it/s]


Test Accuracy: 0.6381


100%|██████████| 937/937 [03:21<00:00,  4.65it/s]


Epoch 35 - Training loss Classifier: 0.005045669099316001, Training loss Domain: 0.003195984027038018


100%|██████████| 157/157 [00:01<00:00, 130.11it/s]


Epoch 35 - Validation Accuracy: 0.9822


100%|██████████| 157/157 [00:31<00:00,  4.93it/s]


Test Accuracy: 0.6442


100%|██████████| 937/937 [03:23<00:00,  4.61it/s]


Epoch 36 - Training loss Classifier: 0.004931704368814826, Training loss Domain: 0.003497992189228535


100%|██████████| 157/157 [00:01<00:00, 132.06it/s]


Epoch 36 - Validation Accuracy: 0.9805


100%|██████████| 157/157 [00:32<00:00,  4.89it/s]


Test Accuracy: 0.6879


100%|██████████| 937/937 [03:21<00:00,  4.66it/s]


Epoch 37 - Training loss Classifier: 0.005061051781848073, Training loss Domain: 0.003340611740325888


100%|██████████| 157/157 [00:01<00:00, 130.07it/s]


Epoch 37 - Validation Accuracy: 0.9826


100%|██████████| 157/157 [00:32<00:00,  4.83it/s]


Test Accuracy: 0.762


100%|██████████| 937/937 [03:17<00:00,  4.74it/s]


Epoch 38 - Training loss Classifier: 0.004912082757552465, Training loss Domain: 0.0035657565449674924


100%|██████████| 157/157 [00:01<00:00, 132.87it/s]


Epoch 38 - Validation Accuracy: 0.9816


100%|██████████| 157/157 [00:31<00:00,  5.06it/s]


Test Accuracy: 0.7229


100%|██████████| 937/937 [03:14<00:00,  4.81it/s]


Epoch 39 - Training loss Classifier: 0.005008983059848348, Training loss Domain: 0.0037227531768381597


100%|██████████| 157/157 [00:01<00:00, 132.70it/s]


Epoch 39 - Validation Accuracy: 0.9823


100%|██████████| 157/157 [00:31<00:00,  5.02it/s]

Test Accuracy: 0.7178





#### MNIST-M Test Accuracy

In [20]:
model.eval()
correct_prediction = 0
for images, labels in tqdm(target_loader_test):
    
    # Cuda
    images = images.to(device)
    labels = labels.to(device)
        
    with torch.no_grad():
        prediction, _, _ = model(images)
        correct_prediction += sum(torch.max(prediction.squeeze(), 1)[1]==labels).item()
        
print('Test Accuracy: {}'.format(correct_prediction/len(target_dataset_test)))

100%|██████████| 157/157 [00:31<00:00,  5.01it/s]

Test Accuracy: 0.7112





In [21]:
PATH = 'DANN_adv.pt'
torch.save(model.state_dict(), PATH)

### DANN+JSD Model Training
- We finetune the model pre-trained in previous part.

##### Loss Function

In [7]:
# Infomax Loss =========================================
class GlobalDiscriminator(nn.Module):
    def __init__(self, sz):
        super(GlobalDiscriminator, self).__init__()
        self.l0 = nn.Linear(sz, 512)
        self.l1 = nn.Linear(512, 512)
        self.l2 = nn.Linear(512, 1)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=1)
        h = F.relu(self.l0(x))
        h = F.relu(self.l1(h))
        return self.l2(h)
    
class DeepInfoMaxLoss(nn.Module):
    def __init__(self, sz=64, type="concat"):
        super().__init__()
        if type=="concat":
            self.global_d = GlobalDiscriminator(sz=2*sz)
        elif type=="dot":
            self.global_d = GlobalDiscriminatorDot(sz=sz)
        else:
            self.global_d = GlobalDiscriminatorConv(sz=2*sz)

    def forward(self, proto_label_pos, proto_label_neg, proto_unlabel_pos):
        Ej = -F.softplus(-self.global_d(proto_unlabel_pos, proto_label_pos)).mean()
        Em = F.softplus(self.global_d(proto_unlabel_pos, proto_label_neg)).mean()
        LOSS = (Em - Ej)
        
        return LOSS

In [8]:
# Load Model
PATH = 'DANN_adv.pt'

model = Net().to(device)
model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [9]:
# Initialize
loss_fn = DeepInfoMaxLoss(sz=64, type="concat").to(device)

criterion = nn.CrossEntropyLoss()
criterion_bce = nn.BCEWithLogitsLoss()

optimizer_fn = optim.Adam(loss_fn.parameters(), lr=1e-4)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [10]:
def get_pseudo_prediction(model, images):
    model.eval()
    with torch.no_grad():
        output_prob, _, _ = model(images)
    model.train()
    output_prob = torch.softmax(output_prob, 1)
    return output_prob

### EMA ###
class EMA(object):
    def __init__(self, model, alpha):
        self.step = 0
        self.model = copy.deepcopy(model)
        self.alpha = alpha

    def update(self, model):
        decay = min(1 - 1 / (self.step + 1), self.alpha)
        for ema_param, param in zip(self.model.parameters(), model.parameters()):
            ema_param.data = decay * ema_param.data + (1 - decay) * param.data
        self.step += 1

#### Training Loop

In [11]:
domain_iter = iter(target_loader)
ema_model = EMA(model, 0.99)

epochs = 40
for epoch in range(epochs):
    running_loss_class = 0
    running_loss_mi = 0    
    running_loss_adv = 0
    
    model.train()
    for images, labels in tqdm(source_loader):
        
        # Cuda
        images = images.to(device)
        labels = labels.to(device)
        
        try:
            images_domain, _ = next(domain_iter)
        except StopIteration:
            domain_iter = iter(target_loader)
            images_domain, _ = next(domain_iter)
        
        images_domain = images_domain.to(device)
        
        # Training pass
        optimizer.zero_grad()
        optimizer_fn.zero_grad()
        
        pseudo_output = get_pseudo_prediction(ema_model.model, images_domain)        
                
        output, source_adversarial, source_projection = model(images)
        _, target_adversarial, target_projection = model(images_domain)
        
        source_pos_projection = torch.tensor(()).to(device)
        source_neg_projection = torch.tensor(()).to(device)
        target_pos_projection = torch.tensor(()).to(device)
        target_neg_projection = torch.tensor(()).to(device)                

        pseudo_output = torch.max(pseudo_output, 1)
        filter_index = pseudo_output[0]>0.9

        for cls in pseudo_output[1][filter_index].unique():
            num_max = min(sum(labels==cls), sum(pseudo_output[1][filter_index]==cls), \
                         sum(pseudo_output[1][filter_index]!=cls))

            target_pos_projection = torch.cat((target_pos_projection, \
                                               target_projection[(filter_index) & (pseudo_output[1]==cls)][:num_max]))
            target_neg_projection = torch.cat((target_neg_projection, \
                                               target_projection[(filter_index) & (pseudo_output[1]!=cls)][:num_max]))                        
            source_pos_projection = torch.cat((source_pos_projection, \
                                               source_projection[labels==cls][:num_max]))
            source_neg_projection = torch.cat((source_neg_projection, \
                                               source_projection[labels!=cls][:num_max]))                        
                
        loss_class = criterion(output.squeeze(), labels)
        loss_mi = loss_fn(source_pos_projection, source_neg_projection, target_pos_projection) +\
        loss_fn(target_pos_projection, target_neg_projection, source_pos_projection)

        adversarial_pred = torch.cat((source_adversarial, target_adversarial))
        label_domain = torch.cat((torch.ones(BATCH_SIZE, 1), torch.zeros(BATCH_SIZE, 1))).to(device)        
        loss_domain = criterion_bce(adversarial_pred, label_domain)
                     
        loss = loss_class + loss_mi + loss_domain
        
        #This is where the model learns by backpropagating
        loss.backward()
        
        #And optimizes its weights here
        optimizer.step()
        optimizer_fn.step()
        
        ema_model.update(model)
        
        running_loss_class += loss_class.item()
        running_loss_adv += loss_domain.item()
        running_loss_mi += loss_mi.item()

    print("Epoch {} - Training loss Classifier: {}, Training loss MI: {}, Training Loss Adv: {}"\
          .format(epoch, running_loss_class/len(source_dataset), running_loss_mi, running_loss_adv))
    
    model.eval()
    correct_prediction = 0
    for images, labels in tqdm(source_loader_test):

        # Cuda
        images = images.to(device)
        labels = labels.to(device)
        
        with torch.no_grad():
            prediction, _, _ = model(images)
            correct_prediction += sum(torch.max(prediction.squeeze(), 1)[1]==labels).item()
    print("Epoch {} - Validation Accuracy: {}".format(epoch, correct_prediction/len(source_dataset_test)))

100%|██████████| 937/937 [03:39<00:00,  4.28it/s]


Epoch 0 - Training loss Classifier: 0.004703488431870937, Training loss MI: 1277.1779759526253, Training Loss Adv: 390.79958939552307


100%|██████████| 157/157 [00:01<00:00, 132.03it/s]


Epoch 0 - Validation Accuracy: 0.9834


100%|██████████| 157/157 [00:31<00:00,  4.95it/s]


Test Accuracy: 0.7121


100%|██████████| 937/937 [03:42<00:00,  4.21it/s]


Epoch 1 - Training loss Classifier: 0.004443418185412884, Training loss MI: 392.23421109840274, Training Loss Adv: 406.7928552031517


100%|██████████| 157/157 [00:01<00:00, 133.03it/s]


Epoch 1 - Validation Accuracy: 0.9836


100%|██████████| 157/157 [00:30<00:00,  5.07it/s]


Test Accuracy: 0.7644


100%|██████████| 937/937 [03:38<00:00,  4.28it/s]


Epoch 2 - Training loss Classifier: 0.004431360670551658, Training loss MI: 236.73952055722475, Training Loss Adv: 390.8268015086651


100%|██████████| 157/157 [00:01<00:00, 133.47it/s]


Epoch 2 - Validation Accuracy: 0.9842


100%|██████████| 157/157 [00:30<00:00,  5.09it/s]


Test Accuracy: 0.77


100%|██████████| 937/937 [03:40<00:00,  4.25it/s]


Epoch 3 - Training loss Classifier: 0.004388084874426325, Training loss MI: 187.00355763640255, Training Loss Adv: 377.6227168738842


100%|██████████| 157/157 [00:01<00:00, 133.31it/s]


Epoch 3 - Validation Accuracy: 0.9843


100%|██████████| 157/157 [00:31<00:00,  5.03it/s]


Test Accuracy: 0.7732


100%|██████████| 937/937 [03:38<00:00,  4.29it/s]


Epoch 4 - Training loss Classifier: 0.004125361414750417, Training loss MI: 166.97324007190764, Training Loss Adv: 369.8016970306635


100%|██████████| 157/157 [00:01<00:00, 133.44it/s]


Epoch 4 - Validation Accuracy: 0.9859


100%|██████████| 157/157 [00:30<00:00,  5.07it/s]


Test Accuracy: 0.7685


100%|██████████| 937/937 [03:40<00:00,  4.25it/s]


Epoch 5 - Training loss Classifier: 0.004143906399607659, Training loss MI: 143.8940303111449, Training Loss Adv: 404.24805265665054


100%|██████████| 157/157 [00:01<00:00, 133.33it/s]


Epoch 5 - Validation Accuracy: 0.984


100%|██████████| 157/157 [00:31<00:00,  4.95it/s]


Test Accuracy: 0.8106


100%|██████████| 937/937 [03:43<00:00,  4.19it/s]


Epoch 6 - Training loss Classifier: 0.003989100226139029, Training loss MI: 126.58005247730762, Training Loss Adv: 406.7014862895012


100%|██████████| 157/157 [00:01<00:00, 133.18it/s]


Epoch 6 - Validation Accuracy: 0.9846


100%|██████████| 157/157 [00:30<00:00,  5.08it/s]


Test Accuracy: 0.7887


100%|██████████| 937/937 [03:41<00:00,  4.24it/s]


Epoch 7 - Training loss Classifier: 0.0038973362456386287, Training loss MI: 126.51844676956534, Training Loss Adv: 441.1064518392086


100%|██████████| 157/157 [00:01<00:00, 133.53it/s]


Epoch 7 - Validation Accuracy: 0.9828


100%|██████████| 157/157 [00:31<00:00,  5.06it/s]


Test Accuracy: 0.8047


100%|██████████| 937/937 [03:41<00:00,  4.22it/s]


Epoch 8 - Training loss Classifier: 0.004005455867263178, Training loss MI: 109.50113238301128, Training Loss Adv: 448.4490275681019


100%|██████████| 157/157 [00:01<00:00, 133.55it/s]


Epoch 8 - Validation Accuracy: 0.9844


100%|██████████| 157/157 [00:31<00:00,  4.95it/s]


Test Accuracy: 0.8088


100%|██████████| 937/937 [03:45<00:00,  4.16it/s]


Epoch 9 - Training loss Classifier: 0.003921537736058235, Training loss MI: 119.03685953794047, Training Loss Adv: 450.05296847224236


100%|██████████| 157/157 [00:01<00:00, 133.60it/s]


Epoch 9 - Validation Accuracy: 0.9823


100%|██████████| 157/157 [00:31<00:00,  4.96it/s]


Test Accuracy: 0.8228


100%|██████████| 937/937 [03:42<00:00,  4.21it/s]


Epoch 10 - Training loss Classifier: 0.0037166804789255064, Training loss MI: 108.69050883012824, Training Loss Adv: 445.24670815467834


100%|██████████| 157/157 [00:01<00:00, 131.94it/s]


Epoch 10 - Validation Accuracy: 0.9835


100%|██████████| 157/157 [00:31<00:00,  5.01it/s]


Test Accuracy: 0.7895


100%|██████████| 937/937 [03:43<00:00,  4.20it/s]


Epoch 11 - Training loss Classifier: 0.0035781634954735637, Training loss MI: 102.77961317147128, Training Loss Adv: 448.41741678118706


100%|██████████| 157/157 [00:01<00:00, 133.37it/s]


Epoch 11 - Validation Accuracy: 0.9842


100%|██████████| 157/157 [00:30<00:00,  5.08it/s]


Test Accuracy: 0.7988


100%|██████████| 937/937 [03:40<00:00,  4.24it/s]


Epoch 12 - Training loss Classifier: 0.0034411467341706158, Training loss MI: 99.1337780142203, Training Loss Adv: 436.19766983389854


100%|██████████| 157/157 [00:01<00:00, 132.41it/s]


Epoch 12 - Validation Accuracy: 0.9851


100%|██████████| 157/157 [00:30<00:00,  5.16it/s]


Test Accuracy: 0.8216


100%|██████████| 937/937 [03:39<00:00,  4.27it/s]


Epoch 13 - Training loss Classifier: 0.003255849094626804, Training loss MI: 92.87385610654019, Training Loss Adv: 444.35932341217995


100%|██████████| 157/157 [00:01<00:00, 133.63it/s]


Epoch 13 - Validation Accuracy: 0.9848


100%|██████████| 157/157 [00:31<00:00,  5.06it/s]


Test Accuracy: 0.8193


100%|██████████| 937/937 [03:44<00:00,  4.17it/s]


Epoch 14 - Training loss Classifier: 0.0029983237745240332, Training loss MI: 87.76763339038007, Training Loss Adv: 436.1696158647537


100%|██████████| 157/157 [00:01<00:00, 133.63it/s]


Epoch 14 - Validation Accuracy: 0.9863


100%|██████████| 157/157 [00:31<00:00,  4.96it/s]


Test Accuracy: 0.8237


100%|██████████| 937/937 [03:43<00:00,  4.20it/s]


Epoch 15 - Training loss Classifier: 0.00284229455034559, Training loss MI: 91.90924913005438, Training Loss Adv: 440.2604983448982


100%|██████████| 157/157 [00:01<00:00, 132.56it/s]


Epoch 15 - Validation Accuracy: 0.9853


100%|██████████| 157/157 [00:31<00:00,  4.99it/s]


Test Accuracy: 0.81


100%|██████████| 937/937 [03:43<00:00,  4.20it/s]


Epoch 16 - Training loss Classifier: 0.0029080196389307577, Training loss MI: 87.86073023709469, Training Loss Adv: 446.8633613586426


100%|██████████| 157/157 [00:01<00:00, 133.55it/s]


Epoch 16 - Validation Accuracy: 0.9852


100%|██████████| 157/157 [00:31<00:00,  4.95it/s]


Test Accuracy: 0.7815


100%|██████████| 937/937 [03:42<00:00,  4.22it/s]


Epoch 17 - Training loss Classifier: 0.0027743527201625207, Training loss MI: 85.58463312755339, Training Loss Adv: 440.77616119384766


100%|██████████| 157/157 [00:01<00:00, 133.53it/s]


Epoch 17 - Validation Accuracy: 0.9856


100%|██████████| 157/157 [00:31<00:00,  5.00it/s]


Test Accuracy: 0.8148


100%|██████████| 937/937 [03:45<00:00,  4.15it/s]


Epoch 18 - Training loss Classifier: 0.002685109530389309, Training loss MI: 82.58213689841796, Training Loss Adv: 454.56578144431114


100%|██████████| 157/157 [00:01<00:00, 132.01it/s]


Epoch 18 - Validation Accuracy: 0.9858


100%|██████████| 157/157 [00:31<00:00,  5.03it/s]


Test Accuracy: 0.7878


100%|██████████| 937/937 [03:43<00:00,  4.19it/s]


Epoch 19 - Training loss Classifier: 0.0025851261872487765, Training loss MI: 86.73677307972685, Training Loss Adv: 445.01895010471344


100%|██████████| 157/157 [00:01<00:00, 133.70it/s]


Epoch 19 - Validation Accuracy: 0.9863


100%|██████████| 157/157 [00:31<00:00,  5.04it/s]


Test Accuracy: 0.7984


100%|██████████| 937/937 [03:44<00:00,  4.17it/s]


Epoch 20 - Training loss Classifier: 0.0025346066403823596, Training loss MI: 78.64947671734262, Training Loss Adv: 453.6993329524994


100%|██████████| 157/157 [00:01<00:00, 131.99it/s]


Epoch 20 - Validation Accuracy: 0.9873


100%|██████████| 157/157 [00:31<00:00,  4.96it/s]


Test Accuracy: 0.8231


100%|██████████| 937/937 [03:42<00:00,  4.21it/s]


Epoch 21 - Training loss Classifier: 0.002392945841824015, Training loss MI: 83.68483478552662, Training Loss Adv: 451.4788531959057


100%|██████████| 157/157 [00:01<00:00, 132.66it/s]


Epoch 21 - Validation Accuracy: 0.9869


100%|██████████| 157/157 [00:31<00:00,  4.94it/s]


Test Accuracy: 0.8086


100%|██████████| 937/937 [03:42<00:00,  4.20it/s]


Epoch 22 - Training loss Classifier: 0.002359369501316299, Training loss MI: 79.95658482948784, Training Loss Adv: 443.6977757513523


100%|██████████| 157/157 [00:01<00:00, 133.32it/s]


Epoch 22 - Validation Accuracy: 0.9863


100%|██████████| 157/157 [00:31<00:00,  5.00it/s]


Test Accuracy: 0.8268


100%|██████████| 937/937 [03:40<00:00,  4.26it/s]


Epoch 23 - Training loss Classifier: 0.002273447309465458, Training loss MI: 77.22545896121301, Training Loss Adv: 452.7937404215336


100%|██████████| 157/157 [00:01<00:00, 133.10it/s]


Epoch 23 - Validation Accuracy: 0.9862


100%|██████████| 157/157 [00:31<00:00,  5.02it/s]


Test Accuracy: 0.8292


100%|██████████| 937/937 [03:44<00:00,  4.17it/s]


Epoch 24 - Training loss Classifier: 0.0022656633188948036, Training loss MI: 76.5501827315893, Training Loss Adv: 445.78842663764954


100%|██████████| 157/157 [00:01<00:00, 133.84it/s]


Epoch 24 - Validation Accuracy: 0.9887


100%|██████████| 157/157 [00:30<00:00,  5.10it/s]


Test Accuracy: 0.8424


100%|██████████| 937/937 [03:42<00:00,  4.22it/s]


Epoch 25 - Training loss Classifier: 0.002225995225397249, Training loss MI: 76.4274573699804, Training Loss Adv: 448.24137234687805


100%|██████████| 157/157 [00:01<00:00, 132.71it/s]


Epoch 25 - Validation Accuracy: 0.9862


100%|██████████| 157/157 [00:31<00:00,  5.05it/s]


Test Accuracy: 0.841


100%|██████████| 937/937 [03:42<00:00,  4.22it/s]


Epoch 26 - Training loss Classifier: 0.002258953795209527, Training loss MI: 70.71505854302086, Training Loss Adv: 454.4401537179947


100%|██████████| 157/157 [00:01<00:00, 133.05it/s]


Epoch 26 - Validation Accuracy: 0.9873


100%|██████████| 157/157 [00:30<00:00,  5.09it/s]


Test Accuracy: 0.8475


100%|██████████| 937/937 [03:43<00:00,  4.19it/s]


Epoch 27 - Training loss Classifier: 0.002155713393073529, Training loss MI: 71.80558956367895, Training Loss Adv: 444.6672441959381


100%|██████████| 157/157 [00:01<00:00, 132.33it/s]


Epoch 27 - Validation Accuracy: 0.9877


100%|██████████| 157/157 [00:31<00:00,  4.96it/s]


Test Accuracy: 0.8498


100%|██████████| 937/937 [03:39<00:00,  4.27it/s]


Epoch 28 - Training loss Classifier: 0.0021905536635934064, Training loss MI: 72.64272961113602, Training Loss Adv: 450.0061900317669


100%|██████████| 157/157 [00:01<00:00, 133.20it/s]


Epoch 28 - Validation Accuracy: 0.9879


100%|██████████| 157/157 [00:30<00:00,  5.14it/s]


Test Accuracy: 0.8397


100%|██████████| 937/937 [03:42<00:00,  4.22it/s]


Epoch 29 - Training loss Classifier: 0.0021570310349731396, Training loss MI: 69.88456711359322, Training Loss Adv: 445.2771665453911


100%|██████████| 157/157 [00:01<00:00, 132.07it/s]


Epoch 29 - Validation Accuracy: 0.9872


100%|██████████| 157/157 [00:30<00:00,  5.11it/s]


Test Accuracy: 0.8514


100%|██████████| 937/937 [03:42<00:00,  4.21it/s]


Epoch 30 - Training loss Classifier: 0.0021932911397268373, Training loss MI: 70.72784498299006, Training Loss Adv: 453.90391263365746


100%|██████████| 157/157 [00:01<00:00, 133.05it/s]


Epoch 30 - Validation Accuracy: 0.988


100%|██████████| 157/157 [00:30<00:00,  5.07it/s]


Test Accuracy: 0.8592


100%|██████████| 937/937 [03:44<00:00,  4.18it/s]


Epoch 31 - Training loss Classifier: 0.002172606764656181, Training loss MI: 67.38823838689132, Training Loss Adv: 451.0541552901268


100%|██████████| 157/157 [00:01<00:00, 132.57it/s]


Epoch 31 - Validation Accuracy: 0.9888


100%|██████████| 157/157 [00:31<00:00,  5.06it/s]


Test Accuracy: 0.8616


100%|██████████| 937/937 [03:41<00:00,  4.22it/s]


Epoch 32 - Training loss Classifier: 0.002059726353144894, Training loss MI: 62.75171203009086, Training Loss Adv: 446.05929574370384


100%|██████████| 157/157 [00:01<00:00, 132.86it/s]


Epoch 32 - Validation Accuracy: 0.988


100%|██████████| 157/157 [00:31<00:00,  5.05it/s]


Test Accuracy: 0.8608


100%|██████████| 937/937 [03:43<00:00,  4.19it/s]


Epoch 33 - Training loss Classifier: 0.0020654280140840757, Training loss MI: 59.68325674912194, Training Loss Adv: 446.30947175621986


100%|██████████| 157/157 [00:01<00:00, 133.74it/s]


Epoch 33 - Validation Accuracy: 0.9876


100%|██████████| 157/157 [00:31<00:00,  5.01it/s]


Test Accuracy: 0.8572


100%|██████████| 937/937 [03:42<00:00,  4.20it/s]


Epoch 34 - Training loss Classifier: 0.0020601508626403907, Training loss MI: 57.79828108800575, Training Loss Adv: 448.10130098462105


100%|██████████| 157/157 [00:01<00:00, 133.58it/s]


Epoch 34 - Validation Accuracy: 0.9889


100%|██████████| 157/157 [00:30<00:00,  5.18it/s]


Test Accuracy: 0.8584


100%|██████████| 937/937 [03:41<00:00,  4.23it/s]


Epoch 35 - Training loss Classifier: 0.002020519978739321, Training loss MI: 56.063964965054765, Training Loss Adv: 453.5189053416252


100%|██████████| 157/157 [00:01<00:00, 132.54it/s]


Epoch 35 - Validation Accuracy: 0.9883


100%|██████████| 157/157 [00:31<00:00,  5.05it/s]


Test Accuracy: 0.8641


100%|██████████| 937/937 [03:45<00:00,  4.15it/s]


Epoch 36 - Training loss Classifier: 0.0020439407722093166, Training loss MI: 57.5003871062072, Training Loss Adv: 453.5408822596073


100%|██████████| 157/157 [00:01<00:00, 133.70it/s]


Epoch 36 - Validation Accuracy: 0.9871


100%|██████████| 157/157 [00:31<00:00,  5.01it/s]


Test Accuracy: 0.8698


100%|██████████| 937/937 [03:45<00:00,  4.16it/s]


Epoch 37 - Training loss Classifier: 0.00209182325466536, Training loss MI: 58.859781658509746, Training Loss Adv: 452.0610482990742


100%|██████████| 157/157 [00:01<00:00, 133.25it/s]


Epoch 37 - Validation Accuracy: 0.9885


100%|██████████| 157/157 [00:32<00:00,  4.84it/s]


Test Accuracy: 0.8766


100%|██████████| 937/937 [03:43<00:00,  4.20it/s]


Epoch 38 - Training loss Classifier: 0.0021412837629516917, Training loss MI: 57.69343749422114, Training Loss Adv: 470.4938782155514


100%|██████████| 157/157 [00:01<00:00, 132.09it/s]


Epoch 38 - Validation Accuracy: 0.9878


100%|██████████| 157/157 [00:30<00:00,  5.12it/s]


Test Accuracy: 0.856


100%|██████████| 937/937 [03:45<00:00,  4.16it/s]


Epoch 39 - Training loss Classifier: 0.002064281786015878, Training loss MI: 50.07603239585296, Training Loss Adv: 473.45588034391403


100%|██████████| 157/157 [00:01<00:00, 133.54it/s]


Epoch 39 - Validation Accuracy: 0.9876


100%|██████████| 157/157 [00:31<00:00,  4.93it/s]

Test Accuracy: 0.8724





In [12]:
model.eval()
correct_prediction = 0
for images, labels in tqdm(target_loader_test):
    
    # Cuda
    images = images.to(device)
    labels = labels.to(device)
        
    with torch.no_grad():
        prediction, _, _ = model(images)
        correct_prediction += sum(torch.max(prediction.squeeze(), 1)[1]==labels).item()

print('Test Accuracy: {}'.format(correct_prediction/len(target_dataset_test)))

100%|██████████| 157/157 [00:30<00:00,  5.08it/s]

Test Accuracy: 0.8762





In [13]:
PATH = 'DANN_adv_mi.pt'
torch.save(model.state_dict(), PATH)