In [1]:
import sys
sys.path.append("../")

from torch import nn
from data_utils import load_synsigns, load_GTSRB
from models import DomainAdaptationNetwork, get_simple_classifier, ProjectorNetwork

import torch
from torch.nn import functional as F
import numpy as np
from train import train_domain_adaptation
from utils import test_network, plot_tsne

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-39yq17uj because the default path (/home/david.bertoin/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
torch.manual_seed(1)
import numpy as np
np.random.seed(1)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
class Decoder(nn.Module):
    def __init__(self, latent_space_dim, conv_feat_size, nb_channels=3):
        super(Decoder, self).__init__()

        self.latent_space_dim = latent_space_dim
        self.conv_feat_size = conv_feat_size

        self.deco_dense = nn.Sequential(
            nn.Linear(in_features=latent_space_dim, out_features=1024),
            nn.ReLU(True),
            nn.Linear(in_features=1024, out_features=np.prod(self.conv_feat_size)),
            nn.ReLU(True),
        )

        self.deco_fetures = nn.Sequential(
            nn.Conv2d(in_channels=self.conv_feat_size[0], out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=2),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=32, out_channels=nb_channels, kernel_size=5, padding=2),
            nn.Sigmoid()
        )

    def forward(self, z_share, z_spe):
        z = torch.cat([z_share, z_spe], 1)
        feat_encode = self.deco_dense(z)
        feat_encode = feat_encode.view(-1, *self.conv_feat_size)
        y = self.deco_fetures(feat_encode)

        return y


class Encoder(nn.Module):
    def __init__(self, latent_space_dim, img_size, nb_channels=3):
        super(Encoder, self).__init__()

        self.latent_space_dim = latent_space_dim
        self.nb_channels = nb_channels

        self.conv_feat = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
            nn.ReLU(True),
            nn.InstanceNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(True),
            nn.InstanceNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.InstanceNorm2d(128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.InstanceNorm2d(128),
        )

        self.conv_feat_size = self.conv_feat(torch.zeros(1, *img_size)).shape[1:]
        self.dense_feature_size = np.prod(self.conv_feat_size)

        self.dense_feat = nn.Linear(in_features=self.dense_feature_size, out_features=1024)
        self.task_feat = nn.Linear(in_features=1024, out_features=latent_space_dim)
        self.source_feat = nn.Linear(in_features=1024, out_features=latent_space_dim)
        self.target_feat = nn.Linear(in_features=1024, out_features=latent_space_dim)

    def forward(self, input_data, mode='all'):
        if (input_data.shape[1] == 1) & (self.nb_channels == 3):
            input_data = input_data.repeat(1, 3, 1, 1)
        feat = self.conv_feat(input_data)
        feat = feat.view(-1, self.dense_feature_size)
        feat = F.relu(self.dense_feat(feat))
        if mode == 'task':
            z_task = F.relu(self.task_feat(feat))
            return z_task
        
        elif mode == 'source':
            z_source = F.relu(self.source_feat(feat))
            return z_source
        
        elif mode == 'target':
            z_target = F.relu(self.target_feat(feat))
            return z_target
        
        else:
            z_task = F.relu(self.task_feat(feat))
            z_source = F.relu(self.source_feat(feat))
            z_target = F.relu(self.target_feat(feat))
            return z_task, z_source, z_target  

In [4]:
def get_simple_classifier(latent_space_dim=1024):
    return nn.Sequential(nn.Dropout2d(),
                         nn.Linear(in_features=latent_space_dim, out_features=43),
                         nn.LogSoftmax(dim=1))

In [None]:
target_loader = load_GTSRB(img_size=64, batch_size=128, shuffle=True, num_workers=4)
source_loader = load_synsigns(img_size=64, batch_size=128, shuffle=True, num_workers=4)

learning_rate = 5e-4
#epochs=10
epochs=30

