In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import scprep
from tqdm import tqdm
import os
import sys

sys.path.append('../src')
from models.unified_model import GeometricAE
from models.distance_matching import DistanceMatching
from data_script import hemisphere_data, sklearn_swiss_roll

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Data
data_name = 'swiss_roll'
if data_name == 'swiss_roll':
    gt_X, X, _ = sklearn_swiss_roll(n_samples=1000, noise=0.0)
    colors = None
elif data_name == 'hemisphere':
    gt_X, X, _ = hemisphere_data(n_samples=1000, noise=0.0)
    colors = None

print(gt_X.shape, X.shape)

In [None]:
mode = 'encoder'
model_hypers = {
    'ambient_dimension': 3,
    'latent_dimension': 2,
    'model_type': 'distance',
    'activation': 'relu',
    'layer_widths': [256, 128, 64],
    'knn': 10,
    't': 'auto',
    'n_landmark': 5000,
    'verbose': False
}
training_hypers = {
    'data_name': f'{data_name}',
    'mode': mode, # 'encoder', 'decoder', 'end2end', 'separate
    'max_epochs': 100,
    'batch_size': 64,
    'lr': 1e-3,
    'shuffle': True,
    'componentwise_std': False,
    'weight_decay': 1e-5,
    'dist_mse_decay': 1e-5,
    'monitor': 'validation/loss',
    'patience': 100,
    'seed': 2024,
    'log_every_n_steps': 100,
    'accelerator': 'auto',
    'train_from_scratch': True,
    'model_save_path': f'./{data_name}_distance_matching_{mode}/model'
}

model = DistanceMatching(**model_hypers)
model.fit(X, train_mask=None, percent_test=0.2, **training_hypers)

In [None]:
Z = model.encode(torch.Tensor(X))
print('Encoded Z:', Z.shape)
X_hat = model.decode(Z)
print('Decoded X:', X_hat.shape)

In [None]:
# Visualize
fig = plt.figure(figsize=(20, 6))
ax = fig.add_subplot(131, projection='3d')
scprep.plot.scatter3d(X_hat.detach().numpy(), c='b', title='Reconstructed', ax=ax)

if Z.shape[-1] < 3:
    ax = fig.add_subplot(132)
    scprep.plot.scatter2d(Z.detach().numpy(), c='b', title='Latent', ax=ax)
else:
    ax = fig.add_subplot(132, projection='3d')
    scprep.plot.scatter3d(Z.detach().numpy(), c='b', title='Latent', ax=ax)

ax = fig.add_subplot(133, projection='3d')
scprep.plot.scatter3d(X, c='b', title='Original', ax=ax)

In [None]:
metric = model.encoder_pullback(torch.Tensor(X))
print('Encoder Pullback:', metric.shape)

In [None]:
# randomly pick a pair of points in ambient space
def random_pair(Z, num_endpoints=32):
    if isinstance(Z, torch.Tensor):
        Z = Z.detach().numpy()
    z_end = Z[np.random.randint(0, Z.shape[0]), :]
    z_start = Z[np.random.randint(0, Z.shape[0], num_endpoints), :]
    print('z_start:', z_start.shape)
    print('z_end:', z_end.shape)

    batch_x0 = torch.Tensor(z_start)
    batch_x1 = torch.Tensor(np.repeat(z_end.reshape(1, -1), num_endpoints, axis=0))

    print('batch_x0:', batch_x0.shape)
    print('batch_x1:', batch_x1.shape)

    return z_start, z_end, batch_x0, batch_x1

# Z = Z.detach().cpu().numpy()
# X = X.detach().cpu().numpy()
z_start, z_end, batch_x0, batch_x1 = random_pair(X, num_endpoints=32)

dataset = torch.utils.data.TensorDataset(batch_x0, batch_x1) # same end points, different start points for NeuralODE
#dataset = torch.utils.data.TensorDataset(batch_x1, batch_x0) 

dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False)

# Visualize x0, x1 in latent space
fig = plt.figure(figsize=(5, 5))
if z_start.shape[1] < 3:
    ax = fig.add_subplot(111)
    scprep.plot.scatter2d(Z, c='gray', title='Latent', ax=ax)
    scprep.plot.scatter2d(z_start, c='red', ax=ax)
    scprep.plot.scatter2d(z_end.reshape(1, -1), c='blue', ax=ax)
else:
    ax = fig.add_subplot(111, projection='3d')
    scprep.plot.scatter3d(X, c='gray', title='Latent', ax=ax)
    scprep.plot.scatter3d(z_start, c='red', ax=ax)
    scprep.plot.scatter3d(z_end.reshape(1, -1), c='blue', ax=ax)

### With Density Loss

In [None]:
adjoint = False
if adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

class ODEFunc(nn.Module):
    """
    For simplicity we are just using 2 layers but it might worth to substitute with the MLP class
    although the torchdiffeq suggusted using tanh activation which we might want to tune.
    """
    def __init__(self, in_dim, hidden_dim):
        super(ODEFunc, self).__init__()
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.net = nn.Sequential(
            nn.Linear(self.in_dim, self.hidden_dim),
            nn.Tanh(),
            nn.Linear(self.hidden_dim, self.in_dim),
        )
    def forward(self, t, x):
        return self.net(x)

