In [1]:
import sys
sys.path.append("../")

from torch import nn
from data_utils import load_mnist, load_svhn
from models import DisentangledDomainAdaptationNetwork, get_simple_classifier

import torch
from torch.nn import functional as F
import numpy as np
from utils import show_decoded_images, plot_results
from train import train_domain_adaptation
from utils import test_network

In [2]:
torch.manual_seed(0)
import numpy as np
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
target_train_loader, target_test_loader = load_svhn(img_size=(32, 32), batch_size=64, split=1, shuffle=True, num_workers=2)
source_train_loader, source_test_loader = load_mnist(img_size=32, batch_size=64, shuffle=True, num_workers=2)

Using downloaded and verified file: ../data/train_32x32.mat


In [4]:
class Decoder(nn.Module):
    def __init__(self, latent_space_dim, conv_feat_size, nb_channels=3):
        super(Decoder, self).__init__()

        self.latent_space_dim = latent_space_dim
        self.conv_feat_size = conv_feat_size

        self.deco_dense = nn.Sequential(
            nn.Linear(in_features=latent_space_dim, out_features=1024),
            nn.ReLU(True),
            nn.Linear(in_features=1024, out_features=np.prod(self.conv_feat_size)),
            nn.ReLU(True),
        )

        self.deco_fetures = nn.Sequential(
            nn.Conv2d(self.conv_feat_size[0], out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=64, out_channels=nb_channels, kernel_size=5, padding=2),
            nn.Sigmoid()
        )

    def forward(self, z_share, z_spe):
        z = torch.cat([z_share, z_spe], 1)
        feat_encode = self.deco_dense(z)
        feat_encode = feat_encode.view(-1, *self.conv_feat_size)
        y = self.deco_fetures(feat_encode)

        return y


class Encoder(nn.Module):
    def __init__(self, latent_space_dim, img_size, nb_channels=3):
        super(Encoder, self).__init__()

        self.latent_space_dim = latent_space_dim
        self.nb_channels = nb_channels

        self.conv_feat = nn.Sequential(
            nn.Conv2d(nb_channels, out_channels=32, kernel_size=5, padding=2),
            nn.InstanceNorm2d(32),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
        )

        self.conv_feat_size = self.conv_feat(torch.zeros(1, *img_size)).shape[1:]
        self.dense_feature_size = np.prod(self.conv_feat_size)

        self.dense_feat = nn.Sequential(
            nn.Linear(in_features=self.dense_feature_size, out_features=1024),
            nn.ReLU(True), )

        self.share_feat = nn.Sequential(
            nn.Linear(in_features=1024, out_features=latent_space_dim),
            nn.ReLU(True),
        )

        self.source_feat = nn.Sequential(
            nn.Linear(in_features=1024, out_features=latent_space_dim),
            nn.ReLU(True),
        )

        self.target_feat = nn.Sequential(
            nn.Linear(in_features=1024, out_features=latent_space_dim),
            nn.ReLU(True),
        )

    def forward(self, input_data):
        if (input_data.shape[1] == 1) & (self.nb_channels == 3):
            input_data = input_data.repeat(1, 3, 1, 1)
        feat = self.conv_feat(input_data)
        feat = feat.view(-1, self.dense_feature_size)
        feat = self.dense_feat(feat)
        z_share = self.share_feat(feat)
        z_source = self.source_feat(feat)
        z_target = self.target_feat(feat)
        return z_share, z_source, z_target

    def forward_share(self, input_data):
        if (input_data.shape[1] == 1) & (self.nb_channels == 3):
            input_data = input_data.repeat(1, 3, 1, 1)
        feat = self.conv_feat(input_data)
        feat = feat.view(-1, self.dense_feature_size)
        feat = self.dense_feat(feat)
        z_share = self.share_feat(feat)
        return z_share

    def forward_source(self, input_data):
        if (input_data.shape[1] == 1) & (self.nb_channels == 3):
            input_data = input_data.repeat(1, 3, 1, 1)
        feat = self.conv_feat(input_data)
        feat = feat.view(-1, self.dense_feature_size)
        feat = self.dense_feat(feat)
        z_source = self.source_feat(feat)
        return z_source

    def forward_target(self, input_data):
        if (input_data.shape[1] == 1) & (self.nb_channels == 3):
            input_data = input_data.repeat(1, 3, 1, 1)
        feat = self.conv_feat(input_data)
        feat = feat.view(-1, self.dense_feature_size)
        feat = self.dense_feat(feat)
        z_target = self.target_feat(feat)
        return z_target


