In [1]:
import sys
sys.path.append("../")

from torch import nn
from data_utils import load_mnist, load_usps
from models import DomainAdaptationNetwork, ProjectorNetwork

import torch
from torch.nn import functional as F
import numpy as np
from train import train_domain_adaptation
from utils import test_network, plot_target_cross_domain_swapping, plot_tsne

In [2]:
torch.manual_seed(0)
import numpy as np
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
class Decoder(nn.Module):
    def __init__(self, latent_space_dim, conv_feat_size, nb_channels=3):
        super(Decoder, self).__init__()

        self.latent_space_dim = latent_space_dim
        self.conv_feat_size = conv_feat_size

        self.deco_dense = nn.Sequential(
            nn.Linear(in_features=latent_space_dim, out_features=1024),
            nn.ReLU(True),
            nn.Linear(in_features=1024, out_features=np.prod(self.conv_feat_size)),
            nn.ReLU(True),
        )

        self.deco_fetures = nn.Sequential(
            nn.Conv2d(in_channels=self.conv_feat_size[0], out_channels=75, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=75, out_channels=50, kernel_size=5, padding=2),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=50, out_channels=1, kernel_size=5, padding=2),
            nn.Sigmoid()
        )

    def forward(self, z_share, z_spe):
        z = torch.cat([z_share, z_spe], 1)
        feat_encode = self.deco_dense(z)
        feat_encode = feat_encode.view(-1, *self.conv_feat_size)
        y = self.deco_fetures(feat_encode)

        return y


class Encoder(nn.Module):
    def __init__(self, latent_space_dim, img_size, nb_channels=3):
        super(Encoder, self).__init__()

        self.latent_space_dim = latent_space_dim
        self.nb_channels = nb_channels

        self.conv_feat = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=50, kernel_size=5, padding=2),
            nn.ReLU(True),
            nn.BatchNorm2d(50),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=50, out_channels=75, kernel_size=5, padding=2),
            nn.ReLU(True),
            nn.BatchNorm2d(75),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=75, out_channels=100, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(100),
        )

        self.conv_feat_size = self.conv_feat(torch.zeros(1, *img_size)).shape[1:]
        self.dense_feature_size = np.prod(self.conv_feat_size)

        self.dense_feat = nn.Linear(in_features=self.dense_feature_size, out_features=1024)
        self.task_feat = nn.Linear(in_features=1024, out_features=latent_space_dim)
        self.source_feat = nn.Linear(in_features=1024, out_features=latent_space_dim)
        self.target_feat = nn.Linear(in_features=1024, out_features=latent_space_dim)

    def forward(self, input_data, mode='all'):
        if (input_data.shape[1] == 1) & (self.nb_channels == 3):
            input_data = input_data.repeat(1, 3, 1, 1)
        feat = self.conv_feat(input_data)
        feat = feat.view(-1, self.dense_feature_size)
        feat = F.relu(self.dense_feat(feat))
        if mode == 'task':
            z_task = F.relu(self.task_feat(feat))
            return z_task
        
        elif mode == 'source':
            z_source = F.relu(self.source_feat(feat))
            return z_source
        
        elif mode == 'target':
            z_target = F.relu(self.target_feat(feat))
            return z_target
        
        else:
            z_task = F.relu(self.task_feat(feat))
            z_source = F.relu(self.source_feat(feat))
            z_target = F.relu(self.target_feat(feat))
            return z_task, z_source, z_target        



In [4]:
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from utils import show_decoded_images
from models import GradReverse

criterion_reconstruction = nn.BCELoss()
criterion_classifier = nn.NLLLoss(reduction='mean')
criterion_weighted_classifier = nn.NLLLoss(reduction='none')
criterion_disentangle = nn.MSELoss()
criterion_distance = nn.MSELoss()
criterion_triplet = nn.TripletMarginLoss(margin=1)

criterion_reconstruction = nn.BCELoss()
disentangle_criterion = nn.MSELoss()
criterion_classifier = nn.NLLLoss(reduction='mean')
criterion_triplet = nn.TripletMarginLoss(margin=1)

