In [3]:
## Standard libraries
import os
import math
import numpy as np
import time
from fastcore.all import *
from nbdev.showdoc import *

# Configure environment
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false' # Tells Jax not to hog all of the memory to this process.

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
# set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgba
import seaborn as sns
sns.set()

## Progress bar
from tqdm.auto import tqdm, trange

import torch
import sys

sys.path.append('../src/')

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Pullback Comparisons with the Affinity Matching AE
> Witty encapsulation

**Hypothesis**: Stuff will happen

# Machinery

Import datasets

In [20]:
from autometric.datasets import Hemisphere

In [21]:
hemisphere = Hemisphere(num_points = 2000, r = 1)

Set up model

In [22]:
from models.affinity_matching import AffinityMatching

In [31]:
model_hypers = {
    'ambient_dimension': 3,
    'latent_dimension': 2,
    'model_type': 'affinity',
    'loss_type': 'kl',
    'activation': 'relu',
    'layer_widths': [256, 128, 64],
    'kernel_method': 'gaussian',
    'kernel_alpha': 1,
    'kernel_bandwidth': 1,
    'knn': 5,
    't': 0,
    'n_landmark': 5000,
    'verbose': False
}
training_hypers = {
    'data_name': 'randomtest',
    'max_epochs': 100,
    'batch_size': 64,
    'lr': 1e-3,
    'shuffle': True,
    'weight_decay': 1e-5,
    'monitor': 'val_loss',
    'patience': 100,
    'seed': 2024,
    'log_every_n_steps': 100,
    'accelerator': 'auto',
    'train_from_scratch': True,
    'model_save_path': './affinity_matching'
}

Fit on data

In [32]:
X = hemisphere.X.numpy()

In [37]:
# Test AffinityMatching model
model = AffinityMatching(**model_hypers)
model.fit(
    X,
    train_mask=None, 
    percent_test=0.3, 
    **training_hypers)

Z = model.encode(X)
print('Encoded Z:', Z.shape)
X_hat = model.decode(Z)
print('Decoded X:', X_hat.shape)

Running PHATE on 983 observations and 3 variables.
Calculating graph and diffusion operator...
  Calculating KNN search...
  Calculating affinities...
Calculating optimal t...
  Automatically selected t = 27
Calculated optimal t in 0.12 seconds.
Calculating diffusion potential...
Calculated diffusion potential in 0.03 seconds.
Calculating metric MDS...
Calculated metric MDS in 0.36 seconds.
row_stochastic_matrix torch.Size([983, 983])
checking row sum: False
row sum:  tensor([1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001,
        1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001,
        1.0001, 1.0001])
Calculating optimal t...
  Automatically selected t = 27
Calculated optimal t in 0.12 seconds.
Running PHATE on 1405 observations and 3 variables.
Calculating graph and diffusion operator...
  Calculating KNN search...
  Calculating affinities...
Calculated graph and diffusion operator in 0.01 seconds.
Calculating optimal t...
  Automatical

# Results

# Conclusion