# Temporal Geneformer Inference and UMAP Visualization

This notebook demonstrates how to:
1. Run inference on the test dataset using a trained temporal geneformer checkpoint
2. Extract cell embeddings from the model's latent space
3. Generate UMAP projections to visualize the learned representations
4. Create publication-quality visualizations of the latent space

The temporal geneformer model learns to predict the next cell state in temporal trajectories, which should result in embeddings that capture meaningful biological progression.


## Setup Paths and Configuration


In [None]:
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Define paths
checkpoint_path = Path("/workspaces/bionemo-framework/sub-packages/bionemo-geneformer/examples/temporal_geneformer_results/temporal_geneformer/dev/checkpoints/epoch=0-val_loss=7.61-step=9999-consumed_samples=80000.0-last")
data_path = Path("/workspaces/bionemo-framework/sub-packages/bionemo-geneformer/src/bionemo/geneformer/test_data/mar_por_datasets_splits/with_neighbors/test")
results_path = Path("./temporal_inference_results")

# Create results directory
results_path.mkdir(parents=True, exist_ok=True)

# Configuration
micro_batch_size = 64
seq_length = 2048
num_dataset_workers = 4
num_gpus = 1

print(f"Checkpoint: {checkpoint_path}")
print(f"Data path: {data_path}")
print(f"Results path: {results_path}")
print(f"\nCheckpoint exists: {checkpoint_path.exists()}")
print(f"Data path exists: {data_path.exists()}")


Checkpoint: /workspaces/bionemo-framework/sub-packages/bionemo-geneformer/examples/temporal_geneformer_results/temporal_geneformer/dev/checkpoints/epoch=0-val_loss=7.61-step=9999-consumed_samples=80000.0-last
Data path: /workspaces/bionemo-framework/sub-packages/bionemo-geneformer/src/bionemo/geneformer/test_data/mar_por_datasets_splits/with_neighbors/test
Results path: temporal_inference_results

Checkpoint exists: True
Data path exists: True


## Run Inference

We use the `infer_geneformer_scmap` command to run inference on the SCMAP test dataset. This version supports temporal training features with neighbor information.

Key flags:
- `--include-embeddings`: Extract cell-level embeddings (mean of all gene tokens, excluding special tokens) - enabled by default
- `--include-hiddens`: Extract per-token hidden states (optional, for more detailed analysis)
- `--include-input-ids`: Include input token IDs for mapping back to genes
- `--next-cell-prediction`: Enable temporal mode (uses neighbor information)
- `--filter-no-neighbors`: Only include cells that have neighbors in the dataset


In [2]:
!infer_geneformer_scmap \
    --data-dir {data_path} \
    --checkpoint-path {checkpoint_path} \
    --results-path {results_path} \
    --micro-batch-size {micro_batch_size} \
    --seq-length {seq_length} \
    --num-dataset-workers {num_dataset_workers} \
    --num-gpus {num_gpus} \
    --include-input-ids


Could not find the bitsandbytes CUDA binary at PosixPath('/usr/local/lib/python3.12/dist-packages/bitsandbytes/libbitsandbytes_cuda129.so')
The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
[NeMo W 2025-10-08 19:54:09 nemo_logging:405] Tokenizer vocab file: /home/ubuntu/.cache/bionemo/d8e3ea569bc43768c24aa651aff77722df202078415528497c22394046b08cc3-singlecell-scdltestdata-20241203.tar.gz.untar/cellxgene_2023-12-15_small_processed_scdl/train/geneformer.vocab already exists. Overwriting...
[NeMo I 2025-10-08 19:54:09 nemo_logging:393] No checksum provided, filename exists. Assuming it is complete.
[NeMo I 2025-10-08 19:54:09 nemo_logging:393] Resource already exists, skipping download: https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl?download=true
[NeMo I 2025-10-08 19:54:09 nemo_logging:393] No checksum provided, filena

In [None]:
!infer_geneformer_scmap \
    --data-dir {data_path} \
    --checkpoint-path {checkpoint_path} \
    --results-path {results_path} \
    --micro-batch-size {micro_batch_size} \
    --seq-length {seq_length} \
    --num-dataset-workers {num_dataset_workers} \
    --num-gpus {num_gpus} \
    --include-input-ids