def train_domain_adaptation(model, optimizer, random_projector, source_train_loader, target_train_loader, betas,
                            alpha=1, gamma=1, delta=0.1, epochs=30, show_images=False):

    t = tqdm(range(epochs))
    for epoch in t:
        total_loss = 0
        corrects_source, corrects_target = 0, 0
        total_source, total_target = 0, 0

        # random images used for disentanglement
        #xs_rand = next(iter(source_train_loader))[0].cuda()
        #xt_rand = next(iter(target_train_loader))[0].cuda()

        for (x_s, y_s), (x_t, y_t) in zip(source_train_loader, target_train_loader):
            loss = 0
            x_s, y_s, x_t, y_t = x_s.cuda(), y_s.cuda(), x_t.cuda(), y_t.cuda()
            min_len = min(len(x_s), len(x_t))
            x_s, y_s, x_t, y_t = x_s[:min_len], y_s[:min_len], x_t[:min_len], y_t[:min_len]
            # target batch
            xt_hat, yt_hat, (z_task, z_target), (pred_task, pred_spe) = model(x_t, mode='all_target')
            #Random projection to reduce the dimension
            random_task = random_projector(GradReverse.grad_reverse(z_task))
            random_spe = random_projector(GradReverse.grad_reverse(z_target))
            
            # synthetic sample with task information from x_t and style info from xs_rand
            xts = model.decode(z_task, x_s[:len(x_t)], mode='source')
            z_s = model.encoder(xts.detach(), mode='task')
            z_target_prime = model.encoder(torch.flip(x_t, [0])[:len(x_t)], mode='target')
            xt_prime = model.decoder_target(z_task, z_target_prime)
            yt_tilde, z_target_tilde = model.forward(xt_prime, mode='target')

            w, predicted = yt_hat.max(1)
            corrects_target += predicted.eq(y_t).sum().item()
            total_target += y_t.size(0)

            loss += alpha * criterion_reconstruction(xt_hat, x_t)
            loss += betas[epoch] * criterion_distance(z_task, z_s)
            loss += gamma * (criterion_disentangle(pred_task, random_task) + criterion_disentangle(pred_spe, random_spe))
            loss += 0.1 * torch.mean((torch.exp(w.detach()) * criterion_weighted_classifier(yt_tilde, predicted.detach())))
            loss += delta[epoch] * criterion_triplet(z_target_tilde, z_target_prime, z_target)

            # source batch
            xs_hat, ys_hat, (z_task, z_source), (pred_task, pred_spe) = model(x_s, mode='all_source')
            #Random projection to reduce the dimension
            random_task = random_projector(GradReverse.grad_reverse(z_task))
            random_spe = random_projector(GradReverse.grad_reverse(z_source))
            
            # synthetic sample with task information from x_s and style info from xt_rand
            xst = model.decode(z_task, x_t[:len(x_s)], mode='target')
            z_t = model.encoder(xst.detach(), mode='task')
            z_source_prime = model.encoder(torch.flip(x_s, [0])[:len(x_s)], mode='source')
            xs_prime = model.decoder_source(z_task, z_source_prime)
            ys_tilde, z_source_tilde = model(xs_prime, mode='source')

            _, predicted = ys_hat.max(1)
            corrects_source += predicted.eq(y_s).sum().item()
            total_source += y_s.size(0)

            loss += criterion_classifier(ys_hat, y_s)
            loss += alpha * criterion_reconstruction(xs_hat, x_s)
            loss += betas[epoch] * criterion_distance(z_task, z_t)
            loss += gamma * (criterion_disentangle(pred_task, random_task) + criterion_disentangle(pred_spe, random_spe))
            loss += 0.1 * criterion_classifier(ys_tilde, y_s.cuda())
            loss += delta[epoch] * criterion_triplet(z_source_tilde, z_source_prime, z_source)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += float(loss.data)
            t.set_description(f'epoch:{epoch} current target accuracy:{round(corrects_target / total_source * 100, 2)}%')
        # ===================log========================
        print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epochs, total_loss / len(source_train_loader)))
        print(f'accuracy source: {round(corrects_source / total_source * 100, 2)}%')
        print(f'accuracy target: {round(corrects_target / total_target * 100, 2)}%')
        if show_images:
            show_decoded_images(x_s[:16], xs_hat[:16], x_t[:len(x_s)][:16], xst[:16])
            show_decoded_images(x_t[:16], xt_hat[:16], x_s[:len(x_t)][:16], xts[:16])

