In [None]:
from preprocessing import prepare_dataset, prepare_dataset_single
from models import DANN_Model

import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import datetime
from openTSNE import TSNE

In [None]:
(source_train_dataset, source_test_dataset) = prepare_dataset_single('MNIST')
# (_, target_test_dataset_1) = prepare_dataset_single('MNIST')
(target_train_dataset, target_test_dataset) = prepare_dataset_single('SVHN')

In [None]:
lp_lr = 0.01
dc_lr = 0.1
fe_lr = 0.01

lr = (lp_lr, dc_lr, fe_lr)
model = DANN_Model(input_shape=(32, 32, 3), model_type='SVHN', run_name='mnist2svhn', lr=lr)

In [None]:
EPOCHS = 50

for epoch in range(EPOCHS):
    
    print(datetime.datetime.now())
    
    for (source_images, class_labels), (target_images, _) in zip(source_train_dataset, target_train_dataset):
        model.train(source_images, class_labels, target_images)
    
    latent_source = []
    latent_target = []
    for (test_images, test_labels), (target_test_images, target_test_labels) in zip(source_test_dataset, target_test_dataset):
        model.test_source(test_images, test_labels, target_test_images)
        model.test_target(target_test_images, target_test_labels)
        
        if len(latent_source) == 0:
            latent_source = model.return_latent_variables(test_images)
        else:
            latent_source = np.concatenate([latent_source, model.return_latent_variables(test_images)])
            
        if len(latent_target) == 0:
            latent_target = model.return_latent_variables(target_test_images)
        else:
            latent_target = np.concatenate([latent_target, model.return_latent_variables(target_test_images)])
            
    print('Epoch: {}'.format(epoch + 1))
    print(model.log())
    
    index = [0, len(latent_source), len(latent_source) + len(latent_target)]  
    latent_variables = np.concatenate([latent_source, latent_target])

    pca_embedding = PCA(n_components=2).fit_transform(latent_variables)

    plt.figure()
    plt.title('Epoch #{}'.format(epoch + 1))
    for i in range(len(index) - 1):
        plt.plot(pca_embedding[index[i]:index[i+1], 0], pca_embedding[index[i]:index[i+1], 1], '.', alpha=0.5)
    plt.legend(['MNIST', 'SVHN'])
    plt.show()

In [None]:
tsne = TSNE(n_components=2, initialization="pca")

print(datetime.datetime.now())

tsne_embedding = tsne.fit(latent_variables)

print(datetime.datetime.now())

plt.figure()
plt.title('Epoch #{}'.format(epoch + 1))
for i in range(len(index) - 1):
    plt.plot(tsne_embedding[index[i]:index[i+1], 0], tsne_embedding[index[i]:index[i+1], 1], '.', alpha=0.5)
plt.legend(['MNIST', 'SVHN'])
plt.show()

In [None]:
plt.figure()
plt.title('Epoch #{}'.format(epoch + 1))
for i in range(1, -1, -1):
    plt.plot(tsne_embedding[index[i]:index[i+1], 0], tsne_embedding[index[i]:index[i+1], 1], '.', alpha=0.5)
plt.legend(['SVHN', 'MNIST'])
plt.show()