In [1]:
# Including Uitilities and TimelyGPT Libraries
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'DejaVu Sans'
import time
import torch.nn.functional as F
import copy
from tqdm import tqdm
import pathlib
import geomloss
import itertools

# --- Importing TimelyGPT module and configs ---
# 将项目根目录添加到 Python 路径中
project_root = str(pathlib.Path.cwd().resolve().parent)
sys.path.insert(0, project_root)
os.chdir(project_root)  # Change the current working directory to the project root

from model.TimelyGPT_CTS.layers.configs import RetNetConfig
from model.TimelyGPT_CTS.layers.Retention_layers import RetNetBlock

# --- Importing utils ---
from notebooks.BenchmarkUtils import loadSCData, tpSplitInd, splitBySpec
from optim.evaluation import globalEvaluation
from plotting.PlottingUtils import umapWithPCA


ModuleNotFoundError: No module named 'torch'

In [3]:
pip install torch numpy matplotlib time copy tqdm pathlib geomloss itertools

ValueError: The python kernel does not appear to be a conda environment.  Please use ``%pip install`` instead.

In [None]:
# VAE and TimelyGPT Model Definitions
class TimelyGPT_CellLevel(nn.Module):
    """
    TimelyGPT for cell-level trajectory prediction.
    Input: multiple cells at initial timepoint
    Output: predicted states of these cells at future timepoints
    
    [MODIFIED] This version uses auto-regressive "GPT-style" generation
    instead of broadcasting z0.
    """
    def __init__(self, config, latent_dim):
        super(TimelyGPT_CellLevel, self).__init__()

        self.num_layers = config.num_layers
        self.d_model = config.d_model
        self.latent_dim = latent_dim

        # Input projection to model dimension
        self.input_projection = nn.Linear(latent_dim, config.d_model)

        # Time embedding, which is learnable
        # self.time_embedding = nn.Embedding(20, config.d_model) # <-- REMOVED as requested

        # RetNet blocks
        self.blocks = nn.ModuleList([RetNetBlock(config) for _ in range(self.num_layers)])

        self.ln_f = nn.LayerNorm(config.d_model)
        self.output_projection = nn.Linear(config.d_model, latent_dim)
        
        self.gradient_checkpointing = config.use_grad_ckp if hasattr(config, 'use_grad_ckp') else False
    
    # [MODIFIED] This entire 'forward' method is replaced
    def forward(self, cells_latent, time_indices, forward_impl='parallel'): # Signature kept for compatibility
        """
        Predict future states for a batch of cells using auto-regressive "GPT-style" generation.
        
        Args:
            cells_latent: [n_cells, latent_dim] - initial cell states (z0) in latent space
            time_indices: [n_timepoints] - indices of timepoints to predict (e.g., [0, 1, ..., 11])
            forward_impl: (Ignored) Kept for compatibility. Logic is now always recurrent.
        Returns:
            predictions: [n_cells, n_timepoints, latent_dim] - predicted cell states [z0_pred, z1_pred, ...]
        """
        n_cells = cells_latent.shape[0]
        n_timepoints = len(time_indices)
        
        predictions_list = []
        
        # Initialize past_key_values for each layer's recurrent state
        past_key_values = [None] * self.num_layers
        
        # The first input (for t=0) is the encoded z0 from the VAE
        current_latent_input = cells_latent  # [n_cells, latent_dim]
        
        # Loop for each timepoint to generate auto-regressively
        for t in range(n_timepoints):
            # 1. Project current latent input (z_t) to hidden state (h_t)
            # Input to blocks needs shape: [n_cells, 1, d_model]
            h_current_step_input = self.input_projection(current_latent_input).unsqueeze(1)
            
            new_past_key_values = []
            
            # 2. Pass through RetNet blocks (must use recurrent mode)
            for i, block in enumerate(self.blocks):
                block_outputs = block(
                    h_current_step_input,
                    retention_mask=None,
                    forward_impl='recurrent',  # This logic requires the recurrent implementation
                    past_key_value=past_key_values[i],
                    sequence_offset=t,  # Pass current time 't' for RoPE/xPos
                    chunk_size=None,
                    output_retentions=False
                )
                h_current_step_input = block_outputs[0]  # Output hidden state [n_cells, 1, d_model]
                new_past_key_values.append(block_outputs[1]) # Save new recurrent state
            
            # Update the past_key_values for the next iteration
            past_key_values = new_past_key_values
            
            # 3. Get output hidden state from the last block
            # h_current_output: [n_cells, d_model]
            h_current_output = h_current_step_input.squeeze(1)
            
            # 4. Project back to latent space to get prediction for this step (z_t_pred)
            # z_pred_current_step: [n_cells, latent_dim]
            z_pred_current_step = self.output_projection(self.ln_f(h_current_output))
            
            # 5. Store prediction
            predictions_list.append(z_pred_current_step)
            
            # 6. Set the output prediction as the input for the *next* time step
            current_latent_input = z_pred_current_step
        
        # Stack all predictions along the time dimension
        # predictions: [n_cells, n_timepoints, latent_dim]
        predictions = torch.stack(predictions_list, dim=1)
        
        return predictions


