# Spiral Experiment

We will perform the sphere reconstruction process with spiral data-generating process

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), '..'))

import torch
from pathlib import Path

from data.generation import SpiralRotation, Patches
from visualization_utils.spheres import visualize_spheres_side_by_side, scatter3d_sphere
from encoders import SphericalEncoder

from encoders import get_mlp

import matplotlib.pyplot as plt
from simclr.simclr import SimCLR
from spaces import NSphereSpace

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create figures directory
figures_dir = Path('figures')
figures_dir.mkdir(exist_ok=True)

print('Using', device)

In [None]:
sphere = NSphereSpace(3)
g_spiral = SpiralRotation(2)

z = sphere.uniform(1000)

fig = visualize_spheres_side_by_side(plt, z, g_spiral(z))
fig.savefig(figures_dir / 'spiral_data_generating_process.png', dpi=150, bbox_inches='tight')
plt.show()

z = sphere.uniform(100000)
fig = scatter3d_sphere(plt, z, g_spiral(z), s=10, a=.8)
fig.savefig(figures_dir / 'spiral_transformed.png', dpi=150, bbox_inches='tight')
plt.show()

fig = scatter3d_sphere(plt, z, z, s=10, a=.8)
fig.savefig(figures_dir / 'spiral_original.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
tau = 0.3
kappa = 1 / tau

iterations = 10000
batch = 6144

sample_pair_fixed = lambda batch: sphere.sample_pair_vmf(batch, kappa)
sample_uniform_fixed = lambda batch: sphere.uniform(batch)

f = SphericalEncoder(hidden_dims=[128, 256, 256, 256, 256, 256, 256, 128])

h = lambda z: f(g_spiral(z))

z = sphere.uniform(1000)
z_enc = h(z)

# Before training
fig = visualize_spheres_side_by_side(plt, z, z_enc)
fig.savefig(figures_dir / 'spiral_mlp_before_training_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

z = sphere.uniform(100000).to(device)
z_enc = h(z.cpu())

fig = scatter3d_sphere(plt, z.cpu(), z_enc.cpu(), s=10, a=.8)
fig.savefig(figures_dir / 'spiral_mlp_before_training_encoded.png', dpi=150, bbox_inches='tight')
plt.show()

fig = scatter3d_sphere(plt, z.cpu(), z.cpu(), s=10, a=0.8)
fig.savefig(figures_dir / 'spiral_mlp_before_training_original.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
from visualization_utils.scoring import plot_scores

iterations = 250

simclr_vmf = SimCLR(
    f, g_spiral, sample_pair_fixed, sample_uniform_fixed, tau, device
)

f, scores = simclr_vmf.train(batch, iterations)

h = lambda z: f(g_spiral(z))

z = sphere.uniform(1000).to(device)
z_enc = h(z).to(device)

fig_scores = plot_scores(plt, scores)
fig_scores.savefig(figures_dir / 'spiral_mlp_training_scores.png', dpi=150, bbox_inches='tight')
plt.show()

fig = visualize_spheres_side_by_side(plt, z.cpu(), z_enc.cpu())
fig.savefig(figures_dir / 'spiral_mlp_after_training_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

z = sphere.uniform(100000).to(device)
z_enc = h(z).to(device)

fig = scatter3d_sphere(plt, z.cpu(), z_enc.cpu(), s=10, a=.8)
fig.savefig(figures_dir / 'spiral_mlp_after_training_encoded.png', dpi=150, bbox_inches='tight')
plt.show()

fig = scatter3d_sphere(plt, z.cpu(), z.cpu(), s=10, a=0.8)
fig.savefig(figures_dir / 'spiral_mlp_after_training_original.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
from experiment_utils.linear import linear_unrotation

z_unrotated = linear_unrotation(z, z_enc)

fig = scatter3d_sphere(plt, z.cpu(), z_unrotated, s=10, a=.8)
fig.savefig(figures_dir / 'spiral_mlp_unrotated.png', dpi=150, bbox_inches='tight')
plt.show()

fig = scatter3d_sphere(plt, z.cpu(), z.cpu(), s=10, a=0.8)
plt.show()

In [None]:
from torch import nn
from encoders import InverseSpiralEncoder


simclr_biased = SimCLR(
    InverseSpiralEncoder(3, 3, 2), g_spiral, sample_pair_fixed, sample_uniform_fixed, tau, device
)

f, scores = simclr_vmf.train(batch, iterations)

h = lambda z: f(g_spiral(z))

z = sphere.uniform(1000).to(device)
z_enc = h(z).to(device)

fig_scores = plot_scores(plt, scores)
fig_scores.savefig(figures_dir / 'spiral_biased_training_scores.png', dpi=150, bbox_inches='tight')
plt.show()

fig = visualize_spheres_side_by_side(plt, z.cpu(), z_enc.cpu())
fig.savefig(figures_dir / 'spiral_biased_after_training_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

z = sphere.uniform(100000).to(device)
z_enc = h(z).to(device)

fig = scatter3d_sphere(plt, z.cpu(), z_enc.cpu(), s=10, a=.8)
fig.savefig(figures_dir / 'spiral_biased_after_training_encoded.png', dpi=150, bbox_inches='tight')
plt.show()

fig = scatter3d_sphere(plt, z.cpu(), z.cpu(), s=10, a=0.8)
fig.savefig(figures_dir / 'spiral_biased_after_training_original.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"All figures saved to {figures_dir}/")