In [1]:
# Autoreload
%load_ext autoreload
%autoreload 2

In [27]:
import os
import shutil
import sys
import numpy as np
import torch
from glob import glob
import scipy.spatial as spatial
import pytorch_lightning as pl
import plotly.graph_objs as go

sys.path.append('../../src/')
from model2 import Autoencoder
from model2 import Discriminator as OldDiscriminator
from off_manifolder import offmanifolder_maker_new, offmanifolder_density_maker
from geodesic import GeodesicFM
from diffusionmap import DiffusionMap
from negative_sampling import make_hi_freq_noise

### Load GAGA Encoder

In [3]:
local=False
root_dir = '../../'

if local:
    # Load model from wandb run checkpoint
    run_path = os.path.join(root_dir, 'src/wandb/run-20240709_172630-pdsufcya')
else:
    entity = 'xingzhis'
    project = 'dmae'
    run_id = 'pzlwi6t6' # run id of the model. This can be found in the wandb url of the run.
    run_path = os.path.join(root_dir, 'src/wandb/')
    run_path = glob(f"{run_path}/*{run_id}")[0]
    print(run_path)

model_path = glob(f"{run_path}/files/*.ckpt")[0]
print(model_path)

ae_model = Autoencoder.load_from_checkpoint(model_path)
#print(ae_model)

# Load config
config = ae_model.hparams.cfg
# for k, v in config.items():
#     print(f'{k}: {v}')

local=False
if local == True:
    # Load model.
    wd_run_path = os.path.join(root_dir, 'src/wandb/run-20240713_163458-vkfmju0p/files')
else:
    # Or remotely load model from wandb
    entity = 'xingzhis'
    project = 'dmae'
    run_id = 'kafcutw4'
    wd_run_path = os.path.join(root_dir, 'src/wandb/')
    wd_run_path = glob(f"{wd_run_path}/*{run_id}")[0]
    print(wd_run_path)
print(wd_run_path)
wd_model_path = glob(f"{wd_run_path}/files/*.ckpt")[0]
print(wd_model_path)

old_wd_model = OldDiscriminator.load_from_checkpoint(wd_model_path)
#print(old_wd_model)

# Load config.
config = old_wd_model.hparams.cfg
# for k, v in config.items():
#     print(f'{k}: {v}')

Lightning automatically upgraded your loaded checkpoint from v1.9.5 to v2.3.3. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../src/wandb/run-20240502_001829-pzlwi6t6/files/epoch=24-step=48000.ckpt`
Lightning automatically upgraded your loaded checkpoint from v1.9.5 to v2.3.3. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../src/wandb/run-20240502_013119-kafcutw4/files/epoch=6-step=26880.ckpt`


../../src/wandb/run-20240502_001829-pzlwi6t6
../../src/wandb/run-20240502_001829-pzlwi6t6/files/epoch=24-step=48000.ckpt
../../src/wandb/run-20240502_013119-kafcutw4
../../src/wandb/run-20240502_013119-kafcutw4
../../src/wandb/run-20240502_013119-kafcutw4/files/epoch=6-step=26880.ckpt


In [4]:
# Load subset data.
rp = '../../data/eb_subset_all.npz'
subset_data = np.load(rp)
for f in subset_data.files:
    print(f, subset_data[f].shape)

# Load the existing wd data.
wd_data_path = '../../data/negative_sampling_toy_shell/False/eb.npz'
wd_data = np.load(wd_data_path)
print('Loaded data:')
for files in wd_data.files:
    print(files, wd_data[files].shape, wd_data[files].dtype)

data (3000, 50)
phate (3000, 2)
dist (3000, 3000)
colors (3000,)
is_train (3000,)
Loaded data:
data (5400, 50) float64
phate (5400, 2) float64
dist (5400, 5400) float64
colors (5400,) float64
is_train (5400,) float64
mask_x (5400,) float64
mask_d (5400, 5400) float32


### Negative Sampling

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.backends.mps.is_available():
    device = 'mps'
batch_size = 32

# Negative sampling based on the data.
def add_noise(data, n_samples, noise_rate=1, seed=42, noise='gaussian'):
    np.random.shuffle(data)
    np.random.seed(seed)
    #noise_rates = np.random.rand(n_samples, 1) * noise_rate
    noise_rates = np.ones((n_samples, 1)) * noise_rate
    if noise == 'gaussian':
        #data_std = data.std()
        noise = np.random.randn(*data.shape)
        #data_noisy = data + noise * noise_rates * data_std
        data_noisy = data + noise * noise_rate
    elif noise == 'hi-freq':
        diff_map_op = DiffusionMap(n_components=3, t=3, random_state=seed).fit(data)
        data_noisy = data + make_hi_freq_noise(data, diff_map_op, noise_rate=noise_rates, add_data_mean=False)

    return data_noisy, noise_rates.flatten()


def neg_sampling(x, noise_levels=[10,], reject=False, noise_type='gaussian'):
    n_samples = x.shape[0]
    print('n_samples: ', n_samples)
    
    x_noisy = []
    noise_rates = []

    for i, noise_level in enumerate(noise_levels):
        cur_noisy, cur_noise_rates = add_noise(x, n_samples=n_samples, noise_rate=noise_level, seed=42+i, noise=noise_type)

        if reject == True:
            # reject negative samples that are too close to the original data.
            k = 20
            dist =  spatial.distance_matrix(cur_noisy, x) # (n_samples, n_samples)
            dist_closest_k = np.sort(dist, axis=1)[:, 1:k] # (n_samples, k)

            dist_topk_mean = dist_closest_k.mean(axis=1) # (n_samples,)
            #threshold = np.percentile(dist_topk_mean, 90) # closest 10% of the data.
            # threshold = dist_topk_mean.mean()*0.5
            threshold = .2
            reject_idxs = np.where(dist_topk_mean < threshold)[0]
            cur_noisy = np.delete(cur_noisy, reject_idxs, axis=0)
            print('After rejection: ', cur_noisy.shape)

        x_noisy.extend(cur_noisy)
        noise_rates.extend(cur_noise_rates)

    x_noisy = np.array(x_noisy)
    noise_rates = np.array(noise_rates)
    print('Generated Negative Samples: ', x_noisy.shape)
    print('x std: ', x.std())
    print('x_noisy std: ', x_noisy.std())
    print('noise_rates.std(): ', noise_rates.std())

    return x_noisy

