In [None]:
import hvplot.polars  # type: ignore
import numpy as np
import polars as pl
import polars.selectors as cs
import torch
import torch.nn.functional as F
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from torch import nn

In [None]:
from pathlib import Path

from dotenv import dotenv_values

paths = dotenv_values()
paths

In [None]:
data_path = Path(paths["DATA_DIR"])

counts = pl.read_parquet(data_path / "processed-data/training_data-counts_uint.parquet")
counts

In [None]:
X = counts.select(cs.numeric()).to_numpy()

In [None]:
X_std = (X - X.mean(axis=1)[:, np.newaxis]) / X.std(axis=1)[:, np.newaxis]

In [None]:
import matplotlib.pyplot as plt

In [None]:
pca = PCA(n_components=512, svd_solver="randomized")
pca.fit(X_std.T)

In [None]:
transformed = pca.transform(X_std.T)

In [None]:
plt.plot(np.arange(512), np.cumsum(pca.explained_variance_ratio_))

In [None]:
embedding_df = pl.DataFrame(
    data=transformed, schema=[f"latent_{i}" for i in range(512)]
).with_columns(gene_name=pl.Series(counts.select(cs.numeric()).columns))
embedding_df = embedding_df.select("gene_name", cs.numeric())
embedding_df

In [None]:
embedding_df.write_parquet(data_path / "gene_embeddings/PCA-train_expression.parquet")

# Covariance

In [None]:
covariance = np.cov(X_std.T)
covariance

In [None]:
from sklearn.decomposition import MiniBatchSparsePCA

sparse_pca = MiniBatchSparsePCA(n_components=512, n_jobs=8, batch_size=100)
transformed_sparse_pca = sparse_pca.fit_transform(X_std.T)

In [None]:
transformed_sparse_pca

In [None]:
sparse_pca_embedding_df = pl.DataFrame(
    data=transformed_sparse_pca, schema=[f"latent_{i}" for i in range(512)]
).with_columns(gene_name=pl.Series(counts.select(cs.numeric()).columns))
sparse_pca_embedding_df = sparse_pca_embedding_df.select("gene_name", cs.numeric())
sparse_pca_embedding_df

In [None]:
sparse_pca_embedding_df.write_parquet(
    data_path / "gene_embeddings/SparsePCA-train_expression.parquet"
)

# Checking these embeddings

In [None]:
df = pl.read_parquet(data_path / "gene_embeddings/PCA-train_expression.parquet")

In [None]:
df

In [None]:
df.var()

In [None]:
X_pca = df.select(cs.numeric()).to_torch()
X_pca.shape

In [None]:
X_norm_pca = F.normalize(X_pca, dim=-1)
X_norm_pca.sum()

In [None]:
gene_sim = torch.matmul(X_norm_pca, X_norm_pca.T)
gene_sim

In [None]:
gene_sim.shape

In [None]:
counts.select(["SAMD11", "NOC2L", "KLHL17"]).with_columns(pl.all().log1p()).corr()

In [None]:
torch.Tensor([[1, 2], [2, 1]]) * torch.Tensor([1, 2]).unsqueeze(-1)

In [None]:
torch.Tensor([1, 2]).unsqueeze(-1)

# VAE embeddings

In [None]:
# Simple Embeddings

from lightning.fabric import Fabric

In [None]:
fabric = Fabric(accelerator="cuda", devices=[3])
fabric

In [None]:
class GeneVAE(nn.Module):
    def __init__(self, n_samples, latent_dim=256, dropout_rate=0.1):
        super(GeneVAE, self).__init__()
        self.n_samples = n_samples
        self.latent_dim = latent_dim

        # Encoder layers
        self.encoder_fc1 = nn.Linear(n_samples, 1024)
        self.encoder_dropout1 = nn.Dropout(dropout_rate)
        self.encoder_fc2 = nn.Linear(1024, 512)
        self.encoder_dropout2 = nn.Dropout(dropout_rate)
        self.encoder_fc3 = nn.Linear(512, 256)

        # Latent space
        self.encoder_mean = nn.Linear(256, latent_dim)
        self.encoder_logvar = nn.Linear(256, latent_dim)

        # Decoder layers
        self.decoder_fc1 = nn.Linear(latent_dim, 256)
        self.decoder_dropout1 = nn.Dropout(dropout_rate)
        self.decoder_fc2 = nn.Linear(256, 512)
        self.decoder_dropout2 = nn.Dropout(dropout_rate)
        self.decoder_fc3 = nn.Linear(512, 1024)
        self.decoder_output = nn.Linear(1024, n_samples)

    def encode(self, x):
        h1 = F.gelu(self.encoder_fc1(x))
        h1 = self.encoder_dropout1(h1)
        h2 = F.gelu(self.encoder_fc2(h1))
        h2 = self.encoder_dropout2(h2)
        h3 = F.gelu(self.encoder_fc3(h2))

        mean = self.encoder_mean(h3)
        logvar = self.encoder_logvar(h3)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + eps * std
        else:
            return mean

    def decode(self, z):
        h3 = F.gelu(self.decoder_fc1(z))
        h3 = self.decoder_dropout1(h3)
        h4 = F.gelu(self.decoder_fc2(h3))
        h4 = self.decoder_dropout2(h4)
        h5 = F.gelu(self.decoder_fc3(h4))
        return self.decoder_output(h5)

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        decoded = self.decode(z)
        return decoded, mean, logvar

    def get_gene_embeddings(self, x):
        """Extract gene embeddings (means from latent space)"""
        self.eval()
        with torch.no_grad():
            mean, _ = self.encode(x)
        return mean


