# Simulated Data across Methods

In [3]:
from mvlearn.embed.kcca_experimental import KCCA
from mvlearn.embed.dcca import DCCA
from mvlearn.embed.gcca import GCCA
from mvlearn.embed.mvmds import MVMDS
from mvlearn.datasets.GaussianMixture import GaussianMixture
from mvlearn.plotting.plot import crossviews_plot

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
import seaborn as sns

## Raw Data

In [21]:
## Make Latents
n = 200
mu = [[0,1], [0,-1]]
sigma = 2*[np.eye(2), np.eye(2)]
pi = [0.5,0.5]
GM_train = GaussianMixture(n,mu,sigma,class_probs=pi)

## Test
GM_test = GaussianMixture(n,mu,sigma,class_probs=pi)

In [22]:
## Make 2 views
n_noise = 2
transforms = ['linear', 'poly', 'sin']

Xs_train = []
Xs_test = []
for transform in transforms:
    GM_train.sample_views(transform=transform, n_noise=n_noise)
    GM_test.sample_views(transform=transform, n_noise=n_noise)

    Xs_train.append(GM_train.get_Xy()[0])
    Xs_test.append(GM_test.get_Xy()[0])

In [112]:
## Plotting parameters
labels = GM_test.latent[:,0]
#cmap = matplotlib.colors.ListedColormap(sns.color_palette("husl", len(labels)).as_hex())
cmap = matplotlib.colors.ListedColormap(sns.diverging_palette(240, 10, n=len(labels), center='light').as_hex())
cmap = 'coolwarm'

save_dir = '/mnt/c/Users/Ronan Perry/Documents/JHU/jovo-lab/multiview/multiview/paper_figures/husl/'

context='poster'
show=False

scatter_kwargs = {'alpha':1.0}

In [111]:
for i,transform in enumerate(transforms):
    crossviews_plot(Xs_train[i], labels, dimensions=[0,1], ax_ticks=False, ax_labels=False, equal_axes=True, context=context, cmap=cmap, show=show, scatter_kwargs=scatter_kwargs)
    if not show:
        plt.savefig(save_dir + f'{transform}_train.png')
        plt.close()
    crossviews_plot(Xs_test[i], labels, dimensions=[0,1], ax_ticks=False, ax_labels=False, equal_axes=True, context=context,cmap=cmap, show=show, scatter_kwargs=scatter_kwargs)
    if not show:
        plt.savefig(save_dir + f'{transform}_test.png')
        plt.close()

## Linear KCCA

In [103]:
ktype='linear'
kcca = KCCA(ktype=ktype, reg = 0.1, degree=2.0, constant=0.1, n_components = 2, test=True)
for i,transform in enumerate(transforms):
    components = kcca.fit(Xs_train[i]).transform(Xs_test[i])

    crossviews_plot(components, labels, ax_ticks=False, ax_labels=False, equal_axes=True, context=context, cmap=cmap, show=show, scatter_kwargs=scatter_kwargs)
    if not show:
        plt.savefig(save_dir + f'{ktype}-kcca_{transform}-test.png')
        plt.close()

## Polynomial KCCA

In [104]:
ktype='poly'
kcca = KCCA(ktype=ktype, reg = 0.1, degree=2.0, constant=0.1, n_components = 2, test=True)
for i,transform in enumerate(transforms):
    components = kcca.fit(Xs_train[i]).transform(Xs_test[i])

    crossviews_plot(components, labels, ax_ticks=False, ax_labels=False, equal_axes=True, context=context, cmap=cmap, show=show, scatter_kwargs=scatter_kwargs)
    if not show:
        plt.savefig(save_dir + f'{ktype}-kcca_{transform}-test.png')
        plt.close()

## Gaussian

In [105]:
ktype='gaussian'
kcca = KCCA(ktype=ktype, reg = 1.0, sigma=2.0, n_components = 2, test=True)
for i,transform in enumerate(transforms):
    components = kcca.fit(Xs_train[i]).transform(Xs_test[i])

    crossviews_plot(components, labels, ax_ticks=False, ax_labels=False, equal_axes=True, context=context, cmap=cmap, show=show, scatter_kwargs=scatter_kwargs)
    if not show:
        plt.savefig(save_dir + f'{ktype}-kcca_{transform}-test.png')
        plt.close()

## DCCA

In [108]:
for i,transform in enumerate(transforms):
    input_size1, input_size2 = Xs_train[i][0].shape[1], Xs_train[i][1].shape[1]
    outdim_size = min(Xs_train[i][0].shape[1], 2)
    layer_sizes1 = [256, 512, outdim_size]
    layer_sizes2 = [256, 512, outdim_size]
    dcca = DCCA(input_size1, input_size2, outdim_size, layer_sizes1, layer_sizes2, epoch_num=400)
    components = dcca.fit(Xs_train[i]).transform(Xs_test[i])

    crossviews_plot(components, labels, ax_ticks=False, ax_labels=False, equal_axes=True, context=context, cmap=cmap, show=show, scatter_kwargs=scatter_kwargs)
    if not show:
        plt.savefig(save_dir + f'dcca_{transform}-test.png')
        plt.close()