# Setup data.
def encode_data(x, encoder):
    batch_size = 256
    encodings = []
    encoder.eval()
    encoder.to(device)
    with torch.no_grad():
        for i in range(0, len(x), batch_size):
            x_batch = torch.tensor(x[i:i+batch_size], dtype=torch.float32).to(device)
            encodings.append(encoder(x_batch).cpu().detach().numpy())
    
    encodings = np.concatenate(encodings, axis=0)
    return encodings

encodings = encode_data(subset_data['data'], ae_model.encoder)
noise_type = 'gaussian'
noise_levels = [0.2]
x_noisy = neg_sampling(x=encodings, noise_levels=noise_levels, reject=False, noise_type=noise_type)

# Plot
# x = encodings
# fig = go.Figure()
# fig.add_trace(go.Scatter(x=x[:, 0], y=x[:, 1], mode='markers', marker=dict(size=2, color='blue', opacity=0.8)))
# fig.add_trace(go.Scatter(x=x_noisy[:, 0], y=x_noisy[:, 1], mode='markers', marker=dict(size=2, color='red', opacity=0.8)))

# Sanity check encoder and decoder on data.
def visualize_pos_neg_data(encodings, neg_encodings):
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(x=encodings[:, 0], y=encodings[:, 1], z=encodings[:, 2], mode='markers', marker=dict(size=2, color='blue', opacity=0.8)))
    fig.add_trace(go.Scatter3d(x=neg_encodings[:, 0], y=neg_encodings[:, 1], z=neg_encodings[:, 2], 
                            mode='markers', marker=dict(size=2, color='red', opacity=0.5)))
    fig.show()

    fig.write_html(f'generated_neg/{noise_type}_{noise_levels}.html')
    
visualize_pos_neg_data(encodings, x_noisy)

n_samples:  3000
Generated Negative Samples:  (3000, 3)
x std:  0.7458885
x_noisy std:  0.7754367039992338
noise_rates.std():  5.551115123125783e-17


In [6]:
def forward_diffusion(x0, t, num_steps, beta_start, beta_end):
    '''
    Forward diffusion. q(x_t | x_(t-1)) = N(x_t | sqrt(1-beta_t) * x_(t-1), beta_t * I);
    With alpha_bar_t = cumprod(1-beta_t), 
    we have x_t = sqrt(alpha_bar_t) * x_0 + (1-alpha_bar_t) * epsilon_t
    where epsilon_t ~ N(0, 1).
    t has to be an integer, and less than num_steps.
    '''
    betas = np.linspace(beta_start, beta_end, num_steps) # [beta_0, beta_1, ..., beta_{T-1}]
    alpha_bars = np.cumprod(1-betas)

    x_t = np.sqrt(alpha_bars[t]) * x0 + (1-alpha_bars[t]) * np.random.randn(*x0.shape)
    print('sqrt(alpha_bar_t): ', np.sqrt(alpha_bars[t]), '1-alpha_bar_t: ', 1-alpha_bars[t])

    return x_t

def neg_sample_using_diffusion(x, ts, num_steps, beta_start, beta_end, seed=42):
    np.random.shuffle(x)
    np.random.seed(seed)

    x_noisy = []
    for t in ts:
        x_t = forward_diffusion(x, t, num_steps, beta_start, beta_end) # (n_samples, n_features)
        x_noisy.extend(x_t)

    x_noisy = np.array(x_noisy)

    return x_noisy

# Diffusion based negative sampling.
beta_start = 0.0001
beta_end = 0.01
num_steps = 1000
ts = [200]

x_noisy_diffusion = neg_sample_using_diffusion(encodings, ts, num_steps, beta_start, beta_end, seed=42)
print('Generated Negative Samples: ', x_noisy_diffusion.shape)

# Plot
visualize_pos_neg_data(encodings, x_noisy_diffusion)


sqrt(alpha_bar_t):  0.8960840190576405 1-alpha_bar_t:  0.1970334307895062
Generated Negative Samples:  (3000, 3)


In [7]:
# Sanity check encoder and decoder on data.
def visualize_pos_neg_data(encodings, neg_encodings):
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(x=encodings[:, 0], y=encodings[:, 1], z=encodings[:, 2], mode='markers', marker=dict(size=2, color='blue', opacity=0.8)))
    fig.add_trace(go.Scatter3d(x=neg_encodings[:, 0], y=neg_encodings[:, 1], z=neg_encodings[:, 2], 
                            mode='markers', marker=dict(size=2, color='red', opacity=0.2)))
    fig.show()

    fig.write_html(f'generated_neg/{noise_type}_{noise_levels}.html')
    
visualize_pos_neg_data(encodings, x_noisy)

### Train W-Discriminator

In [8]:
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