In [5]:
def get_simple_classifier(latent_space_dim=1024):
    return nn.Sequential(nn.Dropout2d(0.55),
                         nn.Linear(in_features=latent_space_dim, out_features=10),
                         nn.LogSoftmax())

In [16]:
class GradReverse(torch.autograd.Function):
    """Extension of grad reverse layer."""

    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.neg()
        return grad_output, None

    def grad_reverse(x):
        return GradReverse.apply(x)


class DisentangledDomainAdaptationNetwork(nn.Module):
    def __init__(self, encoder, decoder_source, decoder_target, classifier):
        super(DisentangledDomainAdaptationNetwork, self).__init__()
        self.encoder = encoder
        self.decoder_source = decoder_source
        self.decoder_target = decoder_target
        self.classifier = classifier
        latent_space_dim = self.encoder.latent_space_dim

        self.random_source = nn.Sequential(
            nn.Linear(latent_space_dim, 50))

        self.random_target = nn.Sequential(
            nn.Linear(latent_space_dim, 50))

        self.random_share = nn.Sequential(
            nn.Linear(latent_space_dim, 50))

        self.spe_predictor = nn.Sequential(
            nn.Linear(latent_space_dim, 100),
            nn.ReLU(True),
            nn.Linear(100, 50))

        self.share_predictor = nn.Sequential(
            nn.Linear(latent_space_dim, 100),
            nn.ReLU(True),
            nn.Linear(100, 50))

    def forward(self, x):
        z_share = self.encoder.forward_share(x)
        logits = self.classifier(z_share)
        return logits

    def forward_s(self, x):
        z_share, z_source, _ = self.encoder(x)
        reco_s = self.decoder_source(z_share, z_source)
        logits = self.classifier(z_share)

        z_spe_rev = GradReverse.grad_reverse(z_source)
        z_share_rev = GradReverse.grad_reverse(z_share)
        pred_spe = torch.cos(self.spe_predictor(z_share_rev))
        pred_share = torch.cos(self.share_predictor(z_spe_rev))
        random_spe = torch.cos(self.random_source(z_source))
        random_share = torch.cos(self.random_share(z_share))

        return reco_s, logits, (z_share, z_source), (random_share, random_spe), (pred_share, pred_spe)

    def forward_s_rand(self, z_share, x_rand):
        z_source = self.encoder.forward_source(x_rand)
        reco_s_rand = self.decoder_source(z_share, z_source)
        return reco_s_rand

    def forward_t(self, x):
        z_share, _, z_target = self.encoder(x)
        reco_t = self.decoder_target(z_share, z_target)
        logits = self.classifier(z_share)

        z_spe_rev = GradReverse.grad_reverse(z_target)
        z_share_rev = GradReverse.grad_reverse(z_share)
        pred_spe = torch.cos(self.spe_predictor(z_share_rev))
        pred_share = torch.cos(self.share_predictor(z_spe_rev))
        random_spe = torch.cos(self.random_target(z_target))
        random_share = torch.cos(self.random_share(z_share))

        return reco_t, logits, (z_share, z_target), (random_share, random_spe), (pred_share, pred_spe)

    def forward_t_rand(self, z_share, x_rand):
        z_target = self.encoder.forward_target(x_rand)
        reco_t_rand = self.decoder_target(z_share, z_target)
        return reco_t_rand

    def forward_st(self, z_source, x):
        z_share = self.encoder.forward_share(x)
        reco_st = self.decoder_source(z_share, z_source)
        logits = self.classifier(z_share)
        return reco_st, logits, z_share

    def forward_ts(self, z_target, x):
        z_share = self.encoder.forward_share(x)
        reco_ts = self.decoder_target(z_share, z_target)
        logits = self.classifier(z_share)
        return reco_ts, logits, z_share

    def forward_target(self, x):
        z_share, _, z_target = self.encoder(x)
        logits = self.classifier(z_share)
        return logits, z_target

    def forward_source(self, x):
        z_share, z_source, _ = self.encoder(x)
        logits = self.classifier(z_share)

        return logits, z_source