In [18]:
learning_rate = 5e-4

source_train_loader, source_test_loader =  load_usps(img_size=32, augment=True, batch_size=64, shuffle=True, num_workers=4)
target_train_loader, target_test_loader = load_mnist(img_size=32, augment=True, batch_size=64, shuffle=True, num_workers=4)

encoder = Encoder(latent_space_dim=150, img_size=(1,32,32), nb_channels=1)

conv_feat_size = encoder.conv_feat_size
decoder_source = Decoder(latent_space_dim=300, conv_feat_size=conv_feat_size, nb_channels=1)
decoder_target = Decoder(latent_space_dim=300, conv_feat_size=conv_feat_size, nb_channels=1)

classifier = nn.Sequential(nn.Dropout2d(),
                         nn.Linear(in_features=150, out_features=10),
                         nn.LogSoftmax(dim=1))

model = DomainAdaptationNetwork(encoder, decoder_source, decoder_target, classifier).cuda()
random_projector = torch.cos
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=0.001)

epochs = 50
betas = np.linspace(0, 5, epochs)
delta = np.linspace(0, 5, epochs)
train_domain_adaptation(model, optimizer, random_projector, source_train_loader, target_train_loader,
                                            epochs=50, betas=betas, gamma=0.5, delta=delta)

epoch:0 current target accuracy:37.32%:   2%|▏         | 1/50 [00:10<08:40, 10.63s/it]

epoch [1/50], loss:3.9241
accuracy source: 80.43%
accuracy target: 37.32%


epoch:1 current target accuracy:62.01%:   4%|▍         | 2/50 [00:21<08:29, 10.62s/it]

epoch [2/50], loss:3.6966
accuracy source: 95.87%
accuracy target: 62.01%


epoch:2 current target accuracy:68.14%:   6%|▌         | 3/50 [00:31<08:18, 10.60s/it]

epoch [3/50], loss:3.5387
accuracy source: 96.91%
accuracy target: 68.14%


epoch:3 current target accuracy:51.83%:   8%|▊         | 4/50 [00:42<08:06, 10.59s/it]

epoch [4/50], loss:3.5233
accuracy source: 97.28%
accuracy target: 51.83%


epoch:4 current target accuracy:35.55%:  10%|█         | 5/50 [00:52<07:55, 10.56s/it]

epoch [5/50], loss:3.4338
accuracy source: 97.02%
accuracy target: 35.55%


epoch:5 current target accuracy:64.09%:  12%|█▏        | 6/50 [01:03<07:44, 10.56s/it]

epoch [6/50], loss:3.2401
accuracy source: 97.87%
accuracy target: 64.09%


epoch:6 current target accuracy:88.19%:  14%|█▍        | 7/50 [01:14<07:35, 10.58s/it]

epoch [7/50], loss:3.1034
accuracy source: 98.4%
accuracy target: 88.19%


epoch:7 current target accuracy:89.92%:  16%|█▌        | 8/50 [01:24<07:24, 10.58s/it]

epoch [8/50], loss:3.1451
accuracy source: 98.23%
accuracy target: 89.92%


epoch:8 current target accuracy:88.51%:  18%|█▊        | 9/50 [01:35<07:13, 10.58s/it]

epoch [9/50], loss:3.0766
accuracy source: 98.55%
accuracy target: 88.51%


epoch:9 current target accuracy:86.46%:  20%|██        | 10/50 [01:45<07:03, 10.58s/it]

epoch [10/50], loss:3.0663
accuracy source: 98.49%
accuracy target: 86.46%


epoch:10 current target accuracy:88.6%:  22%|██▏       | 11/50 [01:56<06:52, 10.57s/it] 

epoch [11/50], loss:3.0564
accuracy source: 98.93%
accuracy target: 88.6%


epoch:11 current target accuracy:90.54%:  24%|██▍       | 12/50 [02:06<06:41, 10.57s/it]