# --- VAE Model Definition ---
class Encoder(nn.Module):
    def __init__(self, n_genes, latent_dim, hidden_dims=[128, 128]):
        super(Encoder, self).__init__()
        layers = []
        in_dim = n_genes
        for h_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(nn.ReLU())
            in_dim = h_dim
        self.encoder_net = nn.Sequential(*layers)
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_log_var = nn.Linear(hidden_dims[-1], latent_dim)

    def forward(self, x):
        h = self.encoder_net(x)
        mu = self.fc_mu(h)
        log_var = self.fc_log_var(h)
        return mu, log_var

class Decoder(nn.Module):
    def __init__(self, latent_dim, n_genes, hidden_dims=[128, 128]):
        super(Decoder, self).__init__()
        layers = []
        in_dim = latent_dim
        # Use a reversed architecture for the decoder
        reversed_hidden_dims = list(reversed(hidden_dims))
        for h_dim in reversed_hidden_dims:
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(nn.ReLU())
            in_dim = h_dim
        layers.append(nn.Linear(reversed_hidden_dims[-1], n_genes))
        self.decoder_net = nn.Sequential(*layers)

    def forward(self, z):
        # Handle 3D input for trajectory decoding
        if z.dim() == 3:
            n_cells, n_timepoints, latent_dim = z.shape
            z_flat = z.reshape(-1, latent_dim)
            recon_flat = self.decoder_net(z_flat)
            return recon_flat.reshape(n_cells, n_timepoints, -1)
        else:
            return self.decoder_net(z)

class VAE(nn.Module):
    def __init__(self, n_genes, latent_dim, enc_hidden_dims=[128, 128], dec_hidden_dims=[128, 128]):
        super(VAE, self).__init__()
        self.encoder = Encoder(n_genes, latent_dim, enc_hidden_dims)
        self.decoder = Decoder(latent_dim, n_genes, dec_hidden_dims)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decoder(z)
        return recon_x, z, mu, log_var

# --- Combined VAE-TimelyGPT Model ---
class VAETimelyGPT(nn.Module):
    def __init__(self, vae, timely_gpt):
        super(VAETimelyGPT, self).__init__()
        self.vae = vae
        self.timely_gpt = timely_gpt

    def forward(self, initial_cells_obs, time_indices, forward_impl='parallel'):
        """
        Forward pass matching scNODE style:
        1. Encode initial cells to latent space
        2. Predict future latent states with TimelyGPT (replacing ODE)
        3. Decode predicted latent states back to gene space
        Returns: recon_obs, first_latent_dist (mu, log_var), first_tp_data, latent_seq
        """
        # 1. Encode initial cells to latent space
        mu, log_var = self.vae.encoder(initial_cells_obs)
        z_initial = self.vae.reparameterize(mu, log_var)

        # 2. Predict future latent states with TimelyGPT
        # z_predictions (latent_seq): [n_cells, n_timepoints, latent_dim]
        latent_seq = self.timely_gpt(z_initial, time_indices, forward_impl)

        # 3. Decode predicted latent states back to gene space
        recon_obs = self.vae.decoder(latent_seq)

        # Return format matching scNODE: (recon_obs, first_latent_dist, first_tp_data, latent_seq)
        first_latent_dist = (mu, log_var)
        first_tp_data = initial_cells_obs
        
        return recon_obs, first_latent_dist, first_tp_data, latent_seq