In [17]:
from tqdm import tqdm
def train_domain_adaptation(model, optimizer, source_train_loader, target_train_loader,
                                                epochs=30, beta_max=10, running_beta=20, alpha=1, show_images=False):
    criterion_reconstruction = nn.BCELoss()
    criterion_classifier = nn.NLLLoss(reduction='mean')
    criterion_weighted_classifier = nn.NLLLoss(reduction='none')
    criterion_disentangle = nn.MSELoss()
    criterion_distance = nn.MSELoss()
    criterion_triplet = nn.TripletMarginLoss(margin=1)

    betas = np.zeros(epochs)
    betas[:running_beta] = np.linspace(0, beta_max, running_beta)
    betas[running_beta:] = np.ones(epochs - running_beta) * beta_max

    t = tqdm(range(epochs))
    for epoch in t:
        total_loss = 0
        corrects_source = 0
        corrects_target = 0
        total = 0

        # random images used for disentanglement
        xs_rand = next(iter(source_train_loader))[0].cuda()
        xt_rand = next(iter(target_train_loader))[0].cuda()

        for (x_s, y_s), (x_t, y_t) in zip(source_train_loader, target_train_loader):
            loss = 0
            x_s, y_s, x_t, y_t = x_s.cuda(), y_s.cuda(), x_t.cuda(), y_t.cuda()

            # target batch
            xt_hat, yt_hat, (z_share, z_target), (random_share, random_spe), (pred_share, pred_spe) = model.forward_t(x_t)
            xts = model.forward_s_rand(z_share, xs_rand[:len(x_t)])
            z_s = model.encoder.forward_share(xts.detach())
            z_target_prime = model.encoder.forward_target(xt_rand[:len(x_t)])
            xt_prime = model.decoder_target(z_share, z_target_prime)
            yt_tilde, z_target_tilde = model.forward_target(xt_prime)

            w, predicted = yt_hat.max(1)
            corrects_target += predicted.eq(y_t).sum().item()

            loss += alpha * criterion_reconstruction(xt_hat, x_t)
            loss += betas[epoch] * criterion_distance(z_share, z_s)
            loss += criterion_disentangle(pred_share, random_share) + criterion_disentangle(pred_spe, random_spe)
            loss += 10 * torch.mean((torch.exp(w.detach()) * criterion_weighted_classifier(yt_tilde, predicted.detach())))
            loss += 10 * criterion_triplet(z_target_tilde, z_target_prime, z_target)

            # source batch
            xs_hat, ys_hat, (z_share, z_source), (random_share, random_spe), (pred_share, pred_spe) = model.forward_s(x_s)
            xst = model.forward_t_rand(z_share, xt_rand[:len(x_s)])
            z_t = model.encoder.forward_share(xst.detach())
            z_source_prime = model.encoder.forward_source(xs_rand[:len(x_s)])
            xs_prime = model.decoder_source(z_share, z_source_prime)
            ys_tilde, z_source_tilde = model.forward_source(xs_prime)

            _, predicted = ys_hat.max(1)
            corrects_source += predicted.eq(y_s).sum().item()
            total += y_s.size(0)

            loss += alpha * criterion_reconstruction(xs_hat, x_s)
            loss += criterion_classifier(ys_hat, y_s)
            loss += betas[epoch] * criterion_disentangle(z_share, z_t)
            loss += (criterion_disentangle(pred_share, random_share) + criterion_disentangle(pred_spe, random_spe))
            loss += 10 * criterion_classifier(ys_tilde, y_s.cuda())
            loss += 10 * criterion_triplet(z_source_tilde, z_source_prime, z_source)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += float(loss.data)
            t.set_description(f'epoch:{epoch} current target accuracy:{round(corrects_target / total * 100, 2)}%')
        # ===================log========================
        print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epochs, total_loss / len(source_train_loader)))
        print(f'accuracy source: {round(corrects_source / total * 100, 2)}%')
        print(f'accuracy target: {round(corrects_target / total * 100, 2)}%')
        if show_images:
            show_decoded_images(x_s, xs_hat, xs_prime, xst)
            show_decoded_images(x_t, xt_hat, xt_prime, xts)


In [19]:
learning_rate = 5e-4

