In [None]:
import torch
import torch.nn as nn
import numpy as np
import json
import matplotlib.pyplot as plt
%matplotlib inline

from input_pipeline import get_datasets
from network import Network

# https://github.com/DmitryUlyanov/Multicore-TSNE
from MulticoreTSNE import MulticoreTSNE as TSNE

# Get validation data

In [None]:
svhn, mnist = get_datasets(is_training=False)

# Load feature extractor

In [None]:
embedder = Network(image_size=(32, 32), embedding_dim=64).cuda()
classifier = nn.Linear(64, 10).cuda()
model = nn.Sequential(embedder, classifier)
model.load_state_dict(torch.load('models/svhn_source'))
model.eval()
model = model[0]  # only embedding

# Extract features

In [None]:
def predict(dataset):
    X, y = [], []

    for image, label in dataset:
        x = model(image.unsqueeze(0).cuda())
        X.append(x.detach().cpu().numpy())
        y.append(label)

    X = np.concatenate(X, axis=0)
    y = np.stack(y)
    return X, y

In [None]:
X_svhn, y_svhn = predict(svhn)
X_mnist, y_mnist = predict(mnist)

# Plot tsne

In [None]:
tsne = TSNE(perplexity=200.0, n_jobs=12)
P = tsne.fit_transform(np.concatenate([X_svhn, X_mnist], axis=0))

P_svhn = P[:len(X_svhn)]
P_mnist = P[len(X_svhn):]

In [None]:
plt.figure(figsize=(15, 8))
plt.scatter(P_svhn[:, 0], P_svhn[:, 1], c=y_svhn, cmap='tab10', marker='.', label='svhn')
plt.scatter(P_mnist[:, 0], P_mnist[:, 1], marker='s', c='w', edgecolors='k', label='mnist', alpha=0.3)
plt.title('source is svhn, target is mnist')
plt.legend();

# Plot loss curves

In [None]:
with open('logs/mnist_source.json', 'r') as f:
    logs = json.load(f)

In [None]:
fig, axes = plt.subplots(1, 3, sharex=True, figsize=(15, 5), dpi=100)
axes = axes.flatten()
plt.suptitle('source is MNIST, target is SVHN', fontsize='x-large', y=1.05)

axes[0].plot(logs['step'], logs['classification_loss'], label='train logloss', c='r')
axes[0].plot(logs['val_step'], logs['svhn_logloss'], label='svhn val logloss', marker='o', c='k')
axes[0].plot(logs['val_step'], logs['mnist_logloss'], label='mnist val logloss', marker='o', c='c')
axes[0].legend()
axes[0].set_title('classification losses');

axes[1].plot(logs['step'], logs['walker_loss'], label='walker loss')
axes[1].plot(logs['step'], logs['visit_loss'], label='visit loss')
axes[1].legend()
axes[1].set_title('domain adaptation losses');

axes[2].plot(logs['val_step'], logs['svhn_accuracy'], label='svhn val', c='k')
axes[2].plot(logs['val_step'], logs['mnist_accuracy'], label='mnist val', c='c')
axes[2].legend()
axes[2].set_title('accuracy')

fig.tight_layout();