def plotPredTestTime(true_umap, pred_umap, true_tps, pred_tps, test_tps_list, save_path=None):
    """
    Plot UMAP visualization comparing true and predicted data at test timepoints.
    Left: True data (all timepoints in gray, test timepoints highlighted)
    Right: Predicted data (all timepoints in gray, test timepoints highlighted)
    """
    import matplotlib.colors as mcolors
    colors = list(mcolors.TABLEAU_COLORS.values())
    gray_color = "#D3D3D3"
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Left plot: True data
    ax1.set_title("True Data", fontsize=15)
    ax1.scatter(true_umap[:, 0], true_umap[:, 1], c=gray_color, s=20, alpha=0.5, label="other")
    for i, t in enumerate(test_tps_list):
        mask = true_tps == t
        if np.any(mask):
            ax1.scatter(true_umap[mask, 0], true_umap[mask, 1], 
                       c=colors[i % len(colors)], s=30, alpha=1.0, label=f"t={int(t)}")
    ax1.set_xlabel("UMAP 1"), ax1.set_ylabel("UMAP 2")
    ax1.legend(loc="best")
    
    # Right plot: Predicted data
    ax2.set_title("Predictions (VAE + TimelyGPT)", fontsize=15)
    ax2.scatter(true_umap[:, 0], true_umap[:, 1], c=gray_color, s=20, alpha=0.5, label="other")
    for i, t in enumerate(test_tps_list):
        mask = pred_tps == t
        if np.any(mask):
            ax2.scatter(pred_umap[mask, 0], pred_umap[mask, 1],
                       c=colors[i % len(colors)], s=30, alpha=1.0, label=f"t={int(t)}")
    ax2.set_xlabel("UMAP 1"), ax2.set_ylabel("UMAP 2")
    ax2.legend(loc="best")
    
    plt.tight_layout()
    plt.show()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to {save_path}")
    plt.close()


In [None]:
# --- Training Configs ---
DATA_NAME, SPLIT_TYPE = "zebrafish", "three_interpolation"
LATENT_DIM = 64
NUM_HEADS = 8 # should be divisor of d_model
LATENT_COEFF = 1.0  # Regularization coefficient for latent trajectory smoothing (beta)
N_PRED_CELLS = 5000  # Number of cells to predict
EPOCHS = 10  # Maximum number of main training epochs
ITERS_PER_EPOCH = 100  # Number of iterations per epoch
BATCH_SIZE = 32
PRETRAIN_LR = 3e-4
PRETRAIN_ITERS = 500  # Number of pre-training iterations
LR = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate: {LR} with exponential decay (gamma=0.99)")
print(f"Pre-training iterations: {PRETRAIN_ITERS}")
print(f"Latent smoothing coefficient: {LATENT_COEFF}")


In [None]:
# --- Data loading and preprocessing ---
ann_data, cell_tps, cell_types, n_genes, n_tps, all_tps = loadSCData(DATA_NAME, SPLIT_TYPE)
train_tps_idx, test_tps_idx = tpSplitInd(DATA_NAME, SPLIT_TYPE)
data = ann_data.X

# Convert to torch tensors (cell_tps ranges from 1 to n_tps)
traj_data = [torch.FloatTensor(data[np.where(cell_tps == t)[0], :]) for t in range(1, n_tps + 1)]
all_tps = list(all_tps)  # Convert to list
train_data, test_data = splitBySpec(traj_data, train_tps_idx, test_tps_idx)
n_cells = [each.shape[0] for each in traj_data]

print(f"# timepoints={n_tps}, # genes={n_genes}")
print(f"# cells per timepoint: {n_cells}")
print(f"Train timepoints: {train_tps_idx}")
print(f"Test timepoints: {test_tps_idx}")