class MLP(torch.nn.Module):
    def __init__(self, in_dim, out_dim, layer_widths=[64, 64, 64], activation='relu', 
                 batch_norm=False, dropout=0.0, use_spectral_norm=False):
        super().__init__()

        layers = []
        for i, width in enumerate(layer_widths):
            if i == 0:
                linear_layer = torch.nn.Linear(in_dim, width)
            else:
                linear_layer = torch.nn.Linear(layer_widths[i-1], width)

            # Conditionally apply spectral normalization
            if use_spectral_norm:
                linear_layer = spectral_norm(linear_layer)

            layers.append(linear_layer)

            if batch_norm:
                layers.append(torch.nn.BatchNorm1d(width))
            
            if activation == 'relu':
                layers.append(torch.nn.ReLU())
            elif activation == 'leaky_relu':
                layers.append(torch.nn.LeakyReLU())
            elif activation == 'tanh':
                layers.append(torch.nn.Tanh())
            else:
                raise ValueError(f'Invalid activation function: {activation}')

            if dropout > 0:
                layers.append(torch.nn.Dropout(dropout))
        
        # Adding the final layer
        final_linear_layer = torch.nn.Linear(layer_widths[-1], out_dim)
        if use_spectral_norm:
            final_linear_layer = spectral_norm(final_linear_layer)
        layers.append(final_linear_layer)

        self.net = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class Discriminator(pl.LightningModule):
    def __init__(self, in_dim, layer_widths=[64, 64, 64], activation='relu', loss_type='bce', normalize=True,
                 data_pts=None, k=5, encoder=None,
                 batch_norm=False, dropout=0.0, use_spectral_norm=False, **kwargs):
        super().__init__()

        if data_pts is None:
            self.in_dim = in_dim
        else:
            self.in_dim = in_dim + k # Augment data with extra density features

        self.mlp = MLP(self.in_dim, 2, layer_widths, activation, batch_norm, dropout, use_spectral_norm)
        self.loss_type = loss_type
        self.normalize = normalize

        self.data_pts = data_pts # [N, in_dim]
        self.encoder = encoder

        self.k = int(k)
        self.lr = kwargs.get('lr', 1e-3)
        self.weight_decay = kwargs.get('weight_decay', 1e-5)
        self.mean = torch.tensor(kwargs.get('mean', np.zeros(in_dim)), dtype=torch.float32, device=device)
        self.std = torch.tensor(kwargs.get('std', np.ones(in_dim)), dtype=torch.float32, device=device)
        
        print('self.mean: ', self.mean.shape)
        print('self.std: ', self.std.shape)

        self.train_step_outs = []
        self.val_step_outs = []
        self.test_step_outs = []

        self.train_ys = []
        self.val_ys = []
        self.test_ys = []

        # Freeze the encoder
        if self.encoder is not None:
            for param in self.encoder.parameters():
                param.requires_grad = False
        
        if self.data_pts is not None:
            assert self.data_pts.shape[1] == self.in_dim, f'data_pts.shape: {self.data_pts.shape}, self.in_dim: {self.in_dim}'
    
    def augment_data(self, x):
        # Augment data with extra density features
        dists = torch.cdist(x, self.data_pts)
        topk, _ = torch.topk(dists, k=self.k, dim=1, largest=False, sorted=False) # [N, k]
        #topk = torch.mean(topk, dim=1).unsqueeze(1) # [N, 1]
        x = torch.cat([x, topk], dim=1)
    
        return x

    def forward(self, x):
        # Assume x is not yet normalized
        if self.normalize == True:
            x = (x - self.mean) / self.std

        if self.data_pts is not None:
            x = self.augment_data(x)
            assert x.shape[1] == self.in_dim, f'x.shape: {x.shape}, self.in_dim: {self.in_dim}'
        
        logits = self.mlp(x)

        return logits # [N, 2]
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return opt

    def step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)

        if self.loss_type == 'bce':
            loss = F.cross_entropy(logits, y)
        elif self.loss_type == 'margin':
            score = logits[0]
            neg_score = score[y==0]
            pos_score = score[y==1]
            loss = -(torch.mean(pos_score) - torch.mean(neg_score)) # maximize the difference between positive and negative scores

        return loss, logits, y

    def training_step(self, batch, batch_idx):
        loss, logits, y = self.step(batch, batch_idx)

        self.train_step_outs.append(logits)
        self.train_ys.append(y)
        self.log('train_loss', loss, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, logits, y = self.step(batch, batch_idx)

        self.val_step_outs.append(logits)
        self.val_ys.append(y)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        loss, logits, y = self.step(batch, batch_idx)

        self.test_step_outs.append(logits)
        self.test_ys.append(y)
        self.log('test_loss', loss, on_epoch=True, prog_bar=True)
        return loss
    
    def on_train_epoch_end(self):
        # Compute prediction accuracy
        train_logits = torch.cat(self.train_step_outs, dim=0) # [N, 2]
        true_classes = torch.cat(self.train_ys, dim=0) # [N]
        pred_classes = torch.argmax(train_logits, dim=1) # [N]

        train_acc = torch.sum(pred_classes == true_classes) / len(true_classes)
        self.log('train_acc', train_acc, on_epoch=True, prog_bar=True)

        # clear lists
        self.train_step_outs.clear()
        self.train_ys.clear()

    def on_validation_epoch_end(self):
        # Compute prediction accuracy
        val_logits = torch.cat(self.val_step_outs, dim=0)
        true_classes = torch.cat(self.val_ys, dim=0)
        pred_classes = torch.argmax(val_logits, dim=1)

        val_acc = torch.sum(pred_classes == true_classes) / len(true_classes)
        self.log('val_acc', val_acc, on_epoch=True, prog_bar=True)

        # clear lists
        self.val_step_outs.clear()
        self.val_ys.clear()

    def on_test_epoch_end(self):
        # Compute prediction accuracy
        test_logits = torch.cat(self.test_step_outs, dim=0)
        true_classes = torch.cat(self.test_ys, dim=0)
        pred_classes = torch.argmax(test_logits, dim=1)

        test_acc = torch.sum(pred_classes == true_classes) / len(true_classes)
        self.log('test_acc', test_acc, on_epoch=True, prog_bar=True)

        # clear lists
        self.test_step_outs.clear()
        self.test_ys.clear()
    
    def positive_prob(self, x):
        logits = self(x)
        # apply softmax to get probabilities
        softmax = torch.nn.Softmax(dim=1)
        return softmax(logits)[:, 1]
    
    def positive_score(self, x):
        logits = self(x)
        return logits[:, 1]
    
    def negative_score(self, x):
        logits = self(x)
        return logits[:, 0]

In [9]:
def get_discriminator_loaders(x, x_noisy, batch_size=256, wd_data=None):
    combined_x = np.concatenate([x, x_noisy], axis=0)

    dataset = torch.utils.data.TensorDataset(torch.tensor(combined_x, dtype=torch.float32), torch.tensor(combined_y, dtype=torch.int64))
    train_size = int(0.7 * len(dataset))
    val_size = int(0.2 * len(dataset))
    test_size = len(dataset) - train_size - val_size

    batch_size = 256
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

combined_x = np.concatenate([encodings, x_noisy], axis=0)
combined_y = np.concatenate([np.ones(encodings.shape[0]), np.zeros(x_noisy.shape[0])]).astype(np.int32)
train_loader, val_loader, test_loader = get_discriminator_loaders(encodings, x_noisy, batch_size=256, wd_data=wd_data)
data_pts = torch.tensor(encodings, dtype=torch.float32, device=device)

print('combined_x.shape: ', combined_x.shape)
print('combined_y.shape: ', combined_y.shape)

combined_x.shape:  (6000, 3)
combined_y.shape:  (6000,)


In [10]:
hyperparams = {
    'dirpath': './wd_model',
    'name': 'wd_model',
    'in_dim': x_noisy.shape[1],
    'layer_widths': [256, 128, 64],
    'activation': 'relu',
    'loss_type': 'bce',
    'normalize': False,
    'batch_norm': True,
    'dropout': 0.5,
    'use_spectral_norm': True,
    'lr': 1e-3,
    'weight_decay': 1e-4,
    'max_epochs': 300,
    'patience': 50,
    'monitor': 'val_acc',
    'mean': combined_x.mean(axis=0),
    'std': combined_x.std(axis=0),
    # 'data_pts': data_pts,
    #'data_pts': encodings,
    #'encoder': ae_model.encoder,
    # 'k': 5,
}


wd_model = Discriminator(**hyperparams)
print(wd_model)

sx, sy = next(iter(train_loader))
print(sx.shape, sy.shape)

early_stopping = pl.callbacks.EarlyStopping(monitor=hyperparams['monitor'], patience=hyperparams['patience'])
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor=hyperparams['monitor'], save_top_k=1, mode='max', 
                                                dirpath=hyperparams['dirpath'], filename=hyperparams['name'])

