In [3]:
"""
train_vae_gnn_align.py

Usage:
    python train_vae_gnn_align.py --data_npz path/to/mock_sc_dataset.npz

Expect the NPZ to contain:
 - ref_expr: (N_ref, G) raw counts or numeric matrix
 - ref_labels: (N_ref,) integer labels (optional, not used for embedding training)
 - query_expr: (N_query, G)
 - query_coords: (N_query, 2)

Produces:
 - model checkpoints (vae_state.pt, gnn_state.pt)
 - final embeddings (final_embeddings.npz)
"""
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.neighbors import kneighbors_graph
from torch.utils.data import DataLoader, TensorDataset


In [4]:
# -------------------------
# Utilities & Losses
# -------------------------
def pairwise_distances(x, y=None):
    """Compute squared pairwise Euclidean distances between rows of x and rows of y (or x vs x)."""
    if y is None:
        y = x
    x_norm = (x**2).sum(dim=1).view(-1,1)
    y_norm = (y**2).sum(dim=1).view(1,-1)
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y.t())
    return torch.clamp(dist, min=0.0)

def gaussian_kernel_matrix(x, y, sigma):
    dist = pairwise_distances(x, y)
    return torch.exp(-dist / (2.0 * sigma**2))

def mmd_rbf(x, y, sigmas=(1.0, 2.0, 4.0, 8.0)):
    """Maximum Mean Discrepancy with mixtures of RBF kernels."""
    Kxx = 0.0
    Kyy = 0.0
    Kxy = 0.0
    for s in sigmas:
        Kxx = Kxx + gaussian_kernel_matrix(x, x, s).mean()
        Kyy = Kyy + gaussian_kernel_matrix(y, y, s).mean()
        Kxy = Kxy + gaussian_kernel_matrix(x, y, s).mean()
    return Kxx + Kyy - 2.0 * Kxy

