In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import scprep
from tqdm import tqdm
import os
import sys

sys.path.append('../src')
from models.unified_model import GeometricAE
from models.affinity_matching import AffinityMatching
from data_script import hemisphere_data, sklearn_swiss_roll
# from geodesic import CondCurve

In [None]:
# Data
gt_X, X, _ = hemisphere_data(n_samples=1000, noise=0)
#gt_X, X, colors = sklearn_swiss_roll(n_samples=2000, noise=0.0)


# Train, Test
percent_test = 0.2
idxs = np.random.permutation(X.shape[0])
split_idx = int(X.shape[0] * (1-percent_test))
train_mask = np.zeros(X.shape[0], dtype=int)
train_mask[idxs[:split_idx]] = 1
train_mask = train_mask.astype(bool)

X_train = X[train_mask]
X_test = X[~train_mask]

print(gt_X.shape, X.shape, X_train.shape, X_test.shape)

In [None]:
# Visualize
scprep.plot.scatter3d(gt_X, c='b', title='Ground Truth')

In [None]:
# Fit Affinity Matching Model
model_hypers = {
    'ambient_dimension': 3,
    'latent_dimension': 2,
    'model_type': 'affinity',
    'loss_type': 'kl',
    'activation': 'relu',
    'layer_widths': [256, 128, 64],
    'kernel_method': 'gaussian',
    'kernel_alpha': 1,
    'kernel_bandwidth': 1,
    'knn': 5,
    't': 0,
    'n_landmark': 5000,
    'verbose': False
}
training_hypers = {
    'data_name': 'hemisphere',
    'max_epochs': 100,
    'batch_size': 64,
    'lr': 1e-3,
    'shuffle': True,
    'weight_decay': 1e-5,
    'monitor': 'val_loss',
    'patience': 100,
    'seed': 2024,
    'log_every_n_steps': 100,
    'accelerator': 'auto',
    'train_from_scratch': True, # load or train from scratch
    'model_save_path': './affinity_matching'
}

model = AffinityMatching(**model_hypers)
model.fit(X, train_mask=train_mask, percent_test=percent_test, **training_hypers)

In [None]:
Z = model.encode(torch.Tensor(X))
print('Encoded Z:', Z.shape)
X_hat = model.decode(Z)
print('Decoded X:', X_hat.shape)

In [None]:
# Visualize
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(121, projection='3d')
scprep.plot.scatter3d(X_hat.detach().numpy(), c='r', title='Reconstructed', ax=ax)
ax = fig.add_subplot(122)
scprep.plot.scatter2d(Z.detach().numpy(), c='r', title='Latent', ax=ax)



In [None]:
metric = model.encoder_pullback(torch.Tensor(X))
print('Encoder Pullback:', metric.shape)

In [None]:
# # CondCurve
# class CondCurve(nn.Module):
#     def __init__(self, input_dim, hidden_dim, scale_factor=5, symmetric=False):
#         super().__init__()
#         self.mod_x0x1 = nn.Sequential(
#             nn.Linear((2 * hidden_dim) + 1, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.ReLU(),
#             nn.Linear(hidden_dim, input_dim),
#         )
#         self.scale_factor = scale_factor
#         self.symmetric = symmetric

#         if symmetric:
#             self.x0_x1_preemb = nn.Linear(2 * input_dim, 2 * hidden_dim)
#             self.x0_emb = nn.Linear(2 * hidden_dim, hidden_dim)
#             self.x1_emb = nn.Linear(2 * hidden_dim, hidden_dim)
#         else:
#             self.x0_emb = nn.Linear(input_dim, hidden_dim)
#             self.x1_emb = nn.Linear(input_dim, hidden_dim)

#     def forward(self, x0, x1, t):
#         '''
#             x0, x1: [B, input_dim]
#             t: [T]
#         '''
#         bs = x0.shape[0]
#         T = t.shape[0]
#         if self.symmetric:
#             x0x1_emb = (
#                 self.x0_x1_preemb(torch.cat([x0, x1], dim=-1))
#                 + self.x0_x1_preemb(torch.cat([x1, x0], dim=-1))
#             ) * 0.5
#             emb_x0 = self.x0_emb(x0x1_emb)
#             emb_x1 = self.x1_emb(x0x1_emb)
#         else:
#             emb_x0 = self.x0_emb(x0)
#             emb_x1 = self.x1_emb(x1) # [B, hidden_dim]
#         t = t.view(-1, 1)
#         avg = t * x1 + (1 - t) * x0 # [B, T, input_dim]
#         print('avg.shape', avg.shape)
#         enveloppe = self.scale_factor * (1 - (t * 2 - 1) ** 2)
#         # Tile t to [B, T, 1]
#         t = torch.tile(t, (bs, 1, 1))
#         # Tile embx1,x0 to [B, T, input_dim]
#         emb_x0 = torch.tile(emb_x0, (1, T, 1))
#         emb_x1 = torch.tile(emb_x1, (1, T, 1))
#         print(emb_x0.shape, emb_x1.shape, t.shape)