In [None]:
# --- Model  ---
# 1. VAE model
vae = VAE(n_genes, LATENT_DIM).to(DEVICE)

# 2. TimelyGPT model
timely_gpt_config = RetNetConfig(
    num_layers=3,
    num_heads=NUM_HEADS,
    d_model=LATENT_DIM,
    qk_dim=LATENT_DIM,
    v_dim=LATENT_DIM,
    ffn_proj_size=200,
    use_bias_in_msr=False,
    use_bias_in_mlp=True,
    use_bias_in_msr_out=False,
    use_default_gamma=True,
    forward_impl='parallel'
)
timely_gpt_config.use_grad_ckp = False

timely_gpt_model = TimelyGPT_CellLevel(
    config=timely_gpt_config,
    latent_dim=LATENT_DIM
)

# 3. Combined model
model = VAETimelyGPT(vae, timely_gpt_model).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=LR)

In [None]:
# ======================================================
# Phase 1: Fast Pre-training - Train only VAE Encoder and Decoder
# ======================================================
train_start_time = time.time()

# Prepare time indices for training
train_time_indices = torch.LongTensor(train_tps_idx).to(DEVICE)
train_tps_tensor = torch.FloatTensor(train_tps_idx).to(DEVICE)

print("\n[Phase 1] Fast Pre-training: Training VAE Encoder and Decoder only...")
latent_encoder = model.vae.encoder
obs_decoder = model.vae.decoder
all_train_data = torch.cat(train_data, dim=0).to(DEVICE)

if PRETRAIN_ITERS > 0:
    # Only train encoder and decoder parameters
    dim_reduction_params = itertools.chain(*[latent_encoder.parameters(), obs_decoder.parameters()])
    dim_reduction_optimizer = torch.optim.Adam(params=dim_reduction_params, lr=PRETRAIN_LR, betas=(0.95, 0.99))
    pretrain_scheduler = torch.optim.lr_scheduler.ExponentialLR(dim_reduction_optimizer, gamma=0.99)
    latent_encoder.train()
    obs_decoder.train()
    
    best_pretrain_loss = float('inf')
    best_pretrain_state = None
    
    pbar = tqdm(range(PRETRAIN_ITERS), desc="Pre-training VAE")
    for i in pbar:
        # Sample random batch from all training data
        rand_idx = np.random.choice(all_train_data.shape[0], size=BATCH_SIZE, replace=False)
        batch_data = all_train_data[rand_idx, :]
        
        dim_reduction_optimizer.zero_grad()
        
        # Encode -> sample -> decode (NO KL divergence term)
        latent_mu, latent_log_var = latent_encoder(batch_data)
        latent_std = torch.exp(0.5 * latent_log_var)
        latent_sample = latent_mu + torch.randn_like(latent_std) * latent_std
        recon_obs = obs_decoder(latent_sample)
        
        # Reconstruction MSE loss only (no KL divergence)
        recon_loss = torch.mean((recon_obs - batch_data) ** 2)
        recon_loss.backward()
        dim_reduction_optimizer.step()
        pretrain_scheduler.step()
        
        # Update progress bar
        pbar.set_postfix({"Loss": f"{recon_loss.item():.6f}"})
        
        # Save the best model
        if recon_loss.item() < best_pretrain_loss:
            best_pretrain_loss = recon_loss.item()
            best_pretrain_state = {
                'encoder': copy.deepcopy(latent_encoder.state_dict()),
                'decoder': copy.deepcopy(obs_decoder.state_dict())
            }

    print(f"Pre-training completed. Best loss: {best_pretrain_loss:.6f}")
    
    # Load the best pre-trained model
    if best_pretrain_state:
        print("Loading best pre-trained model state...")
        latent_encoder.load_state_dict(best_pretrain_state['encoder'])
        obs_decoder.load_state_dict(best_pretrain_state['decoder'])


In [None]:
# ======================================================
# Phase 2: Dynamic Training - Train full model with Sinkhorn + Latent Smoothing
# ======================================================
print("\n[Phase 2] Dynamic Training: Training full model (VAE + TimelyGPT)...")
optimizer = torch.optim.Adam(params=model.parameters(), lr=LR, betas=(0.95, 0.99))
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
ot_solver = geomloss.SamplesLoss("sinkhorn", p=2, blur=0.05, scaling=0.5, debias=True, backend="tensorized")
loss_list = []
iters_per_epoch = ITERS_PER_EPOCH