trainer = pl.Trainer(max_epochs=hyperparams['max_epochs'], accelerator='auto', log_every_n_steps=17, 
                     callbacks=[early_stopping, model_checkpoint])
trainer.fit(model=wd_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

self.mean:  torch.Size([3])
self.std:  torch.Size([3])
Discriminator(
  (mlp): MLP(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.5, inplace=False)
      (4): Linear(in_features=256, out_features=128, bias=True)
      (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): Dropout(p=0.5, inplace=False)
      (8): Linear(in_features=128, out_features=64, bias=True)
      (9): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): Dropout(p=0.5, inplace=False)
      (12): Linear(in_features=64, out_features=2, bias=True)
    )
  )
)
torch.Size([256, 3]) torch.Size([256])


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name | Type | Params | Mode 
--------------------------------------
0 | mlp  | MLP  | 43.2 K | train
--------------------------------------
43.2 K    Trainable params
0         Non-trainable params
43.2 K    Total params
0.173     Total estimated model params size (MB)


Epoch 50: 100%|██████████| 17/17 [00:00<00:00, 83.52it/s, v_num=286, val_loss=0.574, val_acc=0.691, train_acc=0.641] 


In [11]:
### Evaluate Old Discriminator on combined_x
# old_acc = 0
# old_wd_model.to(device)
# old_wd_model.eval()
# with torch.no_grad():
#     for i in range(0, len(combined_x), batch_size):
#         x_batch = torch.tensor(combined_x[i:i+batch_size], dtype=torch.float32).to(device)
#         logits = old_wd_model(x_batch)
#         pred_classes = torch.argmax(logits, dim=1)
#         true_classes = torch.tensor(combined_y[i:i+batch_size], dtype=torch.int64).to(device)
#         old_acc += torch.sum(pred_classes == true_classes).item()

# old_acc /= len(combined_x)

# print('Old Discriminator accuracy: ', old_acc)

### Evaluate New Discriminator on combined_x
wd_model.to(device)
wd_model.eval()
wd_acc = 0
with torch.no_grad():
    for i in range(0, len(combined_x), batch_size):
        x_batch = torch.tensor(combined_x[i:i+batch_size], dtype=torch.float32).to(device)
        logits = wd_model(x_batch)
        pred_classes = torch.argmax(logits, dim=1)
        true_classes = torch.tensor(combined_y[i:i+batch_size], dtype=torch.int64).to(device)
        wd_acc += torch.sum(pred_classes == true_classes).item()
wd_acc /= len(combined_x)

print('New Discriminator accuracy: ', wd_acc)

