In [None]:
import hvplot.polars  # noqa
import hydra
import numpy as np
import pandas as pd
import polars as pl
import polars.selectors as cs
import rootutils
import torch
import torch.nn.functional as F

rootutils.setup_root("../", pythonpath=True)

In [None]:
# setting up paths
from pathlib import Path

from dotenv import dotenv_values

paths = dotenv_values()

data_path = Path(paths["DATA_DIR"])
data_path

In [None]:
df = pl.read_parquet(data_path / "gene_embeddings/PCA-train_expression.parquet")
df

In [None]:
genes = df["gene_name"]
dataset = df.select(cs.numeric()).to_torch("dataset")
dataset

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset=dataset, num_workers=1, batch_size=32, shuffle=False
)
batch = dataloader.__iter__().__next__()[0]
batch

In [None]:
batch[0].shape

In [None]:
pred_norm = F.normalize(batch, dim=-1)
gene_norm = F.normalize(batch, dim=-1)
pred_norm

In [None]:
cos_sim = torch.matmul(pred_norm, pred_norm.T)
cos_sim

In [None]:
mask = ~torch.eye(32, dtype=bool)
mask

In [None]:
cos_sim_off_diag = cos_sim[mask].view(32, 31)

In [None]:
cos_sim_off_diag[1, :]

In [None]:
weights = torch.pow(cos_sim_off_diag, 2)
weights[1, :]

In [None]:
exp_cos_sim = torch.exp(cos_sim_off_diag)
exp_cos_sim

In [None]:
weighted_pos = (weights * exp_cos_sim).sum(dim=1)
weighted_pos

In [None]:
total_sim = exp_cos_sim.sum(dim=1)
total_sim

In [None]:
weighted_pos / (total_sim + 1e-8)