## Load Inference Results

The inference results are saved as a PyTorch `.pt` file containing:
- `embeddings`: Cell-level representations (N_cells x embedding_dim)
- `hidden_states`: Per-token hidden states (N_cells x seq_len x embedding_dim)
- `input_ids`: Token IDs for each cell
- `token_logits`: Output logits (if included)
- `binary_logits`: Binary classification logits (if applicable)


In [3]:
import torch
import numpy as np

# Load predictions
predictions_file = results_path / "predictions__rank_0.pt"
print(f"Loading predictions from: {predictions_file}")

predictions = torch.load(predictions_file, weights_only=False)

# Check what keys are available
print(f"\nAvailable keys: {predictions.keys()}")

# Extract embeddings
embeddings = predictions["embeddings"].float().cpu().numpy()
print(f"\nEmbeddings shape: {embeddings.shape}")
print(f"Number of cells: {embeddings.shape[0]}")
print(f"Embedding dimension: {embeddings.shape[1]}")


Loading predictions from: temporal_inference_results/predictions__rank_0.pt

Available keys: dict_keys(['token_logits', 'binary_logits', 'input_ids', 'embeddings'])

Embeddings shape: (121698, 256)
Number of cells: 121698
Embedding dimension: 256


## Load Metadata and Dataset Labels

We'll load metadata to understand which dataset each cell belongs to. This is useful for visualizing how well the model integrates different datasets.


In [4]:
import pandas as pd
import pyarrow.parquet as pq

# Load feature metadata (contains dataset labels and other info)
features_dir = data_path / "features"

# Read all parquet files and combine them
parquet_files = sorted(features_dir.glob("dataframe_*.parquet"))
print(f"Found {len(parquet_files)} parquet files")

# Load and concatenate all dataframes
dfs = []
for pf in parquet_files:
    df = pd.read_parquet(pf)
    dfs.append(df)
    print(f"Loaded {pf.name}: {len(df)} rows")

metadata = pd.concat(dfs, ignore_index=True)
print(f"\nTotal metadata rows: {len(metadata)}")
print(f"\nMetadata columns: {metadata.columns.tolist()}")
print(f"\nFirst few rows:")
print(metadata.head())


Found 3 parquet files
Loaded dataframe_0.parquet: 21192 rows
Loaded dataframe_1.parquet: 21633 rows
Loaded dataframe_2.parquet: 18323 rows

Total metadata rows: 61148