New Discriminator accuracy:  0.6855


In [12]:
def visualize_probs(ae_model, wd_model, combined_x, combined_y, save_file='wd_probs_new.html'):
    if combined_x.shape[1] > 3:
        wd_encodings = encode_data(combined_x, ae_model.encoder)
    else:
        wd_encodings = combined_x
    
    probs = []
    wd_model.to(device)
    wd_model.eval()
    with torch.no_grad():
        for i in range(0, len(wd_encodings), batch_size):
            x_batch = torch.tensor(wd_encodings[i:i+batch_size], dtype=torch.float32).to(device)
            probs.append(wd_model.positive_prob(x_batch).cpu().detach().numpy())
    probs = np.concatenate(probs, axis=0)

    # Plot
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(x=wd_encodings[:,0], y=wd_encodings[:,1], z=wd_encodings[:,2],
                            mode='markers', marker=dict(size=2, color=combined_y, colorscale='Viridis', opacity=0.8)))
    fig.add_trace(go.Scatter3d(x=wd_encodings[:,0], y=wd_encodings[:,1], z=wd_encodings[:,2],
                                mode='markers', marker=dict(size=2, color=probs, colorscale='Viridis', opacity=0.8)))
    fig.add_trace(go.Scatter3d(x=wd_encodings[:,0], y=wd_encodings[:,1], z=wd_encodings[:,2],
                                mode='markers', marker=dict(size=2, color=(probs > probs.mean()).astype(int), colorscale='Viridis', opacity=0.8)))
    fig.show()

    fig.write_html(f'./generated_neg/{save_file}')

    return wd_encodings, probs

_, probs = visualize_probs(ae_model, wd_model, combined_x, combined_y, save_file=f'{noise_type}_{noise_levels}_new_probs.html')

visualize_probs(ae_model, wd_model, wd_data['data'], wd_data['mask_x'], save_file=f'{noise_type}_{noise_levels}_old_probs.html')


(array([[ 0.2487267 ,  0.5307898 ,  0.14138453],
        [ 0.4516882 ,  0.5163434 , -0.33334905],
        [ 0.35367313,  0.4715617 , -0.28540057],
        ...,
        [ 1.3262057 ,  0.07980571, -0.68556845],
        [-0.7557104 ,  0.13464001, -0.16202168],
        [ 1.1320058 , -0.0158112 , -0.6627188 ]], dtype=float32),
 array([0.4732698 , 0.7268592 , 0.56414217, ..., 0.48164865, 0.55558896,
        0.34201324], dtype=float32))

### Select start/end points

In [13]:
def sample_indices_within_range(points, selected_idx=None, range_size=0.1, num_samples=20, seed=23):
    np.random.seed(seed)
    # Randomly select two points from the array
    if selected_idx is None:
        selected_indices = np.random.choice(points.shape[0], 2, replace=False)
        point1_idx, point2_idx = selected_indices[0], selected_indices[1]
        point1, point2 = points[point1_idx], points[point2_idx]
    else:
        point1_idx, point2_idx = selected_idx
        point1, point2 = points[point1_idx], points[point2_idx]    
    # Function to find indices of points within the range of a given point
    def _find_indices_within_range(point):
        distances = np.linalg.norm(points - point, axis=1)
        within_range_indices = np.where(distances <= range_size)[0]
        return within_range_indices
    
    # Find indices within range of point1 and point2
    indices_within_range1 = _find_indices_within_range(point1)
    indices_within_range2 = _find_indices_within_range(point2)
    
    # Randomly sample indices within the range
    if len(indices_within_range1) >= num_samples:
        sampled_indices_point1 = np.random.choice(indices_within_range1, num_samples, replace=False)
    else:
        sampled_indices_point1 = indices_within_range1
        
    
    if len(indices_within_range2) >= num_samples:
        sampled_indices_point2 = np.random.choice(indices_within_range2, num_samples, replace=False)
    else:
        sampled_indices_point2 = indices_within_range2
    
    return point1_idx, sampled_indices_point1, point2_idx, sampled_indices_point2

start_group = 0
end_group = 3
eb_labels = subset_data['colors']
print('eb_labels: ', eb_labels.shape)
start_idices = np.where(eb_labels == start_group)[0]
start_idx = np.random.choice(start_idices, 1)[0]
end_idices = np.where(eb_labels == end_group)[0]
end_idx = np.random.choice(end_idices, 1)[0]
start_idx = 736
end_idx = 2543
print('start_idx, end_idx: ', start_idx, end_idx)

encodings = encode_data(subset_data['data'], ae_model.encoder)
print('encodings.shape: ', encodings.shape)
point1_idx, sampled_indices_point1, point2_idx, sampled_indices_point2 = sample_indices_within_range(encodings, 
                                                                                                     selected_idx=(start_idx, end_idx),
                                                                                                     range_size=0.3, 
                                                                                                     seed=2024, num_samples=64)
x = subset_data['data']
point1 = x[point1_idx]
point2 = x[point2_idx]
samples_point1 = x[sampled_indices_point1]
samples_point2 = x[sampled_indices_point2]
print('point1, point2: ', point1.shape, point2.shape)
print('samples_point1, samples_point2: ', samples_point1.shape, samples_point2.shape)

