In [None]:
import anndata
import numpy as np
import matplotlib.pyplot as plt
import umap
from sklearn.decomposition import IncrementalPCA

from wassersteinwormhole import Wormhole
from wassersteinwormhole.DefaultConfig import DefaultConfig

# -----------------------------------------------------
# 1. Load Data
# -----------------------------------------------------
rna_data = anndata.read_h5ad("/home/ubuntu/GSE278572_gene_expression.h5ad")
atac_data = anndata.read_h5ad("/home/ubuntu/GSE244184_scATAC_harmonized.h5ad")
print("Data loaded.")

# -----------------------------------------------------
# 2. Dimensionality Reduction with IncrementalPCA
# -----------------------------------------------------
pca_dim = 50       # number of PCA components
chunk_size = 5000  # how many cells to process at once in memory

pca_rna = IncrementalPCA(n_components=pca_dim)
pca_atac = IncrementalPCA(n_components=pca_dim)

rna_n = rna_data.shape[0]
atac_n = atac_data.shape[0]

# --- A) PARTIAL_FIT on RNA in chunks
for start in range(0, rna_n, chunk_size):
    end = min(rna_n, start + chunk_size)
    # Convert only this chunk to a dense array temporarily
    chunk = rna_data.X[start:end].toarray()
    pca_rna.partial_fit(chunk)
print("Partial PCA fit done for RNA.")

# --- B) PARTIAL_FIT on ATAC in chunks
for start in range(0, atac_n, chunk_size):
    end = min(atac_n, start + chunk_size)
    chunk = atac_data.X[start:end].toarray()
    pca_atac.partial_fit(chunk)
print("Partial PCA fit done for ATAC.")

# --- C) Transform RNA in chunks and combine
rna_pca_list = []
for start in range(0, rna_n, chunk_size):
    end = min(rna_n, start + chunk_size)
    chunk = rna_data.X[start:end].toarray()
    rna_pca_list.append(pca_rna.transform(chunk))

rna_pca = np.vstack(rna_pca_list)
del rna_pca_list  # free up memory
print("RNA PCA shape:", rna_pca.shape)

# --- D) Transform ATAC in chunks and combine
atac_pca_list = []
for start in range(0, atac_n, chunk_size):
    end = min(atac_n, start + chunk_size)
    chunk = atac_data.X[start:end].toarray()
    atac_pca_list.append(pca_atac.transform(chunk))

atac_pca = np.vstack(atac_pca_list)
del atac_pca_list  # free up memory
print("ATAC PCA shape:", atac_pca.shape)

# -----------------------------------------------------
# 3. Prepare Data for Wormhole
# -----------------------------------------------------
point_clouds = [rna_pca, atac_pca]

# -----------------------------------------------------
# 4. Configure and Initialize Wormhole
# -----------------------------------------------------
config = DefaultConfig(
    emb_dim=128,         # Embedding dimension for Wormhole’s latent space
    dist_func_enc='GW',  # Distance function for encoder
    dist_func_dec='GW'   # Distance function for decoder
)
wormhole = Wormhole(point_clouds=point_clouds, config=config)
print("Wormhole model initialized.")

# -----------------------------------------------------
# 5. Train Wormhole
#    Adjust training_steps and batch_size to fit resources
# -----------------------------------------------------
wormhole.train(training_steps=5000, batch_size=16, verbose=500)
print("Wormhole model trained.")

# -----------------------------------------------------
# 6. Encode the Full Combined Dataset
# -----------------------------------------------------
combined_data = np.concatenate(point_clouds, axis=0)  # shape: (N_RNA + N_ATAC, pca_dim)
embeddings = wormhole.encode(combined_data)
print("Combined data encoded.")

# -----------------------------------------------------
# 7. UMAP on the Wormhole Embeddings
# -----------------------------------------------------
reducer = umap.UMAP()
umap_emb = reducer.fit_transform(embeddings)
print("UMAP reduction done.")

# -----------------------------------------------------
# 8. Plot and Save
# -----------------------------------------------------
plt.figure(figsize=(8, 6))
plt.scatter(
    umap_emb[:rna_pca.shape[0], 0],
    umap_emb[:rna_pca.shape[0], 1],
    label="RNA-seq", alpha=0.6
)
plt.scatter(
    umap_emb[rna_pca.shape[0]:, 0],
    umap_emb[rna_pca.shape[0]:, 1],
    label="ATAC-seq", alpha=0.6
)
plt.title("Integration of RNA and ATAC Data (IncrementalPCA + GW Wormhole)")
plt.xlabel("UMAP Dimension 1")
plt.ylabel("UMAP Dimension 2")
plt.legend()
plt.grid(False)

plt.savefig("wormhole_integration_umap.png", dpi=300, bbox_inches="tight")
plt.show()
print("Plot saved as wormhole_integration_umap.png.")