encoder = Encoder(latent_space_dim=75, img_size=(3,64,64), nb_channels=3)
conv_feat_size = encoder.conv_feat_size
decoder_source = Decoder(latent_space_dim=150, conv_feat_size=conv_feat_size, nb_channels=3)
decoder_target = Decoder(latent_space_dim=150, conv_feat_size=conv_feat_size, nb_channels=3)
classifier = get_simple_classifier(latent_space_dim=75)
model = DomainAdaptationNetwork(encoder, decoder_source, decoder_target, classifier).cuda()
random_projector = ProjectorNetwork(latent_dim=75).cuda()
betas = np.ones(30) * 10
betas[:10] = np.linspace(0, 10, 10)
#betas = np.linspace(0, 5, 30)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=0.001)

train_domain_adaptation(model, optimizer, random_projector, source_loader, target_loader, betas=betas,
                                            epochs=epochs, alpha=1, gamma=0.5, delta=0.5, show_images=False)

epoch:0 current target accuracy:46.62%:   3%|▎         | 1/30 [01:41<48:50, 101.06s/it]

epoch [1/30], loss:0.9654
accuracy source: 85.21%
accuracy target: 46.72%


epoch:1 current target accuracy:84.53%:   7%|▋         | 2/30 [03:22<47:18, 101.36s/it]

epoch [2/30], loss:0.7472
accuracy source: 98.71%
accuracy target: 84.72%


epoch:2 current target accuracy:92.34%:  10%|█         | 3/30 [05:04<45:41, 101.53s/it]

epoch [3/30], loss:0.6136
accuracy source: 99.51%
accuracy target: 92.54%


epoch:3 current target accuracy:94.63%:  13%|█▎        | 4/30 [06:46<44:01, 101.59s/it]

epoch [4/30], loss:0.5687
accuracy source: 99.52%
accuracy target: 94.84%


epoch:4 current target accuracy:95.07%:  17%|█▋        | 5/30 [08:27<42:20, 101.63s/it]

epoch [5/30], loss:0.5651
accuracy source: 99.49%
accuracy target: 95.28%


epoch:5 current target accuracy:95.49%:  17%|█▋        | 5/30 [09:01<42:20, 101.63s/it]

In [None]:
model.eval()
test_network(model, target_loader)

In [None]:
betas = np.ones(30) * 10
#betas = np.linspace(0, 5, 30)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=0.001)

train_domain_adaptation(model, optimizer, random_projector, source_loader, target_loader, betas=betas,
                                            epochs=epochs, alpha=1, gamma=0.5, delta=0.5, show_images=False)

In [None]:
model.eval()
test_network(model, target_loader)

In [None]:
import matplotlib.pyplot as plt
def plot_target_cross_domain_swapping(model, source_train_loader, target_train_loader):
    X, _ = next(iter(target_train_loader))
    y, _, (z_share, z_spe),  _, _ = model.forward_t(X.cuda())
    X2, _ = next(iter(source_train_loader))
    _, _, (z_share, _),  _, _ = model.forward_t(X2.cuda())
    #blank
    plt.subplot(1,6,1)
    plt.imshow(torch.ones((32,32,3)))
    plt.axis('off')
    plt.tight_layout()
    #styles
    for i in range(5):
        plt.subplot(1,6,i+2)
        plt.imshow(X[i].cpu().detach().permute(1, 2, 0))
        plt.axis('off')
        plt.tight_layout()

    for j in range(10, 20):
        plt.figure()
        plt.subplot(1,6,1)
        plt.imshow(X2[j].cpu().detach().permute(1, 2, 0))
        plt.axis('off')
        plt.tight_layout()

        z_x = torch.zeros_like(z_share)
        z_x[:] = z_share[j]
        y2  = model.decoder_target(z_x, z_spe)
        for i in range(5):
            plt.subplot(1,6,i+2)
            plt.imshow(y2[i].cpu().detach().permute(1, 2, 0))
            plt.axis('off')
            plt.tight_layout()

In [None]:
plot_target_cross_domain_swapping(model, source_loader, target_loader)
#plot_tsne(model, source_train_loader, target_train_loader, 128, 75)