In [None]:
odefunc = ODEFunc(3, 32)
ts = torch.linspace(0, 1, 100)
path = odeint(odefunc, batch_x1, ts) # [T, B, D]
print(path.shape)

### GeoODE with Density Loss

In [None]:
from torch.autograd import grad

def compute_jacobian_function(f, x, create_graph=True, retain_graph=True):
    """
    Compute the Jacobian of function f wrt x using an efficient broadcasting approach.
    Args:
        f: The function to compute the Jacobian of. f: (B, D) -> (B, n).
        x: (B, D) A batch of points in the dim D.
    Returns:
        jacobian: (B, n, D) The Jacobian of f wrt x.
    """
    # z_batch = z_batch.clone().detach().requires_grad_(True)
    x = x.clone()
    x.requires_grad_(True)
    # model.no_grad()
    output = f(x)
    batch_size, output_dim, input_dim = *output.shape, x.shape[-1]

    # Use autograd's grad function to get gradients for each output dimension
    jacobian = torch.zeros(batch_size, output_dim, input_dim).to(x.device)
    for i in range(output_dim):
        grad_outputs = torch.zeros(batch_size, output_dim).to(x.device)
        grad_outputs[:, i] = 1.0
        gradients = grad(outputs=output, inputs=x, grad_outputs=grad_outputs, create_graph=create_graph, retain_graph=retain_graph, only_inputs=True)[0]
        print(gradients.shape)
        print(gradients.mean())
        jacobian[:, i, :] = gradients
    return jacobian

In [None]:
jac = compute_jacobian_function(model.encoder, torch.rand(10, 3))
print(jac.shape)
print(model.encoder(torch.rand(10, 3)).mean())

In [None]:
import torch.nn.functional as F
import pytorch_lightning as pl
    
