In [88]:
# Autoreload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [89]:
import os
import sys
import numpy as np
import torch
from glob import glob
import scipy.spatial as spatial
import pytorch_lightning as pl
import plotly.graph_objs as go

sys.path.append('../../src/')
from model2 import Autoencoder
from model2 import Discriminator as OldDiscriminator
from off_manifolder import offmanifolder_maker_new
from geodesic import GeodesicFM
from diffusionmap import DiffusionMap
from negative_sampling import make_hi_freq_noise

### Load GAGA Encoder

In [90]:
local=False
root_dir = '../../'

if local:
    # Load model from wandb run checkpoint
    run_path = os.path.join(root_dir, 'src/wandb/run-20240709_172630-pdsufcya')
else:
    entity = 'xingzhis'
    project = 'dmae'
    run_id = 'pzlwi6t6' # run id of the model. This can be found in the wandb url of the run.
    run_path = os.path.join(root_dir, 'src/wandb/')
    run_path = glob(f"{run_path}/*{run_id}")[0]
    print(run_path)

model_path = glob(f"{run_path}/files/*.ckpt")[0]
print(model_path)

ae_model = Autoencoder.load_from_checkpoint(model_path)
#print(ae_model)

# Load config
config = ae_model.hparams.cfg
# for k, v in config.items():
#     print(f'{k}: {v}')

local=False
if local == True:
    # Load model.
    wd_run_path = os.path.join(root_dir, 'src/wandb/run-20240713_163458-vkfmju0p/files')
else:
    # Or remotely load model from wandb
    entity = 'xingzhis'
    project = 'dmae'
    run_id = 'kafcutw4'
    wd_run_path = os.path.join(root_dir, 'src/wandb/')
    wd_run_path = glob(f"{wd_run_path}/*{run_id}")[0]
    print(wd_run_path)
print(wd_run_path)
wd_model_path = glob(f"{wd_run_path}/files/*.ckpt")[0]
print(wd_model_path)

old_wd_model = OldDiscriminator.load_from_checkpoint(wd_model_path)
#print(old_wd_model)

# Load config.
config = old_wd_model.hparams.cfg
# for k, v in config.items():
#     print(f'{k}: {v}')

Lightning automatically upgraded your loaded checkpoint from v1.9.5 to v2.3.3. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../src/wandb/run-20240502_001829-pzlwi6t6/files/epoch=24-step=48000.ckpt`

Attribute 'preprocessor' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['preprocessor'])`.

