In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), '..'))

import torch
from invertible_network_utils import construct_invertible_mlp
from pathlib import Path

import matplotlib.pyplot as plt
from simclr.simclr import SimCLR
from spaces import NSphereSpace

from visualization_utils.spheres import visualize_spheres_side_by_side, scatter3d_sphere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create figures directory
figures_dir = Path('figures')
figures_dir.mkdir(exist_ok=True)

print('Using', device)

In [None]:
sphere = NSphereSpace(3)
g_mlp = construct_invertible_mlp(n=3, n_layers=3, act_fct="leaky_relu")

z = sphere.uniform(1000)

print('MAX NORM:',(g_mlp(z) ** 2).sum(dim=1).max())

fig = visualize_spheres_side_by_side(plt, z, g_mlp(z))
fig.savefig(figures_dir / 'mlp_data_generating_process.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
from encoders import SphericalEncoder

tau = 0.3
kappa = 1 / tau

iterations = 10000
batch = 6144

sample_pair_fixed = lambda batch: sphere.sample_pair_vmf(batch, kappa)
sample_uniform_fixed = lambda batch: sphere.uniform(batch)

f = SphericalEncoder()

h = lambda z: f(g_mlp(z))

z = sphere.uniform(1000)
z_enc = h(z)

# Before training visualization
fig = visualize_spheres_side_by_side(plt, z, z_enc)
fig.savefig(figures_dir / 'mlp_before_training_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

z = sphere.uniform(100000)

fig = scatter3d_sphere(plt, z.cpu(), h(z.cpu()).cpu(), s=10, a=.8)
fig.savefig(figures_dir / 'mlp_before_training_encoded.png', dpi=150, bbox_inches='tight')
plt.show()

fig = scatter3d_sphere(plt, z.cpu(), z.cpu(), s=10, a=.8)
fig.savefig(figures_dir / 'mlp_before_training_original.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
from visualization_utils.scoring import plot_scores
from experiment_utils.linear import linear_unrotation

f = SphericalEncoder()

simclr_vmf = SimCLR(
    f, g_mlp, sample_pair_fixed, sample_uniform_fixed, tau, device
)

f, scores = simclr_vmf.train(batch, iterations)

h = lambda z: f(g_mlp(z))

z = sphere.uniform(1000).to(device)
z_enc = h(z).to(device)

fig_scores = plot_scores(plt, scores)
fig_scores.savefig(figures_dir / 'mlp_training_scores.png', dpi=150, bbox_inches='tight')
plt.show()

fig = visualize_spheres_side_by_side(plt, z.cpu(), z_enc.cpu())
fig.savefig(figures_dir / 'mlp_after_training_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

z = sphere.uniform(100000).to(device)
z_enc = h(z)

fig = scatter3d_sphere(plt, z.cpu(), z_enc.cpu(), s=10, a=.8)
fig.savefig(figures_dir / 'mlp_after_training_encoded.png', dpi=150, bbox_inches='tight')
plt.show()

fig = scatter3d_sphere(plt, z.cpu(), z.cpu(), s=10, a=.8)
fig.savefig(figures_dir / 'mlp_after_training_original.png', dpi=150, bbox_inches='tight')
plt.show()

# Unrotation
z_unrotated = linear_unrotation(z, z_enc)
fig = scatter3d_sphere(plt, z.cpu(), z_unrotated, s=10, a=.8)
fig.savefig(figures_dir / 'mlp_after_training_unrotated.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"All figures saved to {figures_dir}/")