epoch [12/50], loss:3.0670
accuracy source: 98.74%
accuracy target: 90.54%


epoch:12 current target accuracy:92.54%:  26%|██▌       | 13/50 [02:17<06:31, 10.58s/it]

epoch [13/50], loss:3.0563
accuracy source: 98.93%
accuracy target: 92.54%


epoch:13 current target accuracy:92.77%:  28%|██▊       | 14/50 [02:28<06:21, 10.59s/it]

epoch [14/50], loss:3.0058
accuracy source: 98.99%
accuracy target: 92.77%


epoch:14 current target accuracy:92.17%:  30%|███       | 15/50 [02:38<06:10, 10.58s/it]

epoch [15/50], loss:3.0203
accuracy source: 98.96%
accuracy target: 92.17%


epoch:15 current target accuracy:89.84%:  32%|███▏      | 16/50 [02:49<05:59, 10.58s/it]

epoch [16/50], loss:2.9917
accuracy source: 99.11%
accuracy target: 89.84%


epoch:16 current target accuracy:87.88%:  34%|███▍      | 17/50 [02:59<05:48, 10.57s/it]

epoch [17/50], loss:3.0153
accuracy source: 98.96%
accuracy target: 87.88%


epoch:17 current target accuracy:90.11%:  36%|███▌      | 18/50 [03:10<05:37, 10.56s/it]

epoch [18/50], loss:3.0486
accuracy source: 98.71%
accuracy target: 90.11%


epoch:18 current target accuracy:93.02%:  38%|███▊      | 19/50 [03:20<05:27, 10.56s/it]

epoch [19/50], loss:3.0098
accuracy source: 99.25%
accuracy target: 93.02%


epoch:19 current target accuracy:93.5%:  40%|████      | 20/50 [03:31<05:17, 10.57s/it] 

epoch [20/50], loss:2.9882
accuracy source: 99.15%
accuracy target: 93.5%


epoch:20 current target accuracy:93.86%:  42%|████▏     | 21/50 [03:42<05:06, 10.57s/it]

epoch [21/50], loss:2.9969
accuracy source: 99.31%
accuracy target: 93.86%


epoch:21 current target accuracy:93.47%:  44%|████▍     | 22/50 [03:52<04:55, 10.57s/it]

epoch [22/50], loss:2.9942
accuracy source: 99.14%
accuracy target: 93.47%


epoch:22 current target accuracy:90.5%:  46%|████▌     | 23/50 [04:03<04:45, 10.56s/it] 

epoch [23/50], loss:2.9457
accuracy source: 99.19%
accuracy target: 90.5%


epoch:23 current target accuracy:90.85%:  48%|████▊     | 24/50 [04:13<04:34, 10.55s/it]

epoch [24/50], loss:2.9106
accuracy source: 99.36%
accuracy target: 90.85%


epoch:24 current target accuracy:92.15%:  50%|█████     | 25/50 [04:24<04:23, 10.55s/it]

epoch [25/50], loss:2.8954
accuracy source: 99.09%
accuracy target: 92.15%


epoch:25 current target accuracy:94.24%:  52%|█████▏    | 26/50 [04:34<04:13, 10.57s/it]

epoch [26/50], loss:2.8740
accuracy source: 99.15%
accuracy target: 94.24%


epoch:26 current target accuracy:94.93%:  54%|█████▍    | 27/50 [04:45<04:03, 10.58s/it]

epoch [27/50], loss:2.8502
accuracy source: 99.15%
accuracy target: 94.93%


epoch:27 current target accuracy:95.16%:  56%|█████▌    | 28/50 [04:56<03:52, 10.58s/it]

epoch [28/50], loss:2.8332
accuracy source: 99.41%
accuracy target: 95.16%


epoch:28 current target accuracy:95.38%:  58%|█████▊    | 29/50 [05:06<03:42, 10.60s/it]

epoch [29/50], loss:2.8009
accuracy source: 99.33%
accuracy target: 95.38%


epoch:29 current target accuracy:95.08%:  60%|██████    | 30/50 [05:17<03:31, 10.60s/it]

epoch [30/50], loss:2.7772
accuracy source: 99.48%
accuracy target: 95.08%