encoder = Encoder(latent_space_dim=200, img_size=(3,32,32), nb_channels=3)
conv_feat_size = encoder.conv_feat_size
decoder_source = Decoder(latent_space_dim=400, conv_feat_size=conv_feat_size, nb_channels=1)
decoder_target = Decoder(latent_space_dim=400, conv_feat_size=conv_feat_size, nb_channels=3)
classifier = get_simple_classifier(latent_space_dim=200)
model = DisentangledDomainAdaptationNetwork(encoder, decoder_source, decoder_target, classifier).cuda()

optimizer = torch.optim.Adam([
    {'params': model.encoder.parameters()},
    {'params': model.decoder_source.parameters()},
    {'params': model.decoder_target.parameters()},
    {'params': model.classifier.parameters()},
    {'params': model.spe_predictor.parameters()},
    {'params': model.share_predictor.parameters()}], lr=learning_rate, weight_decay=0.001)

train_domain_adaptation(model, optimizer, source_train_loader, target_train_loader,
                                            epochs=30, beta_max=1, running_beta=10)

epoch:0 current target accuracy:16.85%:   3%|▎         | 1/30 [01:14<35:49, 74.14s/it]

epoch [1/30], loss:11.3741
accuracy source: 80.29%
accuracy target: 16.85%


epoch:1 current target accuracy:19.76%:   7%|▋         | 2/30 [02:28<34:38, 74.22s/it]

epoch [2/30], loss:4.5395
accuracy source: 97.38%
accuracy target: 19.76%


epoch:2 current target accuracy:20.76%:  10%|█         | 3/30 [03:44<33:34, 74.63s/it]

epoch [3/30], loss:3.9021
accuracy source: 98.2%
accuracy target: 20.76%


epoch:3 current target accuracy:19.44%:  13%|█▎        | 4/30 [04:59<32:25, 74.81s/it]

epoch [4/30], loss:3.5301
accuracy source: 98.4%
accuracy target: 19.44%


epoch:4 current target accuracy:19.58%:  13%|█▎        | 4/30 [05:53<38:15, 88.28s/it]


KeyboardInterrupt: 

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.001)
train_disentangle_domain_adaptation_network(model, optimizer, source_train_loader, target_train_loader,
      epochs=20, beta_max=10, gamma=0.9, running_beta=0)

In [None]:
test_network(model, target_test_loader)

In [None]:
import matplotlib.pyplot as plt
def plot_results2(model, source_train_loader, target_train_loader):
    x, _ = next(iter(source_train_loader))
    
    reco_s, pred, (z_share, z_source), (random_share, random_spe), (pred_share, pred_spe) = model.forward_s(x.cuda())
    _, predicted = pred.max(1)
    
    t = []
    nb_ex = 5
    for k in range(nb_ex):
        x_rand_t = next(iter(target_train_loader))[0].cuda()
        t.append(model.forward_t_rand(z_share, x_rand_t[:len(x)]))
      

    for i in range(32):
        plt.figure(figsize=(7, 4))
        plt.subplot(1, nb_ex + 1, 1)
        plt.imshow(x[i].cpu().detach().permute(1, 2, 0))
        plt.axis('off')
        plt.tight_layout()
        plt.title(f'prediction: {int(predicted[i])}')
        for k in range(nb_ex):
            plt.subplot(1, nb_ex + 1, k+2)
            plt.imshow(t[k][i].cpu().detach()[0], cmap='gray')
            plt.axis('off')
            plt.tight_layout()
        
    x, _ = next(iter(target_train_loader))
    reco_t, pred, (z_share, z_target), (random_share, random_spe), (pred_share, pred_spe) = model.forward_t(x.cuda())
    _, predicted = pred.max(1)
    
    t = []
    for k in range(nb_ex):
        x_rand = next(iter(source_train_loader))[0].cuda()
        t.append(model.forward_s_rand(z_share, x_rand[:len(x)]))

    for i in range(32):
        plt.figure(figsize=(7, 4))
        plt.subplot(1, nb_ex + 1, 1)
        plt.imshow(x[i].cpu().detach()[0], cmap='gray')
        plt.axis('off')
        plt.tight_layout()
        plt.title(f'prediction: {int(predicted[i])}')
        for k in range(nb_ex):
            plt.subplot(1, nb_ex + 1, k+2)
            plt.imshow(t[k][i].cpu().detach().permute(1, 2, 0))
            plt.axis('off')
            plt.tight_layout()

plot_results2(model, source_train_loader, target_train_loader)