print('point1, point2: ', point1_idx, point2_idx) # point1, point2: 736 2543 (0-3)
samples_z_point1 = ae_model.encoder(torch.tensor(samples_point1, dtype=torch.float32).to(device)).cpu().detach().numpy()
samples_z_point2 = ae_model.encoder(torch.tensor(samples_point2, dtype=torch.float32).to(device)).cpu().detach().numpy()
start_end_z = np.vstack([encodings[start_idx], encodings[end_idx]])
print('samples_z_point1, samples_z_point2: ', samples_z_point1.shape, samples_z_point2.shape)
print('start_end_z: ', start_end_z.shape)
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=encodings[:,0], y=encodings[:,1], z=encodings[:,2], 
                           mode='markers', marker=dict(size=2, color='gray', colorscale='Viridis', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=start_end_z[:,0], y=start_end_z[:,1], z=start_end_z[:,2],
                            mode='markers', marker=dict(size=5, color='red', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point1[:,0], y=samples_z_point1[:,1], z=samples_z_point1[:,2],
                            mode='markers', marker=dict(size=2, color='blue', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point2[:,0], y=samples_z_point2[:,1], z=samples_z_point2[:,2],
                            mode='markers', marker=dict(size=2, color='green', opacity=0.8)))


eb_labels:  (3000,)
start_idx, end_idx:  736 2543
encodings.shape:  (3000, 3)
point1, point2:  (50,) (50,)
samples_point1, samples_point2:  (64, 50) (64, 50)
point1, point2:  736 2543
samples_z_point1, samples_z_point2:  (64, 3) (64, 3)
start_end_z:  (2, 3)


In [14]:
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, x0, x1):
        self.x0 = x0
        self.x1 = x1

    def __len__(self):
        return max(len(self.x0), len(self.x1))

    def __getitem__(self, idx):
        return self.x0[idx % len(self.x0)], self.x1[idx % len(self.x1)]

def custom_collate_fn(batch):
    x0_batch = torch.stack([item[0] for item in batch])
    x1_batch = torch.stack([item[1] for item in batch])
    
    # Randomly permute the elements in the batch
    perm_x0 = torch.randperm(len(x0_batch))
    perm_x1 = torch.randperm(len(x1_batch))

    x0_batch = x0_batch[perm_x0]
    x1_batch = x1_batch[perm_x1]
    
    return x0_batch, x1_batch

# Create a dataloader
bs = 32
dataset = CustomDataset(x0=torch.tensor(samples_point1, dtype=torch.float32), 
                        x1=torch.tensor(samples_point2, dtype=torch.float32))
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True, collate_fn=custom_collate_fn)

