In [None]:
%load_ext autoreload
%autoreload 2

from imports import *
from models import *
from utils import *
from data import *
from configs import CONFIGS, EXP_CODES

In [None]:
'''
load saved experiment
'''


config_name = 'freq'
config = deepcopy(CONFIGS[config_name])

train_dataset, test_dataset, train_loader, test_loader = get_dataset_dataloader(config)

In [None]:
'''
generate random trajectory from test dataset
'''


trajectory, gt_vs = test_dataset.generate_sample_trajectory(length=1500)

trajectory = trajectory.squeeze(dim=1)
gt_velocities = np.array(gt_vs)
colors = np.array(test_dataset.cmap(gt_vs))


if gt_velocities.squeeze().ndim == 1:
    gt_velocities = np.hstack((gt_velocities.reshape(-1, 1), np.zeros((gt_velocities.shape[0], 1))))

In [None]:
isomap_velocities = isomap_reduction(trajectory, reduced_dim=1)
umap_velocities = umap_reduction(trajectory, reduced_dim=1)
pca_velocities = pca_reduction(trajectory, reduced_dim=1)

ae_velocities = ae_reduction(
    tensor=trajectory,
    config=deepcopy(CONFIGS[config_name + '_ae']), 
    model_code=EXP_CODES[config_name + '_ae'][0],
)

# mcnet_velocities = mcnet_reduction(
#     tensor=trajectory,
#     config=deepcopy(CONFIGS[config_name + '_mcnet']), 
#     model_code=EXP_CODES[config_name + '_mcnet'],
# )

In [None]:
isomap_error, transformed_isomap_velocities = compute_error_metric(
    true=gt_velocities,
    pred=isomap_velocities,
    num_clusters=30,
    show_knee_visualization=False,
)

#plotly_scatter(transformed_isomap_velocities, colors=colors, title="transformed isomap velocities, error: {:.4f}".format(error))

In [None]:
umap_error, transformed_umap_velocities = compute_error_metric(
    true=gt_velocities,
    pred=umap_velocities,
    num_clusters=30,
    show_knee_visualization=False,
)

#plotly_scatter(transformed_umap_velocities, colors=colors, title="transformed umap velocities, error: {:.4f}".format(error))

In [None]:
pca_error, transformed_pca_velocities = compute_error_metric(
    true=gt_velocities,
    pred=pca_velocities,
    num_clusters=30,
    show_knee_visualization=False,
)

#plotly_scatter(transformed_pca_velocities, colors=colors, title="transformed pca velocities, error: {:.4f}".format(error))

In [None]:
ae_error, transformed_ae_velocities = compute_error_metric(
    true=gt_velocities,
    pred=ae_velocities,
    num_clusters=30,
    show_knee_visualization=False,
)

#plotly_scatter(transformed_ae_velocities, colors=colors, title="transformed ae velocities, error: {:.4f}".format(error))

In [None]:
plt.rcParams.update({'font.size': 22})

plt.figure()

plt.scatter(*transformed_umap_velocities.T, c=colors)

ax = plt.gca()
ax.set_xlim([-0.02, 0.02])
ax.set_ylim([-0.02, 0.02])
ax.set_xlabel('dim 1')
ax.set_ylabel('dim 2')
ax.set_title('UMAP Error: {:.2f}'.format(umap_error))

ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
ax.set_xticks([])
ax.set_yticks([])

In [None]:
plt.rcParams.update({'font.size': 22})

plt.figure()

plt.scatter(*transformed_isomap_velocities.T, c=colors)

ax = plt.gca()
ax.set_xlim([-0.02, 0.02])
ax.set_ylim([-0.02, 0.02])
ax.set_xlabel('dim 1')
ax.set_ylabel('dim 2')
ax.set_title('Isomap Error: {:.2f}'.format(isomap_error))

ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
ax.set_xticks([])
ax.set_yticks([])

In [None]:
plt.rcParams.update({'font.size': 22})

plt.figure()

plt.scatter(*transformed_ae_velocities.T, c=colors)

ax = plt.gca()
ax.set_xlim([-0.007, 0.007])
ax.set_ylim([-0.02, 0.02])
ax.set_xlabel('dim 1')
ax.set_ylabel('dim 2')
ax.set_title('Autoencoder Error: {:.2f}'.format(ae_error))

ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
ax.set_xticks([])
ax.set_yticks([])

In [None]:
plt.rcParams.update({'font.size': 22})

fig = plt.figure()

ax = fig.add_subplot(projection='3d')

ax.scatter3D(*transformed_isomap_velocities.T, c=colors)

ax.set_xlim([-2.25, 2.25])
ax.set_ylim([-2.25, 2.25])
ax.set_zlim([-2.25, 2.25])
ax.set_xlabel('dim 1')
ax.set_ylabel('dim 2')
ax.set_zlabel('dim 3')
ax.set_title('Isomap Error: {:.2f}'.format(isomap_error))

ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.xaxis.labelpad=-11.5
ax.yaxis.labelpad=-11.5
ax.zaxis.labelpad=-11.5