epoch:30 current target accuracy:94.76%:  62%|██████▏   | 31/50 [05:27<03:21, 10.59s/it]

epoch [31/50], loss:2.8451
accuracy source: 99.47%
accuracy target: 94.76%


epoch:31 current target accuracy:95.1%:  64%|██████▍   | 32/50 [05:38<03:10, 10.57s/it] 

epoch [32/50], loss:2.8530
accuracy source: 99.27%
accuracy target: 95.1%


epoch:32 current target accuracy:95.91%:  66%|██████▌   | 33/50 [05:48<02:59, 10.56s/it]

epoch [33/50], loss:2.8426
accuracy source: 99.38%
accuracy target: 95.91%


epoch:33 current target accuracy:94.69%:  68%|██████▊   | 34/50 [05:59<02:48, 10.56s/it]

epoch [34/50], loss:2.8656
accuracy source: 99.3%
accuracy target: 94.69%


epoch:34 current target accuracy:94.17%:  70%|███████   | 35/50 [06:10<02:38, 10.57s/it]

epoch [35/50], loss:2.8791
accuracy source: 99.4%
accuracy target: 94.17%


epoch:35 current target accuracy:94.42%:  72%|███████▏  | 36/50 [06:20<02:28, 10.58s/it]

epoch [36/50], loss:2.8619
accuracy source: 99.44%
accuracy target: 94.42%


epoch:36 current target accuracy:94.38%:  74%|███████▍  | 37/50 [06:31<02:17, 10.58s/it]

epoch [37/50], loss:2.8524
accuracy source: 99.4%
accuracy target: 94.38%


epoch:37 current target accuracy:94.46%:  76%|███████▌  | 38/50 [06:41<02:07, 10.59s/it]

epoch [38/50], loss:2.8959
accuracy source: 99.33%
accuracy target: 94.46%


epoch:38 current target accuracy:95.21%:  78%|███████▊  | 39/50 [06:52<01:56, 10.59s/it]

epoch [39/50], loss:2.9028
accuracy source: 99.3%
accuracy target: 95.21%


epoch:39 current target accuracy:92.73%:  80%|████████  | 40/50 [07:02<01:45, 10.57s/it]

epoch [40/50], loss:2.8895
accuracy source: 99.44%
accuracy target: 92.73%


epoch:40 current target accuracy:93.91%:  82%|████████▏ | 41/50 [07:13<01:35, 10.58s/it]

epoch [41/50], loss:2.9018
accuracy source: 99.4%
accuracy target: 93.91%


epoch:41 current target accuracy:95.47%:  84%|████████▍ | 42/50 [07:24<01:24, 10.59s/it]

epoch [42/50], loss:2.8966
accuracy source: 99.27%
accuracy target: 95.47%


epoch:42 current target accuracy:93.53%:  86%|████████▌ | 43/50 [07:34<01:14, 10.64s/it]

epoch [43/50], loss:2.9178
accuracy source: 99.27%
accuracy target: 93.53%


epoch:43 current target accuracy:92.4%:  88%|████████▊ | 44/50 [07:45<01:03, 10.62s/it] 

epoch [44/50], loss:2.9266
accuracy source: 99.27%
accuracy target: 92.4%


epoch:44 current target accuracy:93.4%:  90%|█████████ | 45/50 [07:56<00:52, 10.59s/it] 

epoch [45/50], loss:2.9225
accuracy source: 99.19%
accuracy target: 93.4%


epoch:45 current target accuracy:91.78%:  92%|█████████▏| 46/50 [08:06<00:42, 10.60s/it]

epoch [46/50], loss:2.9425
accuracy source: 99.23%
accuracy target: 91.78%


epoch:46 current target accuracy:94.6%:  94%|█████████▍| 47/50 [08:17<00:31, 10.60s/it] 

epoch [47/50], loss:2.9096
accuracy source: 99.47%
accuracy target: 94.6%


epoch:47 current target accuracy:94.61%:  96%|█████████▌| 48/50 [08:27<00:21, 10.61s/it]

epoch [48/50], loss:2.9357
accuracy source: 99.36%
accuracy target: 94.61%