Lightning automatically upgraded your loaded checkpoint from v1.9.5 to v2.3.3. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../src/wandb/run-20240502_013119-kafcutw4/files/epoch=6-step=26880.ckpt`


../../src/wandb/run-20240502_001829-pzlwi6t6
../../src/wandb/run-20240502_001829-pzlwi6t6/files/epoch=24-step=48000.ckpt
../../src/wandb/run-20240502_013119-kafcutw4
../../src/wandb/run-20240502_013119-kafcutw4
../../src/wandb/run-20240502_013119-kafcutw4/files/epoch=6-step=26880.ckpt


In [91]:
# Load subset data.
rp = '../../data/eb_subset_all.npz'
subset_data = np.load(rp)
for f in subset_data.files:
    print(f, subset_data[f].shape)

# Load the existing wd data.
wd_data_path = '../../data/negative_sampling_toy_shell/False/eb.npz'
wd_data = np.load(wd_data_path)
print('Loaded data:')
for files in wd_data.files:
    print(files, wd_data[files].shape, wd_data[files].dtype)

# # Load full set data.
# import anndata
# adata = anndata.read_h5ad('../../data/eb_hv.h5ad')
# print(adata)
# def _str_category_to_idx(categories):
#     cat = np.unique(categories)
#     cat2idx = {c: i for i, c in enumerate(cat)}
#     return np.array([cat2idx[c] for c in categories])
# adata.obs['sample_classes'] = _str_category_to_idx(adata.obs['sample_labels'])
# full_data = {
#     'data': adata.obsm['X_pca'],
#     'colors': adata.obs['sample_classes'].values
# }
# print(full_data['data'].shape, full_data['colors'].shape)

data (3000, 50)
phate (3000, 2)
dist (3000, 3000)
colors (3000,)
is_train (3000,)
Loaded data:
data (5400, 50) float64
phate (5400, 2) float64
dist (5400, 5400) float64
colors (5400,) float64
is_train (5400,) float64
mask_x (5400,) float64
mask_d (5400, 5400) float32


### Negative Sampling

In [92]:
# Negative sampling based on the data.
def add_noise(data, n_samples, noise_rate=1, seed=42, noise='gaussian'):
    np.random.shuffle(data)
    np.random.seed(seed)
    #noise_rates = np.random.rand(n_samples, 1) * noise_rate
    noise_rates = np.ones((n_samples, 1)) * noise_rate
    if noise == 'gaussian':
        data_std = data.std()
        noise = np.random.randn(*data.shape)
        #data_noisy = data + noise * noise_rates * data_std
        data_noisy = data + noise * noise_rate
    elif noise == 'hi-freq':
        diff_map_op = DiffusionMap(n_components=3, t=3, random_state=seed).fit(data)
        data_noisy = data + make_hi_freq_noise(data, diff_map_op, noise_rate=noise_rates, add_data_mean=False)

    return data_noisy, noise_rates.flatten()


def neg_sampling(x, noise_levels=[10,], reject=True):
    n_samples = x.shape[0]
    print('n_samples: ', n_samples)
    
    x_noisy = []
    noise_rates = []

    for i, noise_level in enumerate(noise_levels):
        cur_noisy, cur_noise_rates = add_noise(x, n_samples=n_samples, noise_rate=noise_level, seed=42+i, noise='gaussian')

        if reject == True:
            # reject negative samples that are too close to the original data.
            k = 20
            dist =  spatial.distance_matrix(cur_noisy, x) # (n_samples, n_samples)
            dist_closest_k = np.sort(dist, axis=1)[:, 1:k] # (n_samples, k)

            dist_topk_mean = dist_closest_k.mean(axis=1) # (n_samples,)
            threshold = np.percentile(dist_topk_mean, 90) # closest 10% of the data.
            reject_idxs = np.where(dist_topk_mean < threshold)[0]
            cur_noisy = np.delete(cur_noisy, reject_idxs, axis=0)
            print('After rejection: ', cur_noisy.shape)

        x_noisy.extend(cur_noisy)
        noise_rates.extend(cur_noise_rates)

    x_noisy = np.array(x_noisy)
    noise_rates = np.array(noise_rates)
    print(x_noisy.shape)
    print('x std: ', x.std())
    print('x_noisy std: ', x_noisy.std())
    print('noise_rates.std(): ', noise_rates.std())

    return x_noisy

x_noisy = neg_sampling(x=subset_data['data'], noise_levels=[10,12,14], reject=True)

n_samples:  3000
After rejection:  (300, 50)
After rejection:  (300, 50)
After rejection:  (300, 50)
(900, 50)
x std:  1.5974236
x_noisy std:  14.320191687908592
noise_rates.std():  1.632993161855452


In [93]:
# Plot
x = subset_data['data']
fig = go.Figure()
fig.add_trace(go.Scatter(x=x[:, 0], y=x[:, 1], mode='markers', marker=dict(size=2, color='blue', opacity=0.8)))
fig.add_trace(go.Scatter(x=x_noisy[:, 0], y=x_noisy[:, 1], mode='markers', marker=dict(size=2, color='red', opacity=0.8)))


In [95]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.backends.mps.is_available():
    device = 'mps'
batch_size = 32

# Sanity check encoder and decoder on data.
def encode_pos_neg_data(ae_model, x, x_noisy):
    # Encode data
    encodings = []
    neg_encodings = []
    ae_model = ae_model.to(device)
    ae_model.eval()
    with torch.no_grad():
        for i in range(0, len(x), batch_size):
            x_batch = torch.tensor(x[i:i+batch_size], dtype=torch.float32).to(device)
            encodings.append(ae_model.encoder(x_batch).cpu().detach().numpy())
        for i in range(0, len(x_noisy), batch_size):
            neg_x_batch = torch.tensor(x_noisy[i:i+batch_size], dtype=torch.float32).to(device)
            neg_encodings.append(ae_model.encoder(neg_x_batch).cpu().detach().numpy())

    encodings = np.concatenate(encodings, axis=0)
    neg_encodings = np.concatenate(neg_encodings, axis=0)
    print('encodings.shape: ', encodings.shape)
    print('neg_encodings.shape: ', neg_encodings.shape)

    fig = go.Figure()
    fig.add_trace(go.Scatter3d(x=encodings[:, 0], y=encodings[:, 1], z=encodings[:, 2], mode='markers', marker=dict(size=2, color='blue', opacity=0.8)))
    fig.add_trace(go.Scatter3d(x=neg_encodings[:, 0], y=neg_encodings[:, 1], z=neg_encodings[:, 2], 
                            mode='markers', marker=dict(size=2, color='red', opacity=0.5)))
    fig.show()
    
x = data['data']
encode_pos_neg_data(ae_model, x, x_noisy)

encodings.shape:  (3000, 3)
neg_encodings.shape:  (900, 3)


### Train W-Discriminator

In [96]:
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

class MLP(torch.nn.Module):
    def __init__(self, in_dim, out_dim, layer_widths=[64, 64, 64], activation='relu', 
                 batch_norm=False, dropout=0.0, use_spectral_norm=False):
        super().__init__()

        layers = []
        for i, width in enumerate(layer_widths):
            if i == 0:
                linear_layer = torch.nn.Linear(in_dim, width)
            else:
                linear_layer = torch.nn.Linear(layer_widths[i-1], width)

            # Conditionally apply spectral normalization
            if use_spectral_norm:
                linear_layer = spectral_norm(linear_layer)

            layers.append(linear_layer)

            if batch_norm:
                layers.append(torch.nn.BatchNorm1d(width))
            
            if activation == 'relu':
                layers.append(torch.nn.ReLU())
            elif activation == 'leaky_relu':
                layers.append(torch.nn.LeakyReLU())
            elif activation == 'tanh':
                layers.append(torch.nn.Tanh())
            else:
                raise ValueError(f'Invalid activation function: {activation}')

            if dropout > 0:
                layers.append(torch.nn.Dropout(dropout))
        
        # Adding the final layer
        final_linear_layer = torch.nn.Linear(layer_widths[-1], out_dim)
        if use_spectral_norm:
            final_linear_layer = spectral_norm(final_linear_layer)
        layers.append(final_linear_layer)

        self.net = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class Discriminator(pl.LightningModule):
    def __init__(self, in_dim, layer_widths=[64, 64, 64], activation='relu', data_pts=None, k=5, encoder=None,
                 batch_norm=False, dropout=0.0, use_spectral_norm=False, **kwargs):
        super().__init__()

        if data_pts is None:
            self.in_dim = in_dim
        else:
            self.in_dim = in_dim + k # Augment data with extra density features

        self.mlp = MLP(self.in_dim, 2, layer_widths, activation, batch_norm, dropout, use_spectral_norm)
        self.data_pts = data_pts # [N, in_dim]
        self.encoder = encoder
        self.k = int(k)
        self.lr = kwargs.get('lr', 1e-3)
        self.mean = torch.tensor(kwargs.get('mean', np.zeros(in_dim)), dtype=torch.float32, device=device)
        self.std = torch.tensor(kwargs.get('std', np.ones(in_dim)), dtype=torch.float32, device=device)
        
        print('self.mean: ', self.mean.shape)
        print('self.std: ', self.std.shape)

        self.train_step_outs = []
        self.val_step_outs = []
        self.test_step_outs = []

        self.train_ys = []
        self.val_ys = []
        self.test_ys = []

        # Freeze the encoder
        if self.encoder is not None:
            for param in self.encoder.parameters():
                param.requires_grad = False
    
    def augment_data(self, x):
        # Augment data with extra density features
        dists = torch.cdist(x, self.data_pts)
        topk, _ = torch.topk(dists, k=self.k, dim=1, largest=False, sorted=False) # [N, k]
        #topk = torch.mean(topk, dim=1).unsqueeze(1) # [N, 1]
        x = torch.cat([x, topk], dim=1)
    
        return x

    def forward(self, x, normalize=True):
        # Assume x is not yet normalized
        x = (x - self.mean) / self.std

        if self.data_pts is not None:
            if self.encoder is None:
                x = self.augment_data(x)
            else:
                x_embed = self.encoder(x)
                x = self.augment_data(x_embed)
                assert x.shape[1] == self.in_dim, f'x.shape: {x.shape}, self.in_dim: {self.in_dim}'

        #print('x.shape: ', x.shape)
        logits = self.mlp(x)
        return logits # [N, 2]
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.lr)
        return opt
    
    def training_step(self, batch, batch_idx):
        x, y = batch # x: [N, in_dim], y: [N]
        logits = self(x) # [N, 2]
        loss = F.cross_entropy(logits, y)

        self.train_step_outs.append(logits)
        self.train_ys.append(y)
        self.log('train_loss', loss, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        self.val_step_outs.append(logits)
        self.val_ys.append(y)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)

        self.test_step_outs.append(logits)
        self.test_ys.append(y)
        self.log('test_loss', loss, on_epoch=True, prog_bar=True)
        return loss
    
    def on_train_epoch_end(self):
        # Compute prediction accuracy
        train_logits = torch.cat(self.train_step_outs, dim=0) # [N, 2]
        true_classes = torch.cat(self.train_ys, dim=0) # [N]
        pred_classes = torch.argmax(train_logits, dim=1) # [N]

        train_acc = torch.sum(pred_classes == true_classes) / len(true_classes)
        self.log('train_acc', train_acc, on_epoch=True, prog_bar=True)

        # clear lists
        self.train_step_outs.clear()
        self.train_ys.clear()

    def on_validation_epoch_end(self):
        # Compute prediction accuracy
        val_logits = torch.cat(self.val_step_outs, dim=0)
        true_classes = torch.cat(self.val_ys, dim=0)
        pred_classes = torch.argmax(val_logits, dim=1)

        val_acc = torch.sum(pred_classes == true_classes) / len(true_classes)
        self.log('val_acc', val_acc, on_epoch=True, prog_bar=True)

        # clear lists
        self.val_step_outs.clear()
        self.val_ys.clear()

    def on_test_epoch_end(self):
        # Compute prediction accuracy
        test_logits = torch.cat(self.test_step_outs, dim=0)
        true_classes = torch.cat(self.test_ys, dim=0)
        pred_classes = torch.argmax(test_logits, dim=1)

        test_acc = torch.sum(pred_classes == true_classes) / len(true_classes)
        self.log('test_acc', test_acc, on_epoch=True, prog_bar=True)

        # clear lists
        self.test_step_outs.clear()
        self.test_ys.clear()
    
    def positive_prob(self, x):
        logits = self(x)
        # apply softmax to get probabilities
        softmax = torch.nn.Softmax(dim=1)
        return softmax(logits)[:, 1]

In [97]:
# Setup data.
def encode_data(x, encoder):
    batch_size = 256
    encodings = []
    encoder.eval()
    encoder.to(device)
    with torch.no_grad():
        for i in range(0, len(x), batch_size):
            x_batch = torch.tensor(x[i:i+batch_size], dtype=torch.float32).to(device)
            encodings.append(encoder(x_batch).cpu().detach().numpy())
    
    encodings = np.concatenate(encodings, axis=0)
    return encodings

def get_discriminator_loaders(x, x_noisy, batch_size=256, wd_data=None):
    combined_x = np.concatenate([x, x_noisy], axis=0)
    if wd_data is not None:
        combined_x = np.concatenate([x, x_noisy, wd_data['data']], axis=0)
    combined_y = np.concatenate([np.ones(x.shape[0]), np.zeros(x_noisy.shape[0])]).astype(np.int32)
    if wd_data is not None:
        combined_y = np.concatenate([np.ones(x.shape[0]), np.zeros(x_noisy.shape[0]), wd_data['mask_x']]).astype(np.int32)

    dataset = torch.utils.data.TensorDataset(torch.tensor(combined_x, dtype=torch.float32), torch.tensor(combined_y, dtype=torch.int64))
    train_size = int(0.7 * len(dataset))
    val_size = int(0.2 * len(dataset))
    test_size = len(dataset) - train_size - val_size

    batch_size = 256
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

embed = False
x_embed = None
x_noisy_embed = None
combined_x = None
combined_y = None

if embed == True:
    x_embed = encode_data(x, ae_model.encoder)
    x_noisy_embed = encode_data(x_noisy, ae_model.encoder)
    print('x_embed.shape: ', x_embed.shape)
    print('x_noisy_embed.shape: ', x_noisy_embed.shape)

    combined_x = np.concatenate([x_embed, x_noisy_embed], axis=0)
    combined_y = np.concatenate([np.ones(x_embed.shape[0]), np.zeros(x_noisy_embed.shape[0])]).astype(np.int32)

    train_loader, val_loader, test_loader = get_discriminator_loaders(x_embed, x_noisy_embed, batch_size=256)
    data_pts = torch.tensor(x_embed, dtype=torch.float32, device=device)
else:
    combined_x = np.concatenate([x, x_noisy], axis=0)
    combined_y = np.concatenate([np.ones(x.shape[0]), np.zeros(x_noisy.shape[0])]).astype(np.int32)
    train_loader, val_loader, test_loader = get_discriminator_loaders(x, x_noisy, batch_size=256, wd_data=wd_data)
    data_pts = torch.tensor(x, dtype=torch.float32, device=device)

print('combined_x.shape: ', combined_x.shape)
print('combined_y.shape: ', combined_y.shape)

combined_x.shape:  (3900, 50)
combined_y.shape:  (3900,)


In [98]:
hyperparams = {
    'dirpath': './wd_model',
    'name': 'wd_model',
    'in_dim': 3 if embed is True else x.shape[1],
    'layer_widths': [256, 128, 64],
    'activation': 'relu',
    'batch_norm': True,
    'dropout': 0.5,
    'use_spectral_norm': True,
    'lr': 1e-3,
    'weight_decay': 1e-4,
    'max_epochs': 300,
    'patience': 50,
    'monitor': 'val_acc',
    'mean': combined_x.mean(axis=0),
    'std': combined_x.std(axis=0),
    'data_pts': data_pts,
    #'data_pts': encodings,
    #'encoder': ae_model.encoder,
    'k': 5,
}


wd_model = Discriminator(**hyperparams)
print(wd_model)

sx, sy = next(iter(train_loader))
print(sx.shape, sy.shape)

early_stopping = pl.callbacks.EarlyStopping(monitor=hyperparams['monitor'], patience=hyperparams['patience'])
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor=hyperparams['monitor'], save_top_k=1, mode='max', 
                                                dirpath=hyperparams['dirpath'], filename=hyperparams['name'])

trainer = pl.Trainer(max_epochs=hyperparams['max_epochs'], accelerator='auto', log_every_n_steps=17, 
                     callbacks=[early_stopping, model_checkpoint])
trainer.fit(model=wd_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Checkpoint directory /Users/danqiliao/Desktop/dmae/notebooks/flow_matching/wd_model exists and is not empty.


  | Name | Type | Params | Mode 
--------------------------------------
0 | mlp  | MLP  | 56.5 K | train
--------------------------------------
56.5 K    Trainable params
0         Non-trainable params
56.5 K    Total params
0.226     Total estimated model params size (MB)


self.mean:  torch.Size([50])
self.std:  torch.Size([50])
Discriminator(
  (mlp): MLP(
    (net): Sequential(
      (0): Linear(in_features=55, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.5, inplace=False)
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): Dropout(p=0.5, inplace=False)
      (8): Linear(in_features=128, out_features=64, bias=True)
      (9): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): Dropout(p=0.5, inplace=False)
      (12): Linear(in_features=64, out_features=2, bias=True)
    )
  )
)
torch.Size([256, 50]) torch.Size([256])
                                                                            


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.



Epoch 50: 100%|██████████| 26/26 [00:00<00:00, 77.96it/s, v_num=198, val_loss=0.312, val_acc=0.878, train_acc=0.930]


In [99]:
### Evaluate Old Discriminator on combined_x
old_acc = 0
old_wd_model.to(device)
old_wd_model.eval()
with torch.no_grad():
    for i in range(0, len(combined_x), batch_size):
        x_batch = torch.tensor(combined_x[i:i+batch_size], dtype=torch.float32).to(device)
        logits = old_wd_model(x_batch)
        pred_classes = torch.argmax(logits, dim=1)
        true_classes = torch.tensor(combined_y[i:i+batch_size], dtype=torch.int64).to(device)
        old_acc += torch.sum(pred_classes == true_classes).item()

old_acc /= len(combined_x)

print('Old Discriminator accuracy: ', old_acc)

### Evaluate New Discriminator on combined_x
wd_model.to(device)
wd_model.eval()
wd_acc = 0
with torch.no_grad():
    for i in range(0, len(combined_x), batch_size):
        x_batch = torch.tensor(combined_x[i:i+batch_size], dtype=torch.float32).to(device)
        logits = wd_model(x_batch)
        pred_classes = torch.argmax(logits, dim=1)
        true_classes = torch.tensor(combined_y[i:i+batch_size], dtype=torch.int64).to(device)
        wd_acc += torch.sum(pred_classes == true_classes).item()
wd_acc /= len(combined_x)

print('New Discriminator accuracy: ', wd_acc)

Old Discriminator accuracy:  0.9923076923076923
New Discriminator accuracy:  1.0


In [100]:
def visualize_probs(ae_model, wd_model, combined_x, combined_y):
    wd_encodings = []
    ae_model = ae_model.to(device)
    ae_model.eval()
    wd_batch_size = 32
    with torch.no_grad():
        for i in range(0, len(combined_x), wd_batch_size):
            x_batch = torch.tensor(combined_x[i:i+wd_batch_size], dtype=torch.float32).to(device)
            wd_encodings.append(ae_model.encoder(x_batch).cpu().detach().numpy())
    wd_encodings = np.concatenate(wd_encodings, axis=0)
    print('wd_encodings.shape: ', wd_encodings.shape)

    wd_model = wd_model.to(device)
    wd_model.eval()
    probs = []
    with torch.no_grad():
        for i in range(0, len(combined_x), wd_batch_size):
            x_batch = torch.tensor(combined_x[i:i+wd_batch_size], dtype=torch.float32).to(device)
            #prob = wd_model(x_batch)[:, 1] # probability of being real
            prob = wd_model.positive_prob(x_batch)
            probs.append(prob.cpu().detach().numpy())
    probs = np.concatenate(probs, axis=0)
    print('probs.shape: ', probs.shape)

    # Plot
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(x=wd_encodings[:,0], y=wd_encodings[:,1], z=wd_encodings[:,2],
                            mode='markers', marker=dict(size=2, color=combined_y, colorscale='Viridis', opacity=0.8)))
    fig.add_trace(go.Scatter3d(x=wd_encodings[:,0], y=wd_encodings[:,1], z=wd_encodings[:,2],
                                mode='markers', marker=dict(size=2, color=probs, colorscale='Viridis', opacity=0.8)))
    fig.add_trace(go.Scatter3d(x=wd_encodings[:,0], y=wd_encodings[:,1], z=wd_encodings[:,2],
                                mode='markers', marker=dict(size=2, color=(probs > probs.mean()).astype(int), colorscale='Viridis', opacity=0.8)))
    fig.show()

    return wd_encodings, probs

_, probs = visualize_probs(ae_model, wd_model, combined_x, combined_y)

visualize_probs(ae_model, wd_model, wd_data['data'], wd_data['mask_x'])


wd_encodings.shape:  (3900, 3)
probs.shape:  (3900,)


wd_encodings.shape:  (5400, 3)
probs.shape:  (5400,)


(array([[ 0.2487267 ,  0.5307898 ,  0.14138453],
        [ 0.4516882 ,  0.5163434 , -0.33334905],
        [ 0.35367313,  0.4715617 , -0.28540057],
        ...,
        [ 1.3262057 ,  0.07980571, -0.68556845],
        [-0.7557104 ,  0.13464001, -0.16202168],
        [ 1.1320058 , -0.0158112 , -0.6627188 ]], dtype=float32),
 array([0.98819435, 0.98333883, 0.9916487 , ..., 0.7746303 , 0.14895447,
        0.24136136], dtype=float32))

In [101]:
encodings = encode_data(x, ae_model.encoder)
eb_labels = data['colors']
fig = go.Figure()
for i in range(5):
    idx = np.where(eb_labels == i)[0]
    print('idx: ', len(idx))
    xs = encodings[idx, :]
    print('xs.shape: ', xs.shape)
    fig.add_trace(go.Scatter3d(x=xs[:, 0], y=xs[:, 1], z=xs[:, 2], 
                            mode='markers', marker=dict(size=2, opacity=0.8)))
    
fig.show()

eb_labels = data['colors']
fig = go.Figure()
for i in range(5):
    idx = np.where(eb_labels == i)[0]
    print('idx: ', len(idx))
    xs = data['phate'][idx, :]
    print('xs.shape: ', xs.shape)
    fig.add_trace(go.Scatter(x=xs[:, 0], y=xs[:, 1],
                            mode='markers', marker=dict(size=2, opacity=0.8)))
    
fig.show()

idx:  419
xs.shape:  (419, 3)
idx:  758
xs.shape:  (758, 3)
idx:  598
xs.shape:  (598, 3)
idx:  652
xs.shape:  (652, 3)
idx:  573
xs.shape:  (573, 3)


idx:  419
xs.shape:  (419, 2)
idx:  758
xs.shape:  (758, 2)
idx:  598
xs.shape:  (598, 2)
idx:  652
xs.shape:  (652, 2)
idx:  573
xs.shape:  (573, 2)


### Select start/end points

In [106]:
def sample_indices_within_range(points, selected_idx=None, range_size=0.1, num_samples=20, seed=23):
    np.random.seed(seed)
    # Randomly select two points from the array
    if selected_idx is None:
        selected_indices = np.random.choice(points.shape[0], 2, replace=False)
        point1_idx, point2_idx = selected_indices[0], selected_indices[1]
        point1, point2 = points[point1_idx], points[point2_idx]
    else:
        point1_idx, point2_idx = selected_idx
        point1, point2 = points[point1_idx], points[point2_idx]    
    # Function to find indices of points within the range of a given point
    def _find_indices_within_range(point):
        distances = np.linalg.norm(points - point, axis=1)
        within_range_indices = np.where(distances <= range_size)[0]
        return within_range_indices
    
    # Find indices within range of point1 and point2
    indices_within_range1 = _find_indices_within_range(point1)
    indices_within_range2 = _find_indices_within_range(point2)
    
    # Randomly sample indices within the range
    if len(indices_within_range1) >= num_samples:
        sampled_indices_point1 = np.random.choice(indices_within_range1, num_samples, replace=False)
    else:
        sampled_indices_point1 = indices_within_range1
    
    if len(indices_within_range2) >= num_samples:
        sampled_indices_point2 = np.random.choice(indices_within_range2, num_samples, replace=False)
    else:
        sampled_indices_point2 = indices_within_range2
    
    return point1_idx, sampled_indices_point1, point2_idx, sampled_indices_point2

start_group = 0
end_group = 3
eb_labels = subset_data['colors']
print('eb_labels: ', eb_labels.shape)
start_idices = np.where(eb_labels == start_group)[0]
start_idx = np.random.choice(start_idices, 1)[0]
end_idices = np.where(eb_labels == end_group)[0]
end_idx = np.random.choice(end_idices, 1)[0]

start_idx = 736
end_idx = 2543

print('start_idx, end_idx: ', start_idx, end_idx)
encodings = encode_data(subset_data['data'], ae_model.encoder)
print('encodings.shape: ', encodings.shape)
point1_idx, sampled_indices_point1, point2_idx, sampled_indices_point2 = sample_indices_within_range(encodings, 
                                                                                                     selected_idx=(start_idx, end_idx),
                                                                                                     range_size=0.3, 
                                                                                                     seed=2024, num_samples=64)
point1 = x[point1_idx]
point2 = x[point2_idx]
samples_point1 = x[sampled_indices_point1]
samples_point2 = x[sampled_indices_point2]
print('point1, point2: ', point1.shape, point2.shape)
print('samples_point1, samples_point2: ', samples_point1.shape, samples_point2.shape)

print('point1, point2: ', point1_idx, point2_idx) # point1, point2: 736 2543 (0-3)
samples_z_point1 = ae_model.encoder(torch.tensor(samples_point1, dtype=torch.float32).to(device)).cpu().detach().numpy()
samples_z_point2 = ae_model.encoder(torch.tensor(samples_point2, dtype=torch.float32).to(device)).cpu().detach().numpy()
start_end_z = np.vstack([encodings[start_idx], encodings[end_idx]])
print('samples_z_point1, samples_z_point2: ', samples_z_point1.shape, samples_z_point2.shape)
print('start_end_z: ', start_end_z.shape)
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=encodings[:,0], y=encodings[:,1], z=encodings[:,2], 
                           mode='markers', marker=dict(size=2, color='gray', colorscale='Viridis', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=start_end_z[:,0], y=start_end_z[:,1], z=start_end_z[:,2],
                            mode='markers', marker=dict(size=5, color='red', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point1[:,0], y=samples_z_point1[:,1], z=samples_z_point1[:,2],
                            mode='markers', marker=dict(size=2, color='blue', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point2[:,0], y=samples_z_point2[:,1], z=samples_z_point2[:,2],
                            mode='markers', marker=dict(size=2, color='green', opacity=0.8)))


eb_labels:  (3000,)
start_idx, end_idx:  736 2543
encodings.shape:  (3000, 3)
point1, point2:  (50,) (50,)
samples_point1, samples_point2:  (64, 50) (64, 50)
point1, point2:  736 2543
samples_z_point1, samples_z_point2:  (64, 3) (64, 3)
start_end_z:  (2, 3)


In [107]:
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, x0, x1):
        self.x0 = x0
        self.x1 = x1

    def __len__(self):
        return max(len(self.x0), len(self.x1))

    def __getitem__(self, idx):
        return self.x0[idx % len(self.x0)], self.x1[idx % len(self.x1)]

def custom_collate_fn(batch):
    x0_batch = torch.stack([item[0] for item in batch])
    x1_batch = torch.stack([item[1] for item in batch])
    
    # Randomly permute the elements in the batch
    perm_x0 = torch.randperm(len(x0_batch))
    perm_x1 = torch.randperm(len(x1_batch))

    x0_batch = x0_batch[perm_x0]
    x1_batch = x1_batch[perm_x1]
    
    return x0_batch, x1_batch

# Create a dataloader
bs = 32
dataset = CustomDataset(x0=torch.tensor(samples_point1, dtype=torch.float32), 
                        x1=torch.tensor(samples_point2, dtype=torch.float32))
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True, collate_fn=custom_collate_fn)

In [108]:
probs_mean = probs.mean()
probs_std = probs.std()

print('probs_mean: ', probs_mean)
print('probs_std: ', probs_std)

probs_mean:  0.76058894
probs_std:  0.41150856


In [109]:
ae_model = ae_model.to(device)
wd_model = wd_model.to(device)
old_wd_model = old_wd_model.to(device)

for param in ae_model.encoder.parameters():
    param.requires_grad = False
for param in wd_model.parameters():
    param.requires_grad = False
for param in old_wd_model.parameters():
    param.requires_grad = False

enc_func = lambda x: ae_model.encoder(x)
#disc_func = lambda x: (wd_model.positive_proba(x) - min_prob) / (max_prob - min_prob) # normalize the probability to [0, 1]
disc_func = lambda x: 1 - wd_model.positive_prob(x)
#disc_func = lambda x: 1 - old_wd_model.positive_proba(x)

ofm = offmanifolder_maker_new(enc_func, disc_func, disc_factor=5) # ofm encodes both on/off manifold points
# TODO: 1. cc_k default is 2, may want to increase it. 2. symmetric can be false. 
# 3. increase hidden 4. embed t as well. 5. decreasew weight decay 6. increase length loss weight 7. train DW in embedding space
gbmodel = GeodesicFM(
    func=ofm,
    encoder=enc_func,
    input_dim=x.shape[1],
    hidden_dim=64, 
    scale_factor=1, 
    symmetric=True, 
    num_layers=3, 
    n_tsteps=100, 
    lr=1e-3,
    weight_decay=1e-4,
    flow_weight=1,
    length_weight=1,
    cc_k=5,
)

gbmodel.lr=1e-3
early_stopping = pl.callbacks.EarlyStopping(monitor='train_loss_epoch', patience=50, mode='min')
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor='train_loss_epoch', save_top_k=1, mode='min', 
                                                dirpath='./eb_fm/checkpoints', filename='gbmodel')
trainer = pl.Trainer(
    # logger=logger,
    max_epochs=300,
    log_every_n_steps=20,
    accelerator=device,
    callbacks=[early_stopping, model_checkpoint]
)

trainer.fit(gbmodel, train_dataloaders=dataloader)
checkpoint_dir = './eb_fm/checkpoints'

#trainer.save_checkpoint(f"{checkpoint_dir}/{start_group}_{end_group}_gbmodel.ckpt")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.


Checkpoint directory /Users/danqiliao/Desktop/dmae/notebooks/flow_matching/eb_fm/checkpoints exists and is not empty.


  | Name       | Type             | Params | Mode 
--------------------------------------------------------
0 | cc         | CondCurveOverfit | 64.1 K | train
1 | flow_model | MLP              | 14.9 K | train
--------------------------------------------------------
79.0 K    Trainable params
0         Non-trainable params
79.0 K    Total params
0.316     Total estimated model params size (MB)

The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


The number of training batches (2) is smaller than the logging interval Trainer

Epoch 299: 100%|██████████| 2/2 [00:06<00:00,  0.29it/s, v_num=199, loss_length_step=4.240, fm_length_step=1.310, loss_step=5.540, train_loss_step=5.540, loss_length_epoch=4.220, fm_length_epoch=1.290, loss_epoch=5.510, train_loss_epoch=5.510]

`Trainer.fit` stopped: `max_epochs=300` reached.


Epoch 299: 100%|██████████| 2/2 [00:06<00:00,  0.29it/s, v_num=199, loss_length_step=4.240, fm_length_step=1.310, loss_step=5.540, train_loss_step=5.540, loss_length_epoch=4.220, fm_length_epoch=1.290, loss_epoch=5.510, train_loss_epoch=5.510]


In [None]:
# trainer = pl.Trainer(
#     # logger=logger,
#     max_epochs=300,
#     log_every_n_steps=20,
#     accelerator=device,
#     callbacks=[early_stopping, model_checkpoint]
# )
# trainer.fit(gbmodel, train_dataloaders=dataloader)
# checkpoint_dir = './eb_fm/checkpoints'

In [116]:
x0 = []
x1 = []
for x0_, x1_ in dataloader:
    x0.append(x0_)
    x1.append(x1_)
x0 = torch.cat(x0, dim=0)
x1 = torch.cat(x1, dim=0)

x0 = x0.to(device)
x1 = x1.to(device)

ids = torch.zeros((x0.size(0),1)) # Conditional id for each pair of x0 and x1, here is a dummy.
ids = ids.to(device)

gbmodel = gbmodel.to(device)
gbmodel.eval()
with torch.no_grad():
    x_traj = gbmodel(x0, x1, gbmodel.ts.to(device), ids) # [T, B, D]
print('Predicted trajectory shape: ', x_traj.shape)

ae_model.eval()
with torch.no_grad():
    z_traj = ae_model.encoder(x_traj.flatten(0,1)) # [T*B, D]

    z0 = ae_model.encoder(x0)
    z1 = ae_model.encoder(x1)

z_traj = z_traj.cpu().detach().numpy()
z_traj = z_traj.reshape(x_traj.size(0), x_traj.size(1), -1)
print('z_traj.shape: ', z_traj.shape)

z0 = z0.cpu().detach().numpy()
z1 = z1.cpu().detach().numpy()

Predicted trajectory shape:  torch.Size([100, 64, 50])
z_traj.shape:  (100, 64, 3)


In [117]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=encodings[:,0], y=encodings[:,1], z=encodings[:,2], 
                           mode='markers', marker=dict(size=2, color='gray', colorscale='Viridis', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point1[:,0], y=samples_z_point1[:,1], z=samples_z_point1[:,2],
                            mode='markers', marker=dict(size=5, color='blue', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point2[:,0], y=samples_z_point2[:,1], z=samples_z_point2[:,2],
                            mode='markers', marker=dict(size=5, color='green', opacity=0.8)))

for i in range(20):
    fig.add_trace(go.Scatter3d(x=z_traj[:,i,0], y=z_traj[:,i,1], z=z_traj[:,i,2],
                               mode='lines', line=dict(width=2, color='blue')))

fig.show()

# save html
fig.write_html(f"./eb_fm/EB_{start_group}_{end_group}_geobridge.html")

In [118]:
import torchdiffeq
from torch import nn

adjoint = False
if adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

class ODEFuncWrapper(nn.Module):
    def __init__(self, flowmodel):
        super().__init__()
        self.flowmodel = flowmodel
    def forward(self, t, y):
        # Expand t to match the batch size and feature dimension
        t_expanded = t.view(1, 1).expand(y.size(0), 1)
        # Concatenate y and t along the feature dimension
        y_with_t = torch.cat((y, t_expanded), dim=-1)
        return self.flowmodel(y_with_t)

n_samples = min(100, x0.size(0))
sampled_starts = torch.tensor(x[sampled_indices_point1[:n_samples]], dtype=torch.float32).to('cpu')

print(f'Run ODE on {n_samples} samples: {sampled_starts.shape}')
flowfunc = ODEFuncWrapper(gbmodel.flow_model.to('cpu'))
with torch.no_grad():
    traj = odeint(flowfunc, sampled_starts.to('cpu'), gbmodel.ts.to('cpu'))

print('Flow Matching ODE Trajectory shape: ', traj.shape)

ae_model.eval()
with torch.no_grad():
    traj = traj.to(device)
    z_traj = ae_model.encoder(traj.flatten(0,1)) # [T*B, D]

z_traj = z_traj.cpu().detach().numpy()
z_traj = z_traj.reshape(traj.size(0), traj.size(1), -1)
print('z_traj.shape: ', z_traj.shape)


Run ODE on 64 samples: torch.Size([64, 50])
Flow Matching ODE Trajectory shape:  torch.Size([100, 64, 50])
z_traj.shape:  (100, 64, 3)


In [119]:
# Plot 2d
fig = go.Figure()
fig.add_trace(go.Scatter(x=encodings[:,0], y=encodings[:,1], mode='markers', marker=dict(size=2, color='gray')))
fig.add_trace(go.Scatter(x=samples_z_point1[:,0], y=samples_z_point1[:,1], mode='markers', marker=dict(size=5, color='blue')))
fig.add_trace(go.Scatter(x=samples_z_point2[:,0], y=samples_z_point2[:,1], mode='markers', marker=dict(size=5, color='green')))
for i in range(n_samples):
    fig.add_trace(go.Scatter(x=z_traj[:,i,0], y=z_traj[:,i,1], mode='lines', line=dict(width=2, color='blue')))
fig.show()

# save html
fig.write_html(f"./eb_fm/EB_{start_group}_{end_group}_2d.html")

In [120]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=encodings[:,0], y=encodings[:,1], z=encodings[:,2], 
                           mode='markers', marker=dict(size=2, color='gray', colorscale='Viridis', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point1[:,0], y=samples_z_point1[:,1], z=samples_z_point1[:,2],
                            mode='markers', marker=dict(size=5, color='blue', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point2[:,0], y=samples_z_point2[:,1], z=samples_z_point2[:,2],
                            mode='markers', marker=dict(size=5, color='green', opacity=0.8)))

for i in range(20):
    fig.add_trace(go.Scatter3d(x=z_traj[:,i,0], y=z_traj[:,i,1], z=z_traj[:,i,2],
                                 mode='lines', line=dict(width=2, color='blue')))

fig.show()

# save the html
fig.write_html(f"./eb_fm/EB_{start_group}_{end_group}_3d.html")

In [1]:
# Plot positive_prob of these trajectories
traj.shape # [T, B, D]
traj_flat = traj.flatten(0,1) # [T*B, D]

bs = 32
traj_probs = []
wd_model = wd_model.to(device)
wd_model.eval()
with torch.no_grad():
    for i in range(0, traj_flat.shape[0], bs):
        x_batch = traj_flat[i:i+bs]
        prob = wd_model.positive_prob(x_batch)
        #print(prob)
        traj_probs.append(prob.cpu().detach().numpy())
traj_probs = np.concatenate(traj_probs, axis=0) # [T*B]
traj_probs = traj_probs.reshape(traj.size(0), traj.size(1)) # [T, B]
print('traj_probs.shape: ', traj_probs.shape)

fig = go.Figure()
fig.add_trace(go.Scatter3d(x=encodings[:,0], y=encodings[:,1], z=encodings[:,2], 
                           mode='markers', marker=dict(size=2, color='gray', colorscale='Viridis', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point1[:,0], y=samples_z_point1[:,1], z=samples_z_point1[:,2],
                            mode='markers', marker=dict(size=5, color='blue', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point2[:,0], y=samples_z_point2[:,1], z=samples_z_point2[:,2],
                            mode='markers', marker=dict(size=5, color='green', opacity=0.8)))

for i in range(30):    
    fig.add_trace(go.Scatter3d(x=z_traj[:,i,0], y=z_traj[:,i,1], z=z_traj[:,i,2],
                                 mode='markers', marker=dict(size=2, color=traj_probs[:,i].mean(), colorscale='Viridis')))
    prob_texts = []
    for ti in range(traj_probs.shape[0]):
        prob_texts.append(f'Prob: {traj_probs[ti,i]:.2f}')
        #print(ti, ':', prob_texts)
    fig.update_traces(hovertext=prob_texts)

fig.show()




NameError: name 'traj' is not defined