class GeodesicODEDensity(pl.LightningModule):
    def __init__(self, 
        fcn, # encoder/decoder
        in_dim=3, 
        hidden_dim=32, 
        n_tsteps=1000, # num of t steps for length evaluation
        lam=10, # regularization for end point
        lr=1e-3, 
        weight_decay=1e-5, 
        beta=0.,
        n_pow=4,
        data_pts=None,
        n_data_sample=None,
        n_topk=5,
        density_weight=1.,
    ):
        #super().__init__(fcn, in_dim, hidden_dim, n_tsteps, lam, lr, weight_decay, beta, n_pow)
        super().__init__()
        # self.save_hyperparameters()
        
        self.odefunc = ODEFunc(in_dim, hidden_dim)
        
        self.pretraining = False
        self.t = torch.linspace(0, 1, n_tsteps)
        self.register_buffer("data_pts", data_pts)
        self.n_data_sample = n_data_sample
        self.n_topk = n_topk
        self.density_weight = density_weight
        self.fcn = fcn

        self.lam = lam
        self.lr = lr
        self.weight_decay = weight_decay
        self.beta = beta
        self.n_pow = n_pow

        # freeze fcn
        fcn.requires_grad_(False)

    
    def forward(self, x0):
        '''
            x0: [B, D]
            xt: [T, B, D]
        '''
        t = self.t
        x_t = odeint(self.odefunc, x0, t)
        print(x_t.shape)
        return x_t

    def density_loss(self, x_t_flat, data_pts):
        vals, inds = torch.topk(
            torch.cdist(x_t_flat, data_pts), k=self.n_topk, dim=-1, largest=False, sorted=False
        )
        return vals.mean()

    def length_loss(self, t, x):
        '''
            t: [T]
            x: [T, B, D]
        '''
        print('x.shape:', x.shape)
        x_flat = x.view(-1, x.shape[2]) # [T*B, D]
        print(x_flat.mean())
        jac = compute_jacobian_function(self.fcn, x_flat) # [T*B, n, D]
        print('jac.shape', jac.shape)
        #print(jac)
        print(jac.mean())
        metric_flat = torch.einsum('nij,nik->njk', jac, jac) # [T*B, D, D]
        # metric_flat = pullback_metric(x_flat, self.hparams.fcn, create_graph=True, retain_graph=True)
        
        xdot = self.odefunc(t, x) # [T, B, D], the velocity
        xdot_flat = xdot.view(-1, xdot.shape[2]) # [T*B, D]
        #print('metric_flat:', metric_flat.shape)
        print(metric_flat.mean())
        l_flat = torch.sqrt(torch.einsum('Ni,Nij,Nj->N', xdot_flat, metric_flat, xdot_flat)) # [T*B]

        print('l_flat:', l_flat.shape)
        print(l_flat.mean())
        
        return l_flat.mean() # * (t[-1] - t[0]) # numerical integration, we set t in [0,1].

    def step(self, batch, batch_idx):
        t = self.t
        x0, x1 = batch #[B, D]
        x_t = self.forward(x0)  

        mse_loss = F.mse_loss(x_t[-1], x1) # endpoint loss
        if self.pretraining:
            return mse_loss
        mpowerede_loss = 0.
        if self.beta > 0.:
            mpowerede_loss = (torch.pow(x_t[-1] - x1, self.n_pow)).mean() * self.beta
        
        len_loss = self.length_loss(t, x_t)
        loss = len_loss + self.lam * mse_loss + mpowerede_loss
        
        if self.density_weight > 0.:
            x_t_flat = x_t.view(-1, x_t.shape[2])
            if self.n_data_sample is not None and self.n_data_sample < self.data_pts.size(0):
                indices = torch.randperm(self.data_pts.size(0))[:self.n_data_sample]
                dloss = self.density_loss(x_t_flat, self.data_pts[indices])
            else:
                dloss = self.density_loss(x_t_flat, self.data_pts)
            loss += self.density_weight * dloss
        print('len_loss:', len_loss, 'endpoint_loss:', self.lam * mse_loss, 'density_l:', dloss)
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx)
        self.log('train_loss', loss, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx)
        self.log('val_loss', loss, prog_bar=True, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx)
        self.log('test_loss', loss, prog_bar=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

In [None]:
endpoint_lambda = 10.0
density_lambda = 100.0

geo_ode = GeodesicODEDensity(model.encoder, in_dim=X.shape[-1], hidden_dim=64, n_tsteps=100, 
                             lam=endpoint_lambda, lr=1e-3, weight_decay=0, beta=0., n_pow=4, 
                             data_pts=torch.Tensor(X), n_data_sample=None, n_topk=5, density_weight=density_lambda)

trainer = pl.Trainer(max_epochs=300, gpus=0)
trainer.fit(geo_ode, dataloader)

In [None]:
# Visualize
n_samples = 1
t = torch.linspace(0, 1, 1000).view(-1,1)
pred_geodesic = geo_ode(batch_x0[:, :]).detach().numpy()
print('Pred Geodesic:', pred_geodesic.shape)
print(pred_geodesic[:10, :n_samples, :].shape)
print((pred_geodesic[:10, :n_samples, :]).reshape(-1,3))
print((pred_geodesic[:-10, :n_samples, :]).reshape(-1,3))
print('endpoint: ', batch_x0[:n_samples, :])
print('startpoint: ', batch_x1[:n_samples, :])

In [None]:
fig = plt.figure()

# Visualize pred geodesic on the ambient space
if pred_geodesic.shape[-1] < 3:
    ax = fig.add_subplot(111)
    scprep.plot.scatter2d(X, c='gray', ax=ax)
    scprep.plot.scatter2d(batch_x0.detach().numpy(), c='g', ax=ax, alpha=0.5)
    scprep.plot.scatter2d(batch_x1.detach().numpy(), c='r', ax=ax)
    for i in range(pred_geodesic.shape[1]):
        scprep.plot.scatter2d(pred_geodesic[:, i, :].reshape(1, -1), ax=ax)
else:
    ax = fig.add_subplot(111, projection='3d')
    scprep.plot.scatter3d(X, c='gray', ax=ax, alpha=0.5)
    scprep.plot.scatter3d(batch_x0.detach().numpy()[:,:], c='g', ax=ax)
    scprep.plot.scatter3d(batch_x1.detach().numpy()[:,:], c='r', ax=ax)
    for i in range(pred_geodesic.shape[1]):
        scprep.plot.scatter3d(pred_geodesic[:, i, :].reshape(1, -1), ax=ax, c='b')

    # anmi = scprep.plot.rotate_scatter3d(X, c='gray', ax=ax, alpha=0.5)
    # anmi = scprep.plot.rotate_scatter3d(batch_x0.detach().numpy()[:n_samples,:], c='g', ax=ax)
    # anmi = scprep.plot.rotate_scatter3d(batch_x1.detach().numpy()[:n_samples,:], c='r', ax=ax)
    # anmi = scprep.plot.rotate_scatter3d(pred_geodesic, c='b', title='Pred Geodesic', ax=ax)

    # anmi.save(f'density{density_lambda}_{data_name}_geoODE.gif', writer='imagemagick', fps=60)

plt.show()

In [None]:
# Plotly 3D scatter plot
import plotly
plotly.offline.init_notebook_mode()

import plotly.express as px
import plotly.graph_objects as go
n_samples = 20

fig = go.Figure()
fig.add_trace(go.Scatter3d(x=X[:,0], y=X[:,1], z=X[:,2], mode='markers', marker=dict(size=5, color='gray', opacity=0.5)))
fig.add_trace(go.Scatter3d(x=batch_x0[:n_samples,0], y=batch_x0[:,1], z=batch_x0[:,2], mode='markers', marker=dict(size=5, color='red')))
fig.add_trace(go.Scatter3d(x=batch_x1[:n_samples,0], y=batch_x1[:,1], z=batch_x1[:,2], mode='markers', marker=dict(size=5, color='green')))
for i in range(n_samples):
    fig.add_trace(go.Scatter3d(x=pred_geodesic[:,i,0], y=pred_geodesic[:,i,1], z=pred_geodesic[:,i,2], mode='markers', marker=dict(size=5, color='blue')))

# save
plotly.offline.plot(fig, filename=f'density{density_lambda}_{data_name}_geoODE.html')