epoch:48 current target accuracy:92.55%:  98%|█████████▊| 49/50 [08:38<00:10, 10.60s/it]

epoch [49/50], loss:2.9339
accuracy source: 99.37%
accuracy target: 92.55%


epoch:49 current target accuracy:93.17%: 100%|██████████| 50/50 [08:49<00:00, 10.58s/it]

epoch [50/50], loss:2.9355
accuracy source: 99.3%
accuracy target: 93.17%





In [21]:
model.eval()
print(test_network(model, target_test_loader) )
model.train()

0.9859


DomainAdaptationNetwork(
  (encoder): Encoder(
    (conv_feat): Sequential(
      (0): Conv2d(1, 50, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(50, 75, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (5): ReLU(inplace=True)
      (6): BatchNorm2d(75, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (8): Conv2d(75, 100, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (dense_feat): Linear(in_features=6400, out_features=1024, bias=True)
    (task_feat): Linear(in_features=1024, out_features=150, bias=True)
    (s

In [20]:
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-4, weight_decay=0.001)

epochs = 30
betas = np.ones(epochs) * 10

train_domain_adaptation(model, optimizer, random_projector, source_train_loader, target_train_loader, betas=betas,
                                            epochs=epochs, alpha=1, delta=delta, gamma=0.1, show_images=False)

epoch:0 current target accuracy:93.75%:   3%|▎         | 1/30 [00:10<05:05, 10.53s/it]

epoch [1/30], loss:1.1302
accuracy source: 99.47%
accuracy target: 93.75%


epoch:1 current target accuracy:94.83%:   7%|▋         | 2/30 [00:21<04:55, 10.56s/it]

epoch [2/30], loss:1.0799
accuracy source: 99.59%
accuracy target: 94.83%


epoch:2 current target accuracy:96.24%:  10%|█         | 3/30 [00:31<04:44, 10.55s/it]

epoch [3/30], loss:1.0648
accuracy source: 99.62%
accuracy target: 96.24%


epoch:3 current target accuracy:96.38%:  13%|█▎        | 4/30 [00:42<04:34, 10.56s/it]

epoch [4/30], loss:1.0662
accuracy source: 99.62%
accuracy target: 96.38%


epoch:4 current target accuracy:95.87%:  17%|█▋        | 5/30 [00:52<04:24, 10.59s/it]

epoch [5/30], loss:1.0701
accuracy source: 99.75%
accuracy target: 95.87%


epoch:5 current target accuracy:96.17%:  20%|██        | 6/30 [01:03<04:14, 10.58s/it]

epoch [6/30], loss:1.0560
accuracy source: 99.62%
accuracy target: 96.17%


epoch:6 current target accuracy:96.43%:  23%|██▎       | 7/30 [01:14<04:03, 10.57s/it]

epoch [7/30], loss:1.0540
accuracy source: 99.67%
accuracy target: 96.43%


epoch:7 current target accuracy:96.39%:  27%|██▋       | 8/30 [01:24<03:52, 10.58s/it]

epoch [8/30], loss:1.0514
accuracy source: 99.67%
accuracy target: 96.39%


epoch:8 current target accuracy:96.75%:  30%|███       | 9/30 [01:35<03:42, 10.57s/it]

epoch [9/30], loss:1.0419
accuracy source: 99.73%
accuracy target: 96.75%


epoch:9 current target accuracy:96.85%:  33%|███▎      | 10/30 [01:45<03:31, 10.55s/it]

epoch [10/30], loss:1.0528
accuracy source: 99.74%
accuracy target: 96.85%


epoch:10 current target accuracy:96.71%:  37%|███▋      | 11/30 [01:56<03:20, 10.55s/it]

epoch [11/30], loss:1.0227
accuracy source: 99.64%
accuracy target: 96.71%


epoch:11 current target accuracy:97.08%:  40%|████      | 12/30 [02:06<03:10, 10.56s/it]

epoch [12/30], loss:1.0476
accuracy source: 99.64%
accuracy target: 97.08%


epoch:12 current target accuracy:97.23%:  43%|████▎     | 13/30 [02:17<02:59, 10.59s/it]

epoch [13/30], loss:1.0355
accuracy source: 99.66%
accuracy target: 97.23%


epoch:13 current target accuracy:97.04%:  47%|████▋     | 14/30 [02:28<02:49, 10.60s/it]

epoch [14/30], loss:1.0397
accuracy source: 99.7%
accuracy target: 97.04%


epoch:14 current target accuracy:97.48%:  50%|█████     | 15/30 [02:38<02:38, 10.60s/it]

epoch [15/30], loss:1.0347
accuracy source: 99.66%
accuracy target: 97.48%


epoch:15 current target accuracy:97.12%:  53%|█████▎    | 16/30 [02:49<02:28, 10.59s/it]

epoch [16/30], loss:1.0191
accuracy source: 99.73%
accuracy target: 97.12%


epoch:16 current target accuracy:97.27%:  57%|█████▋    | 17/30 [02:59<02:17, 10.58s/it]

epoch [17/30], loss:1.0410
accuracy source: 99.7%
accuracy target: 97.27%


epoch:17 current target accuracy:97.11%:  60%|██████    | 18/30 [03:10<02:06, 10.56s/it]

epoch [18/30], loss:1.0330
accuracy source: 99.7%
accuracy target: 97.11%


epoch:18 current target accuracy:97.5%:  63%|██████▎   | 19/30 [03:20<01:56, 10.55s/it] 

epoch [19/30], loss:1.0262
accuracy source: 99.68%
accuracy target: 97.5%


epoch:19 current target accuracy:97.22%:  67%|██████▋   | 20/30 [03:31<01:45, 10.57s/it]

epoch [20/30], loss:1.0352
accuracy source: 99.81%
accuracy target: 97.22%


epoch:20 current target accuracy:97.39%:  70%|███████   | 21/30 [03:42<01:35, 10.57s/it]

epoch [21/30], loss:1.0345
accuracy source: 99.81%
accuracy target: 97.39%


epoch:21 current target accuracy:97.6%:  73%|███████▎  | 22/30 [03:52<01:24, 10.56s/it] 

epoch [22/30], loss:1.0330
accuracy source: 99.75%
accuracy target: 97.6%


epoch:22 current target accuracy:97.81%:  77%|███████▋  | 23/30 [04:03<01:13, 10.56s/it]

epoch [23/30], loss:1.0178
accuracy source: 99.81%
accuracy target: 97.81%


epoch:23 current target accuracy:97.48%:  80%|████████  | 24/30 [04:13<01:03, 10.55s/it]

epoch [24/30], loss:1.0312
accuracy source: 99.79%
accuracy target: 97.48%


epoch:24 current target accuracy:97.39%:  83%|████████▎ | 25/30 [04:24<00:52, 10.56s/it]

epoch [25/30], loss:1.0221
accuracy source: 99.78%
accuracy target: 97.39%


epoch:25 current target accuracy:97.53%:  87%|████████▋ | 26/30 [04:34<00:42, 10.59s/it]

epoch [26/30], loss:1.0237
accuracy source: 99.66%
accuracy target: 97.53%


epoch:26 current target accuracy:97.79%:  90%|█████████ | 27/30 [04:45<00:31, 10.61s/it]

epoch [27/30], loss:1.0121
accuracy source: 99.75%
accuracy target: 97.79%


epoch:27 current target accuracy:97.57%:  93%|█████████▎| 28/30 [04:56<00:21, 10.60s/it]

epoch [28/30], loss:1.0049
accuracy source: 99.79%
accuracy target: 97.57%


epoch:28 current target accuracy:97.82%:  97%|█████████▋| 29/30 [05:06<00:10, 10.60s/it]

epoch [29/30], loss:1.0317
accuracy source: 99.81%
accuracy target: 97.82%


epoch:29 current target accuracy:97.54%: 100%|██████████| 30/30 [05:17<00:00, 10.58s/it]

epoch [30/30], loss:1.0198
accuracy source: 99.77%
accuracy target: 97.54%





In [22]:
model.eval()
test_network(model, target_test_loader)  

0.9859

In [None]:
plot_target_cross_domain_swapping(model, source_train_loader, target_train_loader)

In [None]:
torch.save(model, "usps_to_mnist.pth")