In [1]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import gpytorch
import numpy as np
import torch
import winsound
from matplotlib import pyplot as plt
from LDGD.model import LDGD, FastLDGD, VAE
from LDGD.visualization.vizualize_utils import plot_heatmap, plot_2d_scatter, plot_ARD_gplvm
from LDGD.visualization.vizualize_utils import plot_loss_gplvm, plot_scatter_gplvm
from gpytorch.likelihoods import GaussianLikelihood, BernoulliLikelihood

import json
%matplotlib inline
%load_ext autoreload
%autoreload 2
duration = 1000  # milliseconds
freq = 440  # Hz
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
def create_dataset(random_state, test_size, dataset='mnist', **kwargs):
    if dataset == 'mnist':
        mnist_train = MNIST(root='./data', train=True, download=True, transform=ToTensor())
        mnist_test = MNIST(root='./data', train=False, download=True, transform=ToTensor())

        # Flatten the images and convert labels
        x_train = mnist_train.data.view(mnist_train.data.size(0), -1).numpy()
        x_train = x_train/x_train.max()
        y_train = mnist_train.targets.numpy()

        # Concatenate train and test sets to split them later
        x_test = mnist_test.data.view(mnist_test.data.size(0), -1).numpy()
        x_test = x_test/x_test.max()
        y_test = mnist_test.targets.numpy()

        # One-hot encode the labels
        y_one_hot_train = np.zeros((y_train.shape[0], len(np.unique(y_train))))
        y_one_hot_train[np.arange(y_train.shape[0]), y_train] = 1

        y_one_hot_test = np.zeros((y_test.shape[0], len(np.unique(y_test))))
        y_one_hot_test[np.arange(y_test.shape[0]), y_test] = 1

        orig_data = None  # No original data in the case of MNIST


    # Convert to PyTorch tensors
    X_train_tensor = torch.tensor(x_train, dtype=torch.float32)
    X_test_tensor = torch.tensor(x_test, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_one_hot_train, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_one_hot_test, dtype=torch.float32)
    y_train_labels_tensor = torch.tensor(y_train)
    y_test_labels_tensor = torch.tensor(y_test)

    return X_train_tensor, X_test_tensor, y_train_tensor, y_test_tensor, y_train_labels_tensor, y_test_labels_tensor, orig_data


In [3]:
def create_LDGD_model(data_cont, data_cat, ldgd_settings, batch_shape, x_init='pca'):
    if ldgd_settings['use_gpytorch'] is False:
        pass
    else:
        kernel_reg = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=ldgd_settings['latent_dim']))
        kernel_cls = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=ldgd_settings['latent_dim']))

    likelihood_reg = GaussianLikelihood(batch_shape=batch_shape)
    likelihood_cls = BernoulliLikelihood()
    model = LDGD(data_cont,
             kernel_reg=kernel_reg,
             kernel_cls=kernel_cls,
             num_classes=data_cat.shape[-1],
             latent_dim=ldgd_settings['latent_dim'],
             num_inducing_points_cls= ldgd_settings['num_inducing_points_cls'],
             num_inducing_points_reg= ldgd_settings['num_inducing_points_reg'],
             likelihood_reg=likelihood_reg,
             likelihood_cls=likelihood_cls,
             use_gpytorch=ldgd_settings['use_gpytorch'],
             shared_inducing_points=ldgd_settings['shared_inducing_points'],
             use_shared_kernel=False,
             x_init=x_init,
             device=device)

    return model

def create_FastLDGD_model(data_cont, data_cat, batch_shape, ldgd_settings):
    kernel_reg = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=ldgd_settings['latent_dim']))
    kernel_cls = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=ldgd_settings['latent_dim']))

    likelihood_reg = GaussianLikelihood(batch_shape=batch_shape)
    likelihood_cls = BernoulliLikelihood()
    model = FastLDGD(data_cont,
             kernel_reg=kernel_reg,
             kernel_cls=kernel_cls,
             num_classes=data_cat.shape[-1],
             latent_dim=ldgd_settings['latent_dim'],
             num_inducing_points_cls= ldgd_settings['num_inducing_points_cls'],
             num_inducing_points_reg= ldgd_settings['num_inducing_points_reg'],
             likelihood_reg=likelihood_reg,
             likelihood_cls=likelihood_cls,
             use_gpytorch=ldgd_settings['use_gpytorch'],
             shared_inducing_points=ldgd_settings['shared_inducing_points'],
             use_shared_kernel=False,
             device=device)

    return model