def vae_loss(recon_x, x, mean, logvar, beta=1.0):
    """
    VAE loss with β parameter for controlling regularization
    Higher β = more regularization, more structured latent space
    """
    # Reconstruction loss (MSE for continuous data)
    recon_loss = F.mse_loss(recon_x, x, reduction="sum")

    # KL Divergence loss
    kld = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())

    return recon_loss + beta * kld, recon_loss, kld

In [None]:
X = counts.select(cs.numeric().log1p()).to_torch("tensor").T.to(torch.float32)

In [None]:
dataset = torch.utils.data.TensorDataset(X)
dataset

In [None]:
model = GeneVAE(counts.shape[0])
# model.compile()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10)


dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, num_workers=8)

In [None]:
fabric.launch()

In [None]:
torch.set_float32_matmul_precision("medium")

In [None]:
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)

In [None]:
from tqdm import tqdm

In [None]:
n_epochs = 200

model.train()
for epoch in tqdm(range(n_epochs), desc="Epochs"):
    for batch_idx, (batch_data,) in enumerate(dataloader):
        optimizer.zero_grad()

        recon_batch, mean, logvar = model(batch_data)
        loss, recon_loss, kld_loss = vae_loss(recon_batch, batch_data, mean, logvar, beta=1.0)

        fabric.backward(loss)

        optimizer.step()
    if epoch % 5 == 0:
        print(
            f"Epoch {epoch:3d}/{n_epochs}, Loss: {loss:.4f}, "
            f"Recon: {recon_loss:.4f}, KLD: {kld_loss:.4f}"
        )
    scheduler.step()

In [None]:
dataset[1]

In [None]:
model.eval()

In [None]:
torch.cuda.empty_cache()

In [None]:
model.cpu()

In [None]:
embeddings = model.encode(dataset[:][0])[0].detach()

In [None]:
embeddings.corrcoef()

In [None]:
X.corrcoef()

In [None]:
embeddings_norm = F.normalize(embeddings, dim=1)

In [None]:
torch.matmul(embeddings_norm, embeddings_norm.T)

In [None]:
torch.save(model, "model.pt")

# Quantile based embeddings

In [None]:
quantiles = X.quantile(
    torch.linspace(0, 1, 256),
    dim=1,
    # keepdim=True
)

In [None]:
quantiles.shape

In [None]:
quantiles_norm = F.normalize(quantiles.T, dim=1)
quantiles_norm.shape

In [None]:
similarity = torch.matmul(quantiles_norm, quantiles_norm.T)

In [None]:
similarity

In [None]:
X_norm = F.normalize(X, dim=-1)
torch.matmul(X_norm, X_norm.T)

In [None]:
gene_sim

In [None]:
F.normalize(X, 1, dim=-1).sum(-1).shape

In [None]:
torch.corrcoef(torch.argsort(X, dim=-1))

In [None]:
torch.corrcoef(X)

In [None]:
quantiles.shape

In [None]:
quantiles_df = pl.DataFrame(
    data=quantiles.T, schema=[f"quantile_{i}" for i in range(256)]
).with_columns(gene_name=pl.Series(counts.select(cs.numeric()).columns))
quantiles_df = quantiles_df.select("gene_name", cs.numeric())
quantiles_df

In [None]:
quantiles_df.select(cs.numeric()).corr()

In [None]:
quantiles_df.write_parquet(data_path / "gene_embeddings/quantiles-train_expression.parquet")