#         aug_state = torch.cat([emb_x0, emb_x1, t], dim=-1)
#         outs = self.mod_x0x1(aug_state) * enveloppe + avg

#         return outs
    
#     # return a function that only takes t as input
#     def forward_rspt_t(self, x0, x1, t):
#         return lambda t: self.forward(x0, x1, t)

# curve = CondCurve(input_dim=X.shape[1], hidden_dim=32, scale_factor=5, symmetric=False)

# optimizer = torch.optim.Adam(curve.parameters(), lr=0.001)

In [None]:
class CurveNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, x0, x1, variance = 0.15 ):
        super().__init__()
        self.mod = nn.Sequential(nn.Linear(1,hidden_dim),
                    nn.Tanh(),
                    nn.Linear(hidden_dim,hidden_dim),
                    nn.Tanh(),
                    nn.Linear(hidden_dim,hidden_dim),
                    nn.Tanh(),
                    nn.Linear(hidden_dim,input_dim + 1))
        self.x0 = x0
        self.x1 = x1
        self.variance = variance
        self.input_dim = input_dim
    def forward(self, t):
    
        mu = t * self.x1 + (1-t) * self.x0
        
        enveloppe = self.variance * (1- (t*2-1)**2) * self.mod(t)[:, self.input_dim].reshape(-1, 1)
        outs =  self.mod(t)[:,:self.input_dim] * enveloppe + mu

        return outs

# randomly pick a pair of points
batch_x0 = torch.Tensor(X_train[np.random.choice(X_train.shape[0], 1)])
batch_x1 = torch.Tensor(X_train[np.random.choice(X_train.shape[0], 1)])

print('Batch X0:', batch_x0.shape)
print('Batch X1:', batch_x1.shape)

curve = CurveNet(input_dim=X.shape[1], hidden_dim=32, x0=batch_x0, x1=batch_x1, variance=0.15)
optimizer = torch.optim.Adam(curve.parameters(), lr=0.001)

In [None]:
print(batch_x0)
print(batch_x1)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
scprep.plot.scatter3d(X_train, c='b', title='Training Data', ax=ax)
scprep.plot.scatter3d(batch_x0, c='r', ax=ax)
scprep.plot.scatter3d(batch_x1, c='r', ax=ax)


In [None]:
t = torch.linspace(0, 1, 100).view(-1,1)
losses = []
for _ in tqdm(range(500)):

    optimizer.zero_grad()

    #jac = torch.autograd.functional.jacobian(curve.forward, t, create_graph=True)
    jac = torch.autograd.functional.jacobian(curve.forward, t, create_graph=True)
    #print('jac: ', jac.shape)
    #print('jac = torch.einsum("tntd->tnd",jac)', torch.einsum("tntd->tnd",jac).shape)
    jac = torch.einsum("tntd->tnd",jac)[...,0]
    #print('jac.shape', jac.shape)

    out = curve(t)
        
    #m = metric(out)
    out_ = torch.Tensor(out.cpu().detach().numpy())
    m = model.encoder_pullback(out_)

    pre_prod = torch.einsum('tb,tbj->tj',jac,m)
    prod = torch.einsum('tb,tb->t', pre_prod, jac)

    loss = prod.mean()
    loss.backward()
    optimizer.step()
    losses.append(loss.detach().numpy())

In [None]:
# Visualize
t = torch.linspace(0, 1, 1000).view(-1,1)
pred_geodesic = curve(t).detach().numpy()
print('Pred Geodesic:', pred_geodesic.shape)

In [None]:
# Visualize pred geodesic on the latent space
fig = plt.figure()
ax = fig.add_subplot(111)
scprep.plot.scatter2d(Z.detach().numpy(), c='r', title='Latent', ax=ax)
# start and end points
scprep.plot.scatter2d(batch_x0, c='g', ax=ax)
scprep.plot.scatter2d(batch_x1, c='g', ax=ax)
scprep.plot.scatter2d(pred_geodesic, c='b', title='Pred Geodesic', ax=ax)


In [None]:
# Visualize pred geodesic on the ambient space
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
scprep.plot.scatter3d(X, c='r', title='Reconstructed', ax=ax)
# start and end points
scprep.plot.rotate_scatter3d(batch_x0, c='g', ax=ax)
scprep.plot.rotate_scatter3d(batch_x1, c='r', ax=ax)
scprep.plot.rotate_scatter3d(pred_geodesic, c='b', title='Pred Geodesic', ax=ax)