In [7]:
model_settings = {
    'latent_dim': 10,
    'num_inducing_points_reg': 100,
    'num_inducing_points_cls': 100,
    'num_epochs_train': 20000,
    'num_epochs_test': 20000,
    'batch_size': 700,
    'load_trained_model': False,
    'load_tested_model': False,
    'shared_inducing_points': True,
    'use_gpytorch': True,
    'random_state': 65,
    'test_size': 0.2,
    'cls_weight': 1.0,
    'reg_weight': 1.0,
    'num_samples': 500,

}
np.random.seed(model_settings['random_state'])


In [8]:
# load raw data
yn_train, yn_test, ys_train, ys_test, labels_train, labels_test, _ = create_dataset(random_state=model_settings['random_state'], test_size=0.2, dataset='mnist')
yn_train = yn_train/yn_train.max()
yn_test = yn_test/yn_test.max()


In [12]:
load_saved_result = False
batch_shape = torch.Size([yn_train.shape[-1]])
metric_fastldgd_list = []
model = create_LDGD_model(data_cont=yn_train, data_cat=ys_train, ldgd_settings=model_settings, batch_shape=batch_shape)

if load_saved_result is False:
    losses, history_train = model.train_model(yn=yn_train, ys=ys_train,
                                              epochs=model_settings['num_epochs_train'],
                                              batch_size=model_settings['batch_size'], monitor_mse=False)
    model.save_wights(path_save='./saved_models/', file_name=f"model_mnist_fast")
else:
    model.load_weights(path_save='./saved_models/', file_name=f"model_mnist_fast.pth")


winsound.Beep(freq, duration*3)


In [13]:
predictions, metrics, history_test = model.evaluate(yn_test=yn_test, ys_test=labels_test,
                                                    epochs=model_settings['num_epochs_test'])
winsound.Beep(freq, duration*3)


In [None]:
metrics

In [14]:

alpha_reg = 1 / model.kernel_reg.base_kernel.lengthscale.cpu().detach().numpy()
alpha_cls = 1 / model.kernel_cls.base_kernel.lengthscale.cpu().detach().numpy()

x = model.x.q_mu.cpu().detach().numpy()
std = torch.nn.functional.softplus(model.x.q_log_sigma).cpu().detach().numpy()

x_test = model.x_test.q_mu.cpu().detach().numpy()
std_test = torch.nn.functional.softplus(model.x_test.q_log_sigma).cpu().detach().numpy()

inducing_points = (history_test['z_list_reg'][-1], history_test['z_list_cls'][-1])

latent_dim = x.shape[-1]
values, indices = torch.topk(torch.tensor(alpha_cls), k=2, largest=True)
l1 = indices.numpy().flatten()[0]
l2 = indices.numpy().flatten()[1]

In [15]:
color_list = ['r', 'b', 'g', 'c', 'm', 'y', 'k', 'lime', 'navy', 'teal']
fig, axs = plt.subplots(2,3, figsize=(32, 20))


plot_loss_gplvm(losses, ax=axs[0,0])
plot_ARD_gplvm(model_settings['latent_dim'], alpha_cls, ax=axs[0,2])
plot_ARD_gplvm(model_settings['latent_dim'], alpha_reg, ax=axs[0,1])
plot_scatter_gplvm(x, labels_train, l1=l1, l2=l2, ax=axs[1,0], colors=color_list, show_errorbars=False, std=std)
plot_scatter_gplvm(x_test, labels_test, l1=l1, l2=l2, ax=axs[1,1], colors=color_list, show_errorbars=False, std=std_test)

plot_heatmap(x, labels_train, model, alpha_cls, cmap='winter', range_scale=1.2,
             file_name='latent_heatmap_train', inducing_points=inducing_points, ax1=axs[1,2], fig=fig)


plt.tight_layout()
fig.savefig("ARD_synthetic4.png")
fig.savefig("ARD_synthetic4.svg")

In [None]:
rec_img, predictions_std = model.regress_x(x_test[1:45])
plt.imshow(rec_img.view(-1,28,28).cpu().detach().numpy()[14], cmap='gray')
#plt.imshow(predictions_std.view(-1,28,28).cpu().detach().numpy()[5], cmap='gray')


In [None]:
num_images = rec_img.size(0)
random_indices = np.random.choice(num_images, 9, replace=False)

# Create a 3x3 grid of subplots
fig, axes = plt.subplots(3, 3, figsize=(9, 9))

for idx, ax in zip(random_indices, axes.ravel()):
    # Convert tensor image to numpy and display it
    img = rec_img.view(-1, 28, 28).cpu().detach().numpy()[idx]
    ax.imshow(img, cmap='gray')
    ax.axis('off')  # Turn off axis