In [15]:
fig = go.Figure()
old_wd_data = wd_data['data']
fig.add_trace(go.Scatter3d(x=old_wd_data[:,0], y=old_wd_data[:,1], z=old_wd_data[:,2],
                           mode='markers', marker=dict(size=2, color='gray', colorscale='Viridis', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=old_wd_data[:,0], y=old_wd_data[:,1], z=old_wd_data[:,2],
                            mode='markers', marker=dict(size=2, color=wd_data['mask_x'], colorscale='Viridis', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=encodings[:,0], y=encodings[:,1], z=encodings[:,2], 
                           mode='markers', marker=dict(size=2, color='gray', colorscale='Viridis', opacity=0.8)))

In [32]:
from geodesic import GeodesicFM

ae_model = ae_model.to(device)
wd_model = wd_model.to(device)
old_wd_model = old_wd_model.to(device)

ae_model.eval()
wd_model.eval()
old_wd_model.eval()

for param in ae_model.encoder.parameters():
    param.requires_grad = False
for param in wd_model.parameters():
    param.requires_grad = False
for param in old_wd_model.parameters():
    param.requires_grad = False

enc_func = lambda x: ae_model.encoder(x)
#disc_func = lambda x: (wd_model.positive_proba(x) - min_prob) / (max_prob - min_prob) # normalize the probability to [0, 1]
disc_func = lambda x: 1 - wd_model.positive_prob(enc_func(x))
#disc_func = lambda x: 1 - old_wd_model.positive_proba(x)

ofm, extended_dim_func = offmanifolder_maker_new(enc_func, disc_func, disc_factor=5, 
                                                 data_encodings=torch.tensor(encodings, dtype=torch.float32).to(device)) # ofm encodes both on/off manifold points
#ofm, extended_dim_func = offmanifolder_density_maker(enc_func, disc_func, disc_factor=5, data_encodings=torch.tensor(encodings, dtype=torch.float32).to(device))

# TODO: 1. cc_k default is 2, may want to increase it. 2. symmetric can be false. 
# 3. increase hidden 4. embed t as well. 5. decreasew weight decay 6. increase length loss weight 7. train DW in embedding space
gbmodel = GeodesicFM(
    func=ofm,
    encoder=enc_func,
    input_dim=x.shape[1],
    hidden_dim=64, 
    scale_factor=1, 
    symmetric=True, 
    num_layers=3, 
    n_tsteps=100, 
    lr=1e-3,
    weight_decay=1e-4,
    flow_weight=0,
    length_weight=1,
    cc_k=5,
    density_weight=0.,
    use_density=False,
    data_pts=torch.tensor(x, dtype=torch.float32).to(device),
    visualize_training=True,
    dataloader=dataloader,
    device=device,
)

gbmodel.lr=1e-3
early_stopping = pl.callbacks.EarlyStopping(monitor='train_loss_epoch', patience=150, mode='min')
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor='train_loss_epoch', save_top_k=1, mode='min', 
                                                dirpath='./eb_fm/checkpoints', filename='gbmodel')
trainer = pl.Trainer(
    # logger=logger,
    max_epochs=1,
    log_every_n_steps=20,
    accelerator=device,
    callbacks=[early_stopping, model_checkpoint]
)

# remove the training file if it exists
if os.path.exists('./eb_fm/training'):
    shutil.rmtree('./eb_fm/training')

trainer.fit(gbmodel, train_dataloaders=dataloader)
checkpoint_dir = './eb_fm/checkpoints'

#trainer.save_checkpoint(f"{checkpoint_dir}/{start_group}_{end_group}_gbmodel.ckpt")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name       | Type             | Params | Mode 
--------------------------------------------------------
0 | cc         | CondCurveOverfit | 64.1 K | train
1 | flow_model | MLP              | 14.9 K | train
--------------------------------------------------------
79.0 K    Trainable params
0         Non-trainable params
79.0 K    Total params
0.316     Total estimated model params size (MB)


Starting training at epoch 0
self.visualize_x0:  torch.Size([64, 50]) mps:0
trajectories:  torch.Size([100, 64, 50])
on_train_epoch_end, traj_z:  (100, 64, 3)
Epoch 0: 100%|██████████| 2/2 [00:08<00:00,  0.24it/s, v_num=291, loss_length_step=7.640, fm_length_step=6.420, loss_step=7.640, train_loss_step=7.640, loss_length_epoch=7.610, fm_length_epoch=6.540, loss_epoch=7.610, train_loss_epoch=7.610]self.visualize_x0:  torch.Size([64, 50]) mps:0
trajectories:  torch.Size([100, 64, 50])
on_train_epoch_end, traj_z:  (100, 64, 3)


`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 2/2 [00:09<00:00,  0.21it/s, v_num=291, loss_length_step=7.640, fm_length_step=6.420, loss_step=7.640, train_loss_step=7.640, loss_length_epoch=7.610, fm_length_epoch=6.540, loss_epoch=7.610, train_loss_epoch=7.610]

====frames===moviepy:  ['trajs_epoch_0000.png', 'trajs_epoch_0001.png']
Moviepy - Building video ./eb_fm/trajs.mp4.                                                                                                                                                                                              
Moviepy - Writing video ./eb_fm/trajs.mp4                                                                                                                                                                                                

Epoch 4:   0%|          | 0/2 [01:31<?, ?it/s, v_num=289, loss_length_step=7.410, fm_length_step=6.540, loss_step=7.410, train_loss_step=7.410, loss_length_epoch=7.430, fm_length_epoch=6.480, loss_epoch=7.430, train_



Moviepy - Done !                                                                                                                                                                                                                         
Moviepy - video ready ./eb_fm/trajs.mp4                                                                                                                                                                                                  
Epoch 4:   0%|          | 0/2 [01:31<?, ?it/s, v_num=289, loss_length_step=7.410, fm_length_step=6.540, loss_step=7.410, train_loss_step=7.410, loss_length_epoch=7.430, fm_length_epoch=6.480, loss_epoch=7.430, train_loss_epoch=7.430]

In [None]:
print(encodings.shape)
print(wd_data['data'].shape)

def visualize_extended_dim(extended_dim_func, x):
    if isinstance(x, np.ndarray):
        x = torch.tensor(x, dtype=torch.float32).to(device)
    extended_dims = extended_dim_func(x).cpu().numpy()

    x_encodings = enc_func(x).cpu().numpy()
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(x=x_encodings[:,0], y=x_encodings[:,1], z=x_encodings[:,2], 
                                 mode='markers', marker=dict(size=2, color=extended_dims.tolist(), opacity=0.8)))
    # add hovertext
    fig.update_traces(hovertext=[ f'{ed}' for ed in extended_dims.tolist()])
    fig.show()
#visualize_extended_dim(extended_dim_func, wd_data['data'])

In [None]:
def generate_z_traj(gbmodel, ae_model, dataloader, device, 
                    plot=False, 
                    encodings=encodings, samples_z_point1=samples_z_point1, samples_z_point2=samples_z_point2,
                    save_file=f'Bridge_{noise_type}_{noise_levels}_EB_{start_group}_{end_group}.html'):
    x0 = []
    x1 = []
    for x0_, x1_ in dataloader:
        x0.append(x0_)
        x1.append(x1_)
    x0 = torch.cat(x0, dim=0).to(device)
    x1 = torch.cat(x1, dim=0).to(device)

    ids = torch.zeros((x0.size(0),1), device=device)  # Conditional id for each pair of x0 and x1, here is a dummy.

    gbmodel = gbmodel.to(device)
    gbmodel.eval()
    ae_model = ae_model.to(device)
    ae_model.eval()

    with torch.no_grad():
        x_traj = gbmodel(x0, x1, gbmodel.ts.to(device), ids)  # [T, B, D]
        z_traj = ae_model.encoder(x_traj.flatten(0,1))  # [T*B, D]
        z0 = ae_model.encoder(x0)
        z1 = ae_model.encoder(x1)

    z_traj = z_traj.cpu().detach().numpy().reshape(x_traj.size(0), x_traj.size(1), -1)
    z0 = z0.cpu().detach().numpy()
    z1 = z1.cpu().detach().numpy()

    if plot:
        fig = go.Figure()
        fig.add_trace(go.Scatter3d(x=encodings[:,0], y=encodings[:,1], z=encodings[:,2], 
                                   mode='markers', marker=dict(size=2, color='gray', colorscale='Viridis', opacity=0.8)))
        fig.add_trace(go.Scatter3d(x=samples_z_point1[:,0], y=samples_z_point1[:,1], z=samples_z_point1[:,2],
                                    mode='markers', marker=dict(size=5, color='blue', opacity=0.8)))
        fig.add_trace(go.Scatter3d(x=samples_z_point2[:,0], y=samples_z_point2[:,1], z=samples_z_point2[:,2],
                                    mode='markers', marker=dict(size=5, color='green', opacity=0.8)))
        
        for i in range(30):
            fig.add_trace(go.Scatter3d(x=z_traj[:,i,0], y=z_traj[:,i,1], z=z_traj[:,i,2],
                                       mode='lines', line=dict(width=2, color='blue')))
        fig.show()
        fig.write_html(save_file)

    return z_traj, z0, z1

# Generate trajectories
z_traj, z0, z1 = generate_z_traj(gbmodel, ae_model, dataloader, device, plot=True)
print('GeoBridge z_traj.shape: ', z_traj.shape)

In [None]:
import torchdiffeq
from torch import nn

adjoint = False
if adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

class ODEFuncWrapper(nn.Module):
    def __init__(self, flowmodel):
        super().__init__()
        self.flowmodel = flowmodel
    def forward(self, t, y):
        # Expand t to match the batch size and feature dimension
        t_expanded = t.view(1, 1).expand(y.size(0), 1)
        # Concatenate y and t along the feature dimension
        y_with_t = torch.cat((y, t_expanded), dim=-1)
        return self.flowmodel(y_with_t)

n_samples = min(100, x0.size(0))
sampled_starts = torch.tensor(x[sampled_indices_point1[:n_samples]], dtype=torch.float32).to('cpu')

print(f'Run ODE on {n_samples} samples: {sampled_starts.shape}')
flowfunc = ODEFuncWrapper(gbmodel.flow_model.to('cpu'))
with torch.no_grad():
    traj = odeint(flowfunc, sampled_starts.to('cpu'), gbmodel.ts.to('cpu'))

print('Flow Matching ODE Trajectory shape: ', traj.shape)

ae_model.eval()
with torch.no_grad():
    traj = traj.to(device)
    z_traj = ae_model.encoder(traj.flatten(0,1)) # [T*B, D]

z_traj = z_traj.cpu().detach().numpy()
z_traj = z_traj.reshape(traj.size(0), traj.size(1), -1)
print('z_traj.shape: ', z_traj.shape)


In [None]:
# Plot 2d
fig = go.Figure()
fig.add_trace(go.Scatter(x=encodings[:,0], y=encodings[:,1], mode='markers', marker=dict(size=2, color='gray')))
fig.add_trace(go.Scatter(x=samples_z_point1[:,0], y=samples_z_point1[:,1], mode='markers', marker=dict(size=5, color='blue')))
fig.add_trace(go.Scatter(x=samples_z_point2[:,0], y=samples_z_point2[:,1], mode='markers', marker=dict(size=5, color='green')))
for i in range(n_samples):
    fig.add_trace(go.Scatter(x=z_traj[:,i,0], y=z_traj[:,i,1], mode='lines', line=dict(width=2, color='blue')))
fig.show()

# save html
fig.write_html(f"./eb_fm/EB_{start_group}_{end_group}_{noise_type}_{noise_levels}_2d.html")

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=encodings[:,0], y=encodings[:,1], z=encodings[:,2], 
                           mode='markers', marker=dict(size=2, color='gray', colorscale='Viridis', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point1[:,0], y=samples_z_point1[:,1], z=samples_z_point1[:,2],
                            mode='markers', marker=dict(size=5, color='blue', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point2[:,0], y=samples_z_point2[:,1], z=samples_z_point2[:,2],
                            mode='markers', marker=dict(size=5, color='green', opacity=0.8)))

for i in range(30):
    fig.add_trace(go.Scatter3d(x=z_traj[:,i,0], y=z_traj[:,i,1], z=z_traj[:,i,2],
                                 mode='lines', line=dict(width=2, color='blue')))

fig.show()

# save the html
fig.write_html(f"./eb_fm/EB_{start_group}_{end_group}_{noise_type}_{noise_levels}_3d.html")

In [None]:
# Plot positive_prob of these trajectories
traj.shape # [T, B, D]
traj_flat = traj.flatten(0,1) # [T*B, D]

bs = 32
traj_probs = []
wd_model = wd_model.to(device)
wd_model.eval()
with torch.no_grad():
    for i in range(0, traj_flat.shape[0], bs):
        x_batch = traj_flat[i:i+bs]
        prob = wd_model.positive_prob(enc_func(x_batch))
        #print(prob)
        traj_probs.append(prob.cpu().detach().numpy())
traj_probs = np.concatenate(traj_probs, axis=0) # [T*B]
traj_probs = traj_probs.reshape(traj.size(0), traj.size(1)) # [T, B]
print('traj_probs.shape: ', traj_probs.shape)

fig = go.Figure()
fig.add_trace(go.Scatter3d(x=encodings[:,0], y=encodings[:,1], z=encodings[:,2], 
                           mode='markers', marker=dict(size=2, color='gray', colorscale='Viridis', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point1[:,0], y=samples_z_point1[:,1], z=samples_z_point1[:,2],
                            mode='markers', marker=dict(size=5, color='blue', opacity=0.8)))
fig.add_trace(go.Scatter3d(x=samples_z_point2[:,0], y=samples_z_point2[:,1], z=samples_z_point2[:,2],
                            mode='markers', marker=dict(size=5, color='green', opacity=0.8)))

for i in range(30):    
    fig.add_trace(go.Scatter3d(x=z_traj[:,i,0], y=z_traj[:,i,1], z=z_traj[:,i,2],
                                 mode='markers', marker=dict(size=2, color=traj_probs[:,i].mean(), colorscale='Viridis')))
    prob_texts = []
    for ti in range(traj_probs.shape[0]):
        prob_texts.append(f'Prob: {traj_probs[ti,i]:.2f}')
        #print(ti, ':', prob_texts)
    fig.update_traces(hovertext=prob_texts)

fig.show()

fig.write_html(f"./eb_fm/EB_{start_group}_{end_group}_{noise_type}_{noise_levels}_probs.html")