epoch_pbar = tqdm(range(1, EPOCHS + 1), desc="Training Progress")
for epoch in epoch_pbar:
    model.train()
    epoch_losses = []
    
    # Inner loop for iterations within each epoch
    for iter_idx in range(iters_per_epoch):
        # Sample mini-batch from first timepoint
        rand_idx = np.random.choice(train_data[0].shape[0], size=BATCH_SIZE, replace=False)
        batch_data = train_data[0][rand_idx, :].to(DEVICE)
        
        optimizer.zero_grad()
        
        # Forward pass with data from first timepoint
        recon_obs, first_latent_dist, first_tp_data, latent_seq = model(
            batch_data, train_time_indices, forward_impl='parallel'
        )
        
        # Compute loss: Sinkhorn OT + Latent Trajectory Smoothing
        ot_loss = 0.0
        for t_idx, t in enumerate(train_tps_idx):
            pred_x = recon_obs[:, t_idx, :]  # [batch_size, n_genes]
            true_x = train_data[t_idx].to(DEVICE)  # [n_cells_at_t, n_genes]
            
            # Subsample cells for efficient computation
            subsample_size = min(200, true_x.shape[0])
            subsample_idx = np.random.choice(true_x.shape[0], subsample_size, replace=False)
            ot_loss += ot_solver(pred_x, true_x[subsample_idx])
        
        # Compute trajectory smoothness loss
        latent_drift_loss = torch.mean((latent_seq[:, 1:, :] - latent_seq[:, :-1, :]) ** 2)
        
        # Total loss
        loss = ot_loss + LATENT_COEFF * latent_drift_loss
        
        loss.backward()
        optimizer.step()
        
        loss_list.append((loss.item(), ot_loss.item(), latent_drift_loss.item()))
        epoch_losses.append(loss.item())
    
    # Update learning rate
    scheduler.step()
    
    # Update epoch progress bar
    avg_epoch_loss = np.mean(epoch_losses)
    epoch_pbar.set_postfix({
        "Loss": f"{avg_epoch_loss:.4f}",
        "LR": f"{scheduler.get_last_lr()[0]:.6f}"
    })

# Training summary
train_duration = time.time() - train_start_time
print(f"\nTraining completed! Total epochs: {EPOCHS}")
print(f"Training time: {train_duration:.2f} seconds ({train_duration/60:.2f} minutes)")