plt.tight_layout()
plt.savefig("generation_mnist.png")
plt.show()

In [None]:
num_images = rec_img.size(0)
random_indices = np.random.choice(num_images, 9, replace=False)

# Create a 3x3 grid of subplots
fig, axes = plt.subplots(1, 9, figsize=(50, 9))

for idx, ax in zip(random_indices, axes.ravel()):
    # Convert tensor image to numpy and display it
    img = rec_img.view(-1, 28, 28).cpu().detach().numpy()[idx]
    ax.imshow(img, cmap='gray')
    ax.axis('off')  # Turn off axis

plt.tight_layout()
plt.savefig("generation_mnist2.png")
plt.show()

In [None]:
from time import time

import numpy as np
import pandas as pd


# For plotting
from matplotlib import offsetbox
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects
import seaborn as sns
import plotly.graph_objects as go

%matplotlib inline
sns.set(style='white', context='notebook', rc={'figure.figsize':(14,10)})

#For standardising the dat
from sklearn.preprocessing import StandardScaler

#PCA
from sklearn.manifold import TSNE

#Ignore warnings
import warnings
warnings.filterwarnings('ignore')

In [None]:
mnist_train = MNIST(root='./data', train=True, download=True, transform=ToTensor())
mnist_test = MNIST(root='./data', train=False, download=True, transform=ToTensor())

x = mnist_train.data.view(mnist_train.data.size(0), -1).numpy()
y = mnist_train.targets.numpy()

# Concatenate train and test sets to split them later
X_test = mnist_test.data.view(mnist_test.data.size(0), -1).numpy()
X_test = X_test/X_test.max()
y_test = mnist_test.targets.numpy()

In [None]:
standardized_data = StandardScaler().fit_transform(x)
print(standardized_data.shape)
x_subset = x[0:10000]
y_subset = y[0:10000]

In [None]:
tsne = TSNE(random_state = 42, n_components=2,verbose=0, perplexity=40, n_iter=300).fit_transform(x_subset)

In [None]:
plt.scatter(tsne[:, 0], tsne[:, 1], s= 5, c=y_subset, cmap='Spectral')
plt.gca().set_aspect('equal', 'datalim')
plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))
plt.title('Visualizing MNIST through t-SNE', fontsize=24)

In [None]:
model_settings = {
    'latent_dim': 2,
    'num_inducing_points_reg': 10,
    'num_inducing_points_cls': 10,
    'num_epochs_train': 2000,
    'num_epochs_test': 2000,
    'batch_size': 1000,
    'load_trained_model': False,
    'load_tested_model': False,
    'shared_inducing_points': True,
    'use_gpytorch': True,
    'random_state': 65,
    'test_size': 0.2,
    'cls_weight': 0.0,
    'reg_weight': 1.0,

}
y_one_hot_subset = np.zeros((y_subset.shape[0], len(np.unique(y_subset))))
y_one_hot_subset[np.arange(y_subset.shape[0]), y_subset] = 1


batch_shape = torch.Size([x_subset.shape[-1]])
metric_fastldgd_list = []
model = create_LDGD_model(data_cont=torch.tensor(x_subset/x_subset.max()), data_cat=torch.tensor(y_one_hot_subset), ldgd_settings=model_settings, batch_shape=batch_shape, x_init=None)

losses, history_train = model.train_model(yn=torch.tensor(x_subset/x_subset.max()), ys=torch.tensor(y_one_hot_subset),
                                          epochs=model_settings['num_epochs_train'],
                                          batch_size=model_settings['batch_size'], monitor_mse=False)


winsound.Beep(freq, duration*3)


In [None]:
alpha_reg = 1 / model.kernel_reg.base_kernel.lengthscale.cpu().detach().numpy()
alpha_cls = 1 / model.kernel_cls.base_kernel.lengthscale.cpu().detach().numpy()

latent_mean = model.x.q_mu.cpu().detach().numpy()
latent_std = torch.nn.functional.softplus(model.x.q_log_sigma).cpu().detach().numpy()

inducing_points = (history_train['z_list_reg'][-1], history_train['z_list_cls'][-1])

latent_dim = latent_mean.shape[-1]
values, indices = torch.topk(torch.tensor(alpha_cls), k=2, largest=True)
l1 = indices.numpy().flatten()[0]
l2 = indices.numpy().flatten()[1]

In [None]:
plt.scatter(latent_mean[:, l1], latent_mean[:, l2], s= 5, c=y_subset, cmap='Spectral')
plt.gca().set_aspect('equal', 'datalim')
plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))
plt.title('Visualizing MNIST through t-SNE', fontsize=24)