In [5]:
# -------------------------
# Models: Pt-1 VAE
# -------------------------
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim=256):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
        )
        self.mu = nn.Linear(hidden_dim//2, latent_dim)
        self.logvar = nn.Linear(hidden_dim//2, latent_dim)

        self.dec = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, input_dim),
            # decoder output is real-valued (we'll use MSE on log1p data)
        )

    def encode(self, x):
        h = self.enc(x)
        return self.mu(h), self.logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.dec(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar, z

In [6]:
# -------------------------
# Models: Pt-1 GNN Encoder
# -------------------------

class GNNEncoder(nn.Module):
    """
    Simple 2-layer graph encoder using precomputed normalized adjacency A_norm:
      h1 = relu(A_norm @ X @ W1 + b1)
      z  = A_norm @ h1 @ W2 + b2
    plus a small decoder from z -> reconstruct X.
    """
    def __init__(self, input_dim, latent_dim, hidden_dim=256):
        super().__init__()
        # linear transforms (we apply adjacency externally via matmul)
        self.lin1 = nn.Linear(input_dim, hidden_dim, bias=True)
        self.lin2 = nn.Linear(hidden_dim, latent_dim, bias=True)
        # decoder: latent -> reconstruct expression
        self.dec = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim//2),
            nn.ReLU(),
            nn.Linear(hidden_dim//2, input_dim)
        )

    def forward(self, X, A_norm):
        # X: (N, G); A_norm: (N, N) normalized adjacency (torch tensor)
        H1 = torch.relu(A_norm @ self.lin1(X))  # (N, hidden)
        Z  = A_norm @ self.lin2(H1)             # (N, latent)
        recon = self.dec(Z)                     # (N, G)
        return recon, Z


In [7]:
# -------------------------
# Training function
# -------------------------
def train(
    data_npz,
    outdir="mock_sc_align",
    latent_dim=10,
    batch_size=128,
    epochs=50,
    lr=1e-3,
    k_nn=8,
    alpha_vae=1.0,
    alpha_gnn=1.0,
    alpha_comp=1.0,
    device=None
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(outdir, exist_ok=True)

    npz = np.load(data_npz, allow_pickle=True)
    # required arrays
    ref_expr = npz["ref_expr"].astype(np.float32)   # (N_ref, G)
    query_expr = npz["query_expr"].astype(np.float32) # (N_query, G)
    query_coords = npz["query_coords"].astype(np.float32)  # (N_query, 2)
    # optional
    ref_labels = npz["ref_labels"] if "ref_labels" in npz else None

    # Preprocessing: log1p transform (common simple normalization)
    ref_x = np.log1p(ref_expr)
    query_x = np.log1p(query_expr)

    N_ref, G = ref_x.shape
    N_query = query_x.shape[0]

    # Build normalized adjacency for query using kNN on coords
    A = kneighbors_graph(query_coords, n_neighbors=k_nn, mode='connectivity', include_self=True).toarray().astype(float)
    A = (A + A.T) / 2.0
    A = A + np.eye(N_query)  # ensure self connections
    deg = A.sum(axis=1)
    D_inv_sqrt = np.diag(1.0 / np.sqrt(deg + 1e-8))
    A_norm_np = D_inv_sqrt @ A @ D_inv_sqrt
    A_norm = torch.tensor(A_norm_np, dtype=torch.float32, device=device)

    # Convert to torch tensors
    ref_tensor = torch.tensor(ref_x, dtype=torch.float32, device=device)
    query_tensor = torch.tensor(query_x, dtype=torch.float32, device=device)

    # Dataloaders for reference (VAE uses minibatches)
    ref_ds = TensorDataset(ref_tensor)
    ref_loader = DataLoader(ref_ds, batch_size=batch_size, shuffle=True, drop_last=False)

    # instantiate models
    vae = VAE(input_dim=G, latent_dim=latent_dim).to(device)
    gnn = GNNEncoder(input_dim=G, latent_dim=latent_dim).to(device)

    # optimizer over both models' parameters
    opt = optim.Adam(list(vae.parameters()) + list(gnn.parameters()), lr=lr)

    mse = nn.MSELoss(reduction="mean")

    print(f"Training on device: {device}; N_ref={N_ref}, N_query={N_query}, genes={G}")
    for ep in range(1, epochs+1):
        vae.train(); gnn.train()
        epoch_vae_loss = 0.0
        epoch_gnn_loss = 0.0
        epoch_comp_loss = 0.0
        seen = 0

        # Precompute full query forward for this epoch (we can update it inside loop as well;
        # here we do full-batch forward each minibatch because GNN uses whole adjacency)
        # This choice keeps GNN full-batch and VAE minibatched.
        # You could instead update query forward after each optimizer step (done below).
        for batch_idx, (batch_x,) in enumerate(ref_loader):
            batch_size_curr = batch_x.size(0)
            # VAE forward (on ref batch)
            recon_b, mu_b, logvar_b, z_b = vae(batch_x)

            # VAE losses: reconstruction (MSE) + KL
            recon_loss_b = mse(recon_b, batch_x)
            kl_b = -0.5 * torch.mean(1 + logvar_b - mu_b.pow(2) - logvar_b.exp())
            vae_loss = recon_loss_b + kl_b

            # GNN forward: full-batch on current query tensor (we use current model weights)
            recon_q, z_q = gnn(query_tensor, A_norm)   # recon_q: (N_query, G), z_q: (N_query, latent)
            gnn_recon_loss = mse(recon_q, query_tensor)

            # Comparative loss: MMD between z_b (batch) and a random subset of z_q
            # choose subset of query embeddings to match batch size (or smaller if query < batch)
            q_sub_size = min(z_q.shape[0], z_b.shape[0])
            # select indices randomly on CPU numpy for reproducibility (but can be torch)
            q_idx = np.random.choice(z_q.shape[0], size=q_sub_size, replace=False)
            z_q_sub = z_q[q_idx, :]

            comp_loss = mmd_rbf(z_b, z_q_sub)

            total_loss = alpha_vae * vae_loss + alpha_gnn * gnn_recon_loss + alpha_comp * comp_loss

            opt.zero_grad()
            total_loss.backward()
            opt.step()

            epoch_vae_loss += vae_loss.item() * batch_size_curr
            epoch_gnn_loss += gnn_recon_loss.item() * batch_size_curr  # approximate bookkeeping
            epoch_comp_loss += comp_loss.item() * batch_size_curr
            seen += batch_size_curr

        epoch_vae_loss /= seen
        epoch_gnn_loss /= seen
        epoch_comp_loss /= seen

        # diagnostics: compute means and cosine similarity between full means of embeddings
        with torch.no_grad():
            vae.eval(); gnn.eval()
            _, _, _, z_ref_full = vae(ref_tensor)     # (N_ref, latent)
            _, z_query_full = gnn(query_tensor, A_norm)  # (N_query, latent)
            mean_ref = z_ref_full.mean(dim=0)
            mean_q = z_query_full.mean(dim=0)
            mean_cos = nn.functional.cosine_similarity(mean_ref.unsqueeze(0), mean_q.unsqueeze(0)).item()
            mean_euc = torch.norm(mean_ref - mean_q).item()

        if ep % max(1, epochs//10) == 0 or ep == 1:
            print(f"[Epoch {ep:03d}/{epochs}] VAE_loss={epoch_vae_loss:.4f} | GNN_recon={epoch_gnn_loss:.4f} | MMD={epoch_comp_loss:.4f} | mean_cos={mean_cos:.4f} | mean_euc={mean_euc:.4f}")

    # Save models & embeddings
    torch.save(vae.state_dict(), os.path.join(outdir, "vae_state.pt"))
    torch.save(gnn.state_dict(), os.path.join(outdir, "gnn_state.pt"))

    # final embeddings
    with torch.no_grad():
        vae.eval(); gnn.eval()
        _, _, _, z_ref_final = vae(ref_tensor)
        _, z_query_final = gnn(query_tensor, A_norm)

    z_ref_final_np = z_ref_final.cpu().numpy()
    z_query_final_np = z_query_final.cpu().numpy()
    np.savez_compressed(os.path.join(outdir, "final_embeddings.npz"),
                        z_ref=z_ref_final_np,
                        z_query=z_query_final_np)
    print("Saved models and embeddings to", outdir)

In [None]:
## -------------------------
## Loading the data set
## -------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

data = np.load("mock_sc/mock_sc_dataset.npz", allow_pickle=True)


Device: cpu


In [12]:
## -------------------------
## Model training
## -------------------------

train("mock_sc/mock_sc_dataset.npz")

Training on device: cpu; N_ref=1000, N_query=1000, genes=100
[Epoch 001/50] VAE_loss=2.2319 | GNN_recon=2.6820 | MMD=1.2835 | mean_cos=0.9253 | mean_euc=0.5144
[Epoch 005/50] VAE_loss=0.8314 | GNN_recon=0.5087 | MMD=0.3560 | mean_cos=0.9909 | mean_euc=0.4289
[Epoch 010/50] VAE_loss=0.6218 | GNN_recon=0.3258 | MMD=0.3151 | mean_cos=0.9738 | mean_euc=0.4691
[Epoch 015/50] VAE_loss=0.5213 | GNN_recon=0.2864 | MMD=0.3196 | mean_cos=0.9259 | mean_euc=0.3761
[Epoch 020/50] VAE_loss=0.4921 | GNN_recon=0.2849 | MMD=0.3012 | mean_cos=0.7953 | mean_euc=0.3842
[Epoch 025/50] VAE_loss=0.4838 | GNN_recon=0.2818 | MMD=0.2728 | mean_cos=0.6315 | mean_euc=0.2811
[Epoch 030/50] VAE_loss=0.4752 | GNN_recon=0.2858 | MMD=0.2833 | mean_cos=0.4570 | mean_euc=0.3176
[Epoch 035/50] VAE_loss=0.4762 | GNN_recon=0.2835 | MMD=0.2881 | mean_cos=0.0085 | mean_euc=0.3589
[Epoch 040/50] VAE_loss=0.4713 | GNN_recon=0.2811 | MMD=0.2807 | mean_cos=0.2447 | mean_euc=0.1653
[Epoch 045/50] VAE_loss=0.4659 | GNN_recon=0.279