In [None]:
X, _ = next(iter(target_train_loader))
y, _, (z_share, z_spe),  _, _ = model.forward_t(X.cuda())
#blank
plt.subplot(1,6,1)
plt.imshow(torch.ones((32,32,3)))
plt.axis('off')
plt.tight_layout()
#styles
for i in range(5):
    plt.subplot(1,6,i+2)
    plt.imshow(X[i].cpu().detach()[0], cmap='gray')
    plt.axis('off')
    plt.tight_layout()

for j in range(10, 20):
    plt.figure()
    plt.subplot(1,6,1)
    plt.imshow(X[j][0], cmap='gray')
    plt.axis('off')
    plt.tight_layout()

    z_x = torch.zeros_like(z_share)
    z_x[:] = z_share[j]
    y2  = model.decoder_target(z_x, z_spe)
    for i in range(5):
        plt.subplot(1,6,i+2)
        plt.imshow(y2[i].cpu().detach()[0], cmap='gray')
        plt.axis('off')
        plt.tight_layout()

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
def extract_features(train_loader, sample_count):
    features = np.zeros(shape=(sample_count, 75))
    labels = np.zeros(shape=(sample_count))
    batch_size=128
    i = 0
    for x, labels_batch in train_loader:
        features_batch = encoder(x.cuda())[0]
        features[i * batch_size : (i + 1) * batch_size] = features_batch.cpu().detach().numpy()
        labels[i * batch_size : (i + 1) * batch_size] = labels_batch.numpy()
        i += 1
        
        if i * batch_size >= sample_count:
            # Note that since generators yield data indefinitely in a loop,
            # we must `break` after every image has been seen once.
            break
    return features, labels.astype(int)

In [None]:
f_s, s_labels = extract_features(source_train_loader, 640)
f_t, t_labels = extract_features(target_train_loader, 640)

f = np.zeros((1280, 75))
f[:640] = f_s
f[640:] = f_t
pca = PCA(n_components=50)
X_pca = pca.fit_transform(f)
X_tsne = TSNE(2, perplexity=50).fit_transform(X_pca)
X_tsne_x = X_tsne[:,0]
X_tsne_y = X_tsne[:,1]
plt.figure(figsize=(16,10))
color_map=plt.cm.rainbow(np.linspace(0,1,10))
for i, g in enumerate(np.unique(s_labels)):
    color = color_map[i]
    ix = np.where(s_labels == g)
    plt.scatter(x=X_tsne_x[ix], y=X_tsne_y[ix], color=color , marker='o', alpha=0.5, label = g)
for i, g in enumerate(np.unique(t_labels)):
    color = color_map[i]
    ix = np.where(t_labels == g)
    plt.scatter(x=X_tsne_x[640:][ix], y=X_tsne_y[640:][ix], color=color, marker='+', label = g)
    plt.legend()

In [None]:
from advertorch.utils import predict_from_logits
from advertorch_examples.utils import get_mnist_test_loader
from advertorch_examples.utils import _imshow


In [None]:
from advertorch.attacks import LinfPGDAttack

adversary = LinfPGDAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.15,
    nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0,
    targeted=False)

In [None]:

x, label = next(iter(source_train_loader))
x = x.cuda()
label = label.cuda()
adv_untargeted = adversary.perturb(x, label)

In [None]:
target = torch.ones_like(label) * 3
adversary.targeted = True
adv_targeted = adversary.perturb(x, target)

In [None]:
pred_cln = predict_from_logits(model(x))
pred_untargeted_adv = predict_from_logits(model(adv_untargeted))
pred_targeted_adv = predict_from_logits(model(adv_targeted))
batch_size=8
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 8))
for ii in range(batch_size):
    plt.subplot(3, batch_size, ii + 1)
    _imshow(x[ii])
    plt.title("clean \n pred: {}".format(pred_cln[ii]))
    plt.subplot(3, batch_size, ii + 1 + batch_size)
    _imshow(adv_untargeted[ii])
    plt.title("untargeted \n adv \n pred: {}".format(
        pred_untargeted_adv[ii]))
    plt.subplot(3, batch_size, ii + 1 + batch_size * 2)
    _imshow(adv_targeted[ii])
    plt.title("targeted to 3 \n adv \n pred: {}".format(
        pred_targeted_adv[ii]))

plt.tight_layout()
plt.show()