In [None]:
# Visualization - loss curve
if len(loss_list) > 0:
    plt.figure(figsize=(8, 6))
    plt.subplot(3, 1, 1)
    plt.title("Loss")
    plt.plot([each[0] for each in loss_list])
    plt.subplot(3, 1, 2)
    plt.title("OT Term")
    plt.plot([each[1] for each in loss_list])
    plt.subplot(3, 1, 3)
    plt.title("Dynamic Reg")
    plt.plot([each[2] for each in loss_list])
    plt.xlabel("Dynamic Learning Iter")
    plt.tight_layout()
    plt.show()
    loss_plot_path = "figures/TimelyGPT_Loss.png"
    plt.savefig(loss_plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Loss plot saved to {loss_plot_path}")


In [None]:
# Generate predictions by sampling from t0 cell latent distribution
model.eval()

with torch.no_grad():
    # 1. Encode all real t0 cells to get their latent distributions
    t0_real_cells = traj_data[0].to(DEVICE)
    mus, log_vars = model.vae.encoder(t0_real_cells)
    stds = torch.exp(0.5 * log_vars)
    num_real_t0_cells = t0_real_cells.shape[0]

    # 2. Sample N_PRED_CELLS latent variables from this GMM
    component_indices = np.random.choice(num_real_t0_cells, size=N_PRED_CELLS, replace=True)
    selected_mus = mus[component_indices]
    selected_stds = stds[component_indices]
    
    # Sample using reparameterization trick
    eps = torch.randn_like(selected_stds)
    z_initial_sampled = selected_mus + eps * selected_stds  # Shape: [N_PRED_CELLS, LATENT_DIM]
    print(f"Sampled {N_PRED_CELLS} initial latent states.")

    # 3. Use sampled latent states as input to TimelyGPT
    all_time_indices = torch.LongTensor(all_tps).to(DEVICE)
    
    # Generate trajectory using TimelyGPT and decoder
    # predictions_latent: [N_PRED_CELLS, n_all_tps, latent_dim]
    predictions_latent = model.timely_gpt(z_initial_sampled, all_time_indices, forward_impl='parallel')
    
    # 4. Decode predictions to gene expression space
    # all_recon_obs_tensor: [N_PRED_CELLS, n_tps, n_genes]
    all_recon_obs_tensor = model.vae.decoder(predictions_latent)
    
    # Reshape for evaluation: [n_tps, n_cells, n_genes]
    all_recon_obs = all_recon_obs_tensor.permute(1, 0, 2).cpu().numpy()


print(f"Predicted data shape: {all_recon_obs.shape}")
print(f"Predicted cells at {len(all_tps)} timepoints")

# Evaluate predictions on all timepoints
print("\n Evaluation metrics for ALL timepoints:")
print(f"{'Time':<6} {'Type':<6} {'OT':<10} {'L2':<10} {'CorrDist':<10}")
print("-"*70)

for t_idx in all_tps:
    true_data_t = traj_data[t_idx].numpy()
    pred_data_t = all_recon_obs[t_idx]
    
    # Compute standard metrics
    metrics = globalEvaluation(true_data_t, pred_data_t)
    
    # Determine timepoint type
    tp_type = "TRAIN" if t_idx in train_tps_idx else "TEST"
    
    # Print results
    print(f"t={t_idx:<4} {tp_type:<6} {metrics['ot']:<10.4f} {metrics['l2']:<10.4f} "
          f"{metrics['corr']:<10.4f}")

print("\n Summary for TEST timepoints only:")
test_metrics = {'ot': [], 'l2': [], 'corr': []}
for t_idx in test_tps_idx:
    true_data_t = traj_data[t_idx].numpy()
    pred_data_t = all_recon_obs[t_idx]
    metrics = globalEvaluation(true_data_t, pred_data_t)
    test_metrics['ot'].append(metrics['ot'])
    test_metrics['l2'].append(metrics['l2'])
    test_metrics['corr'].append(metrics['corr'])

print(f"Average OT: {np.mean(test_metrics['ot']):.4f} ± {np.std(test_metrics['ot']):.4f}")
print(f"Average L2: {np.mean(test_metrics['l2']):.4f} ± {np.std(test_metrics['l2']):.4f}")
print(f"Average CorrDist: {np.mean(test_metrics['corr']):.4f} ± {np.std(test_metrics['corr']):.4f}")

# Generate UMAP visualization
print("\nGenerating UMAP visualization...")

# Prepare true data
true_data_list = [each.detach().numpy() if isinstance(each, torch.Tensor) else each for each in traj_data]
true_all = np.concatenate(true_data_list, axis=0)

# Fit UMAP on real data only
true_umap_traj, umap_model, pca_model = umapWithPCA(true_all, n_neighbors=50, min_dist=0.1, pca_pcs=50)

# Transform predictions using fitted UMAP
pred_data_list = [all_recon_obs[t_idx] for t_idx in all_tps]
pred_all = np.concatenate(pred_data_list, axis=0)
pred_umap_traj = umap_model.transform(pca_model.transform(pred_all))

# Generate timepoint labels
true_cell_tps = np.concatenate([np.repeat(t_idx, traj_data[t_idx].shape[0]) for t_idx in all_tps])
pred_cell_tps = np.concatenate([np.repeat(t_idx, all_recon_obs[t_idx].shape[0]) for t_idx in all_tps])

# Create visualization
save_path = "figures/TimelyGPT_Results.png"
plotPredTestTime(true_umap_traj, pred_umap_traj, true_cell_tps, pred_cell_tps, test_tps_idx, save_path=save_path)