Metadata columns: ['soma_joinid', 'feature_id', 'feature_name', 'feature_length', 'nnz', 'n_measured_obs', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm']

First few rows:
   soma_joinid       feature_id feature_name  feature_length      nnz  \
0            1  ENSG00000121410         A1BG            3999  5640476   
1            2  ENSG00000268895     A1BG-AS1            3374  3071864   
2            4  ENSG00000175899          A2M            6318  7894261   
3            5  ENSG00000245105      A2M-AS1            2948  1637794   
4            8  ENSG00000184389      A3GALT2            1023   439067   

   n_measured_obs  n_cells     mt  n_cells_by_counts  mean_counts  \
0        62641311    15746  False          

In [5]:
# Extract dataset labels if available
if 'dataset' in metadata.columns:
    dataset_labels = metadata['dataset'].values
    print(f"Dataset distribution:")
    print(metadata['dataset'].value_counts())
elif 'dataset_index' in metadata.columns:
    dataset_labels = metadata['dataset_index'].values
    print(f"Dataset distribution:")
    print(metadata['dataset_index'].value_counts())
else:
    # If no dataset column, create dummy labels based on filename patterns
    print("No dataset column found. Available columns:")
    print(metadata.columns.tolist())
    # Create dummy labels (split in half for visualization)
    n_cells = len(metadata)
    dataset_labels = np.array(["Dataset A"] * (n_cells // 2) + ["Dataset B"] * (n_cells - n_cells // 2))
    print("\nCreated dummy dataset labels for visualization purposes")

# Ensure embeddings and labels have same length
print(f"\nEmbeddings shape: {embeddings.shape[0]}")
print(f"Labels length: {len(dataset_labels)}")
assert len(dataset_labels) == embeddings.shape[0], "Mismatch between embeddings and labels!"


No dataset column found. Available columns:
['soma_joinid', 'feature_id', 'feature_name', 'feature_length', 'nnz', 'n_measured_obs', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm']

Created dummy dataset labels for visualization purposes

Embeddings shape: 121698
Labels length: 61148


AssertionError: Mismatch between embeddings and labels!

## Generate UMAP Projection

UMAP (Uniform Manifold Approximation and Projection) is a dimensionality reduction technique that preserves both local and global structure. It's particularly effective for visualizing high-dimensional embeddings.

Key parameters:
- `n_neighbors`: Controls balance between local and global structure (15-50 typical)
- `min_dist`: Controls cluster tightness (0.0-0.5 typical)
- `metric`: Distance metric ('cosine' works well for embeddings)
- `n_components`: Number of dimensions to reduce to (2 for visualization)


In [None]:
import umap
from sklearn.preprocessing import StandardScaler

# Standardize embeddings (optional but often helpful)
scaler = StandardScaler()
embeddings_scaled = scaler.fit_transform(embeddings)

print("Computing UMAP projection...")
print("This may take a few minutes for large datasets...")

# Create UMAP reducer
reducer = umap.UMAP(
    n_neighbors=15,
    min_dist=0.1,
    n_components=2,
    metric='cosine',
    random_state=42,
    verbose=True
)

# Fit and transform
umap_embedding = reducer.fit_transform(embeddings_scaled)

print(f"\nUMAP embedding shape: {umap_embedding.shape}")
print(f"UMAP1 range: [{umap_embedding[:, 0].min():.2f}, {umap_embedding[:, 0].max():.2f}]")
print(f"UMAP2 range: [{umap_embedding[:, 1].min():.2f}, {umap_embedding[:, 1].max():.2f}]")


## Save UMAP Coordinates

Save the UMAP coordinates so you don't have to recompute them.


In [None]:
# Save UMAP coordinates
np.save(results_path / "umap_coordinates.npy", umap_embedding)
print(f"UMAP coordinates saved to {results_path / 'umap_coordinates.npy'}")

# Save metadata with UMAP coordinates
metadata_with_umap = metadata.copy()
metadata_with_umap['UMAP1'] = umap_embedding[:, 0]
metadata_with_umap['UMAP2'] = umap_embedding[:, 1]
metadata_with_umap['dataset_label'] = dataset_labels
metadata_with_umap.to_csv(results_path / "metadata_with_umap.csv", index=False)
print(f"Metadata with UMAP saved to {results_path / 'metadata_with_umap.csv'}")


## Visualize UMAP Projection

Create a publication-quality UMAP visualization showing how the temporal geneformer model integrates different datasets in the latent space.


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set up plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("notebook", font_scale=1.2)

# Create figure
fig, ax = plt.subplots(1, 1, figsize=(12, 10))

# Get unique labels and assign colors
unique_labels = np.unique(dataset_labels)
print(f"Unique labels: {unique_labels}")

# Use tab10 colormap for distinct colors
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))

# Plot each dataset separately
for i, label in enumerate(unique_labels):
    mask = dataset_labels == label
    ax.scatter(
        umap_embedding[mask, 0],
        umap_embedding[mask, 1],
        c=[colors[i]],
        label=f"{label}",
        alpha=0.6,
        s=5,
        rasterized=True  # Better performance for many points
    )

# Formatting
ax.set_xlabel('UMAP1', fontsize=14, fontweight='bold')
ax.set_ylabel('UMAP2', fontsize=14, fontweight='bold')
ax.set_title('Temporal Geneformer - Latent Space Integration', fontsize=16, fontweight='bold')
ax.legend(title="Dataset", loc='best', frameon=True, fancybox=True, framealpha=0.9, fontsize=12)

# Add annotation
ax.text(0.5, 1.02, "UMAP projection with each dataset highlighted",
        ha='center', va='bottom', transform=ax.transAxes, fontsize=11, style='italic')

# Remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()

# Save figure
output_file = results_path / "temporal_geneformer_umap.png"
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\nFigure saved to {output_file}")

plt.show()
