In [1]:
import sys
sys.path.append('..')
from data.data_reader import *

In [2]:
import tqdm
import scanpy as sc
import pandas as pd

In [3]:
download_file('https://plus.figshare.com/ndownloader/files/35775512','35775512.h5ad')
adata_orig = sc.read_h5ad("35775512.h5ad")
adata_orig.X[adata_orig.X == float("inf")]=0

File downloaded successfully to 35775512.h5ad


In [4]:
adata_orig.obs['gene_name']=list(pd.Series(adata_orig.obs.index).apply(lambda x:x.split("_")[1]))
adata_orig.obs['id']=range(adata_orig.obs.shape[0])

In [5]:
import tqdm
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [6]:
def cosine_similarity(A):
  AAt=np.matmul(A,A.transpose())
  n_A=np.sqrt((A**2).sum(axis=1)).reshape(-1,1)
  n_A=np.matmul(n_A,n_A.transpose())
  return AAt/(n_A)

### Let's first train a VAE model

In [7]:
class X_dataset(Dataset):
    def __init__(self,data):
        self.data=data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}


In [8]:
dataset=X_dataset(adata_orig)
train_loader=DataLoader(dataset,batch_size=32,shuffle=True)

In [9]:
BASENUM=512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Encoder(nn.Module):
    def __init__(self, latent_dim=10,kl_coef=0.000001):
        super(Encoder, self).__init__()
        self.latent_dim=latent_dim
        self.kl_coef=kl_coef
        self.dense1=nn.Linear(dataset[0]['x'].shape[0],BASENUM)
        self.bn1=nn.BatchNorm1d(BASENUM)
        self.dense2=nn.Linear(BASENUM,BASENUM//4)
        self.bn2=nn.BatchNorm1d(BASENUM//4)
        self.mu=nn.Linear(BASENUM//4, latent_dim)
        self.logvar=nn.Linear(BASENUM//4, latent_dim)
        self.kl = 0
    def reparameterize(self, mu , logvar):
        std = torch.exp(logvar*0.5)
        eps = torch.randn_like(std).to(device)
        z = mu + eps * std
        return z
    def forward(self, x):
        bn=x.size(0)
        x=F.relu(self.bn1(self.dense1(x)))
        x=F.relu(self.bn2(self.dense2(x)))
        mu =  self.mu(x)
        logvar = self.logvar(x)
        z=self.reparameterize(mu , logvar)
        self.kl = 0.5*(logvar.exp() + mu**2 - logvar - 1).sum()*self.kl_coef
        return z


class Decoder(nn.Module):
    def __init__(self, latent_dim=8):
        super(Decoder, self).__init__()
        self.dense1=nn.Linear(latent_dim,BASENUM//4)
        self.bn1=nn.BatchNorm1d(BASENUM//4)
        self.dense2=nn.Linear(BASENUM//4,BASENUM)
        self.bn2=nn.BatchNorm1d(BASENUM)
        self.out=nn.Linear(BASENUM,dataset[0]['x'].shape[0])

    def forward(self, z):
        z = F.relu(self.bn1(self.dense1(z)))
        z = F.relu(self.bn2(self.dense2(z)))
        z = self.out(z)
        return z

class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims=10,kl_coef=0.000001):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(latent_dims,kl_coef).to(device)
        self.decoder = Decoder(latent_dims).to(device)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)


In [10]:
# autoencoder=VariationalAutoencoder(20,0.1)
autoencoder=VariationalAutoencoder(20,1e-9)
opt = torch.optim.Adam(autoencoder.parameters(),lr=0.001)
loss_fn=torch.nn.MSELoss()


  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}


In [11]:
autoencoder.train()
for epoch in range(100):
    train_loss1=0
    train_loss2=0
    autoencoder.train()
    for batch in tqdm.tqdm(train_loader):
        x = batch['x'].to(device) # GPU
        c = batch['c'].to(device) # GPU
        opt.zero_grad()
        x_hat = autoencoder(x)
        loss1=loss_fn(x_hat,x)
        loss2=autoencoder.encoder.kl
        loss = loss1 + loss2
        train_loss1+=loss1.detach().cpu().numpy()
        train_loss2+=loss2.detach().cpu().numpy()
        loss.backward()
        opt.step()
    print(f"TRAIN: EPOCH {epoch}: MSE: {train_loss1/len(train_loader)}, KL_LOSS: {train_loss2/len(train_loader)}")



  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}
100%|██████████| 84/84 [00:01<00:00, 71.59it/s] 


TRAIN: EPOCH 0: MSE: 0.07866987781155677, KL_LOSS: 7.507470368547534e-07


100%|██████████| 84/84 [00:00<00:00, 100.34it/s]


TRAIN: EPOCH 1: MSE: 0.0528616142858352, KL_LOSS: 1.36090612353915e-06


100%|██████████| 84/84 [00:00<00:00, 97.57it/s] 


TRAIN: EPOCH 2: MSE: 0.04972523313370489, KL_LOSS: 1.5389885845706885e-06


100%|██████████| 84/84 [00:00<00:00, 121.77it/s]


TRAIN: EPOCH 3: MSE: 0.0474938845616721, KL_LOSS: 1.7023029995146186e-06


100%|██████████| 84/84 [00:00<00:00, 110.72it/s]


TRAIN: EPOCH 4: MSE: 0.046000901077474864, KL_LOSS: 1.8361713054153889e-06


100%|██████████| 84/84 [00:00<00:00, 87.74it/s] 


TRAIN: EPOCH 5: MSE: 0.04499698295036242, KL_LOSS: 1.963261650774127e-06


100%|██████████| 84/84 [00:01<00:00, 72.83it/s] 


TRAIN: EPOCH 6: MSE: 0.04454555069761617, KL_LOSS: 2.070732389960367e-06


100%|██████████| 84/84 [00:00<00:00, 104.83it/s]


TRAIN: EPOCH 7: MSE: 0.04338913293377984, KL_LOSS: 2.162971760472352e-06


100%|██████████| 84/84 [00:00<00:00, 102.78it/s]


TRAIN: EPOCH 8: MSE: 0.04229024398539748, KL_LOSS: 2.3037469667369373e-06


100%|██████████| 84/84 [00:00<00:00, 127.22it/s]


TRAIN: EPOCH 9: MSE: 0.041679693186389546, KL_LOSS: 2.4338301749421157e-06


100%|██████████| 84/84 [00:00<00:00, 133.46it/s]


TRAIN: EPOCH 10: MSE: 0.0407896225251967, KL_LOSS: 2.428005264694851e-06


100%|██████████| 84/84 [00:00<00:00, 131.15it/s]


TRAIN: EPOCH 11: MSE: 0.0399091726479431, KL_LOSS: 2.507325635157204e-06


100%|██████████| 84/84 [00:00<00:00, 128.44it/s]


TRAIN: EPOCH 12: MSE: 0.03967428220702069, KL_LOSS: 2.5387014916521454e-06


100%|██████████| 84/84 [00:00<00:00, 140.80it/s]


TRAIN: EPOCH 13: MSE: 0.03953065720963336, KL_LOSS: 2.645340194768713e-06


100%|██████████| 84/84 [00:00<00:00, 152.22it/s]


TRAIN: EPOCH 14: MSE: 0.03819738561287522, KL_LOSS: 2.70709552383085e-06


100%|██████████| 84/84 [00:00<00:00, 136.45it/s]


TRAIN: EPOCH 15: MSE: 0.03818866406523046, KL_LOSS: 2.8018493548556546e-06


100%|██████████| 84/84 [00:00<00:00, 140.56it/s]


TRAIN: EPOCH 16: MSE: 0.03781060801286783, KL_LOSS: 2.8259308351152868e-06


100%|██████████| 84/84 [00:00<00:00, 135.85it/s]


TRAIN: EPOCH 17: MSE: 0.03685334460100248, KL_LOSS: 2.8488827540838557e-06


100%|██████████| 84/84 [00:00<00:00, 107.91it/s]


TRAIN: EPOCH 18: MSE: 0.03638520615086669, KL_LOSS: 2.8588287688892578e-06


100%|██████████| 84/84 [00:00<00:00, 108.98it/s]


TRAIN: EPOCH 19: MSE: 0.036708348497216194, KL_LOSS: 2.9296657679705753e-06


100%|██████████| 84/84 [00:00<00:00, 122.89it/s]


TRAIN: EPOCH 20: MSE: 0.03602980150442038, KL_LOSS: 2.9908620230604096e-06


100%|██████████| 84/84 [00:00<00:00, 146.33it/s]


TRAIN: EPOCH 21: MSE: 0.03548041090280527, KL_LOSS: 3.0190023034308634e-06


100%|██████████| 84/84 [00:00<00:00, 144.36it/s]


TRAIN: EPOCH 22: MSE: 0.035122004515003594, KL_LOSS: 3.030719773007496e-06


100%|██████████| 84/84 [00:00<00:00, 145.82it/s]


TRAIN: EPOCH 23: MSE: 0.03429830249487644, KL_LOSS: 3.0621406115945624e-06


100%|██████████| 84/84 [00:00<00:00, 118.24it/s]


TRAIN: EPOCH 24: MSE: 0.034260834323331005, KL_LOSS: 3.0736908776080305e-06


100%|██████████| 84/84 [00:00<00:00, 104.02it/s]


TRAIN: EPOCH 25: MSE: 0.03375462901645473, KL_LOSS: 3.129349768945152e-06


100%|██████████| 84/84 [00:01<00:00, 68.09it/s] 


TRAIN: EPOCH 26: MSE: 0.033515051317711674, KL_LOSS: 3.164834140534367e-06


100%|██████████| 84/84 [00:00<00:00, 147.43it/s]


TRAIN: EPOCH 27: MSE: 0.03388472744041965, KL_LOSS: 3.215726240211682e-06


100%|██████████| 84/84 [00:00<00:00, 148.52it/s]


TRAIN: EPOCH 28: MSE: 0.03375525109558588, KL_LOSS: 3.214301597966239e-06


100%|██████████| 84/84 [00:00<00:00, 144.68it/s]


TRAIN: EPOCH 29: MSE: 0.03349691749151264, KL_LOSS: 3.2462367260320335e-06


100%|██████████| 84/84 [00:00<00:00, 150.17it/s]


TRAIN: EPOCH 30: MSE: 0.03341481668342437, KL_LOSS: 3.292497697635727e-06


100%|██████████| 84/84 [00:00<00:00, 136.54it/s]


TRAIN: EPOCH 31: MSE: 0.03348411777101103, KL_LOSS: 3.3490978077160046e-06


100%|██████████| 84/84 [00:00<00:00, 138.36it/s]


TRAIN: EPOCH 32: MSE: 0.034200340354194246, KL_LOSS: 3.4334224264533814e-06


100%|██████████| 84/84 [00:00<00:00, 136.54it/s]


TRAIN: EPOCH 33: MSE: 0.03289748462183135, KL_LOSS: 3.4604562794508333e-06


100%|██████████| 84/84 [00:00<00:00, 144.84it/s]


TRAIN: EPOCH 34: MSE: 0.032129179287169664, KL_LOSS: 3.4439887816092994e-06


100%|██████████| 84/84 [00:00<00:00, 139.74it/s]


TRAIN: EPOCH 35: MSE: 0.03243640453244249, KL_LOSS: 3.5284312935973936e-06


100%|██████████| 84/84 [00:00<00:00, 143.35it/s]


TRAIN: EPOCH 36: MSE: 0.03229481875452967, KL_LOSS: 3.503827318105496e-06


100%|██████████| 84/84 [00:00<00:00, 140.42it/s]


TRAIN: EPOCH 37: MSE: 0.031468593360235296, KL_LOSS: 3.5426708488254612e-06


100%|██████████| 84/84 [00:00<00:00, 126.83it/s]


TRAIN: EPOCH 38: MSE: 0.030942275277560667, KL_LOSS: 3.5427931515013847e-06


100%|██████████| 84/84 [00:01<00:00, 65.84it/s]


TRAIN: EPOCH 39: MSE: 0.030701833466688793, KL_LOSS: 3.5562846716987096e-06


100%|██████████| 84/84 [00:00<00:00, 132.91it/s]


TRAIN: EPOCH 40: MSE: 0.030999477952718735, KL_LOSS: 3.5867833070196433e-06


100%|██████████| 84/84 [00:00<00:00, 120.96it/s]


TRAIN: EPOCH 41: MSE: 0.03100990504026413, KL_LOSS: 3.6203659008259864e-06


100%|██████████| 84/84 [00:00<00:00, 106.27it/s]


TRAIN: EPOCH 42: MSE: 0.03028665902093053, KL_LOSS: 3.637893081802412e-06


100%|██████████| 84/84 [00:00<00:00, 122.36it/s]


TRAIN: EPOCH 43: MSE: 0.03021816709744079, KL_LOSS: 3.703954925120862e-06


100%|██████████| 84/84 [00:00<00:00, 99.24it/s] 


TRAIN: EPOCH 44: MSE: 0.030140631721310672, KL_LOSS: 3.706148301522002e-06


100%|██████████| 84/84 [00:01<00:00, 79.96it/s] 


TRAIN: EPOCH 45: MSE: 0.02979853886756159, KL_LOSS: 3.7250784890559206e-06


100%|██████████| 84/84 [00:00<00:00, 100.78it/s]


TRAIN: EPOCH 46: MSE: 0.029583629309421496, KL_LOSS: 3.7487728236772543e-06


100%|██████████| 84/84 [00:00<00:00, 115.44it/s]


TRAIN: EPOCH 47: MSE: 0.029666936579382137, KL_LOSS: 3.752079789722172e-06


100%|██████████| 84/84 [00:00<00:00, 125.24it/s]


TRAIN: EPOCH 48: MSE: 0.029772714756074407, KL_LOSS: 3.8033870872649936e-06


100%|██████████| 84/84 [00:00<00:00, 132.16it/s]


TRAIN: EPOCH 49: MSE: 0.02942888420962152, KL_LOSS: 3.836705434964859e-06


100%|██████████| 84/84 [00:01<00:00, 69.36it/s]


TRAIN: EPOCH 50: MSE: 0.029004028754397518, KL_LOSS: 3.879204533469809e-06


100%|██████████| 84/84 [00:00<00:00, 113.92it/s]


TRAIN: EPOCH 51: MSE: 0.029288581882913906, KL_LOSS: 3.877520059956753e-06


100%|██████████| 84/84 [00:00<00:00, 97.10it/s] 


TRAIN: EPOCH 52: MSE: 0.02903040460798712, KL_LOSS: 3.9106363238889094e-06


100%|██████████| 84/84 [00:00<00:00, 117.65it/s]


TRAIN: EPOCH 53: MSE: 0.028523739028189863, KL_LOSS: 3.944992375268373e-06


100%|██████████| 84/84 [00:00<00:00, 131.97it/s]


TRAIN: EPOCH 54: MSE: 0.028993324576211826, KL_LOSS: 3.957471313870406e-06


100%|██████████| 84/84 [00:00<00:00, 121.38it/s]


TRAIN: EPOCH 55: MSE: 0.028688756039454824, KL_LOSS: 3.968823967798449e-06


100%|██████████| 84/84 [00:00<00:00, 120.15it/s]


TRAIN: EPOCH 56: MSE: 0.027981590967447983, KL_LOSS: 3.969273068703062e-06


100%|██████████| 84/84 [00:00<00:00, 116.92it/s]


TRAIN: EPOCH 57: MSE: 0.02837698815745257, KL_LOSS: 3.9746936056958696e-06


100%|██████████| 84/84 [00:00<00:00, 102.40it/s]


TRAIN: EPOCH 58: MSE: 0.028281085259680237, KL_LOSS: 4.020514471189277e-06


100%|██████████| 84/84 [00:00<00:00, 93.74it/s] 


TRAIN: EPOCH 59: MSE: 0.027902615899663596, KL_LOSS: 4.005216331799082e-06


100%|██████████| 84/84 [00:00<00:00, 138.91it/s]


TRAIN: EPOCH 60: MSE: 0.027778176834718102, KL_LOSS: 4.049146950603158e-06


100%|██████████| 84/84 [00:00<00:00, 114.88it/s]


TRAIN: EPOCH 61: MSE: 0.027852039562449568, KL_LOSS: 4.061095456195506e-06


100%|██████████| 84/84 [00:00<00:00, 124.14it/s]


TRAIN: EPOCH 62: MSE: 0.02769838302351889, KL_LOSS: 4.099446138232972e-06


100%|██████████| 84/84 [00:01<00:00, 83.59it/s]


TRAIN: EPOCH 63: MSE: 0.027957601095771507, KL_LOSS: 4.109859232704323e-06


100%|██████████| 84/84 [00:00<00:00, 138.58it/s]


TRAIN: EPOCH 64: MSE: 0.02725079602428845, KL_LOSS: 4.133075313867656e-06


100%|██████████| 84/84 [00:00<00:00, 126.86it/s]


TRAIN: EPOCH 65: MSE: 0.027522393302725896, KL_LOSS: 4.192261398895339e-06


100%|██████████| 84/84 [00:00<00:00, 117.60it/s]


TRAIN: EPOCH 66: MSE: 0.027560526837727854, KL_LOSS: 4.212447081607977e-06


100%|██████████| 84/84 [00:00<00:00, 113.14it/s]


TRAIN: EPOCH 67: MSE: 0.027599718199954146, KL_LOSS: 4.2422983651271175e-06


100%|██████████| 84/84 [00:00<00:00, 92.16it/s] 


TRAIN: EPOCH 68: MSE: 0.027097619892585845, KL_LOSS: 4.205933166867042e-06


100%|██████████| 84/84 [00:00<00:00, 131.35it/s]


TRAIN: EPOCH 69: MSE: 0.027136201127654032, KL_LOSS: 4.24695883252536e-06


100%|██████████| 84/84 [00:00<00:00, 93.85it/s] 


TRAIN: EPOCH 70: MSE: 0.026925008500083572, KL_LOSS: 4.276362931575152e-06


100%|██████████| 84/84 [00:00<00:00, 84.17it/s]


TRAIN: EPOCH 71: MSE: 0.0264242788821104, KL_LOSS: 4.27380827464471e-06


100%|██████████| 84/84 [00:01<00:00, 82.92it/s] 


TRAIN: EPOCH 72: MSE: 0.02653488889336586, KL_LOSS: 4.304659312094286e-06


100%|██████████| 84/84 [00:00<00:00, 91.17it/s] 


TRAIN: EPOCH 73: MSE: 0.02619966860151007, KL_LOSS: 4.301827567858363e-06


100%|██████████| 84/84 [00:00<00:00, 111.88it/s]


TRAIN: EPOCH 74: MSE: 0.026385820204658166, KL_LOSS: 4.324047838941797e-06


100%|██████████| 84/84 [00:00<00:00, 132.18it/s]


TRAIN: EPOCH 75: MSE: 0.026491079945117235, KL_LOSS: 4.3478759775574696e-06


100%|██████████| 84/84 [00:00<00:00, 143.94it/s]


TRAIN: EPOCH 76: MSE: 0.026334703545130435, KL_LOSS: 4.348786614541279e-06


100%|██████████| 84/84 [00:00<00:00, 146.96it/s]


TRAIN: EPOCH 77: MSE: 0.026332720936763854, KL_LOSS: 4.3807331897889005e-06


100%|██████████| 84/84 [00:00<00:00, 126.16it/s]


TRAIN: EPOCH 78: MSE: 0.0260694097461445, KL_LOSS: 4.3876034965251165e-06


100%|██████████| 84/84 [00:00<00:00, 100.54it/s]


TRAIN: EPOCH 79: MSE: 0.025926516302639528, KL_LOSS: 4.408286302003987e-06


100%|██████████| 84/84 [00:00<00:00, 144.56it/s]


TRAIN: EPOCH 80: MSE: 0.025663646337177073, KL_LOSS: 4.400747928100048e-06


100%|██████████| 84/84 [00:00<00:00, 147.87it/s]


TRAIN: EPOCH 81: MSE: 0.025831357797696478, KL_LOSS: 4.424867272843715e-06


100%|██████████| 84/84 [00:00<00:00, 123.15it/s]


TRAIN: EPOCH 82: MSE: 0.025604871190374807, KL_LOSS: 4.408951223835868e-06


100%|██████████| 84/84 [00:00<00:00, 128.10it/s]


TRAIN: EPOCH 83: MSE: 0.025525413231835478, KL_LOSS: 4.405770691088644e-06


100%|██████████| 84/84 [00:01<00:00, 57.94it/s]


TRAIN: EPOCH 84: MSE: 0.025713778899184296, KL_LOSS: 4.422842250408264e-06


100%|██████████| 84/84 [00:00<00:00, 87.41it/s] 


TRAIN: EPOCH 85: MSE: 0.025472129029887065, KL_LOSS: 4.424945865637364e-06


100%|██████████| 84/84 [00:00<00:00, 129.63it/s]


TRAIN: EPOCH 86: MSE: 0.025290073516468208, KL_LOSS: 4.443485823100803e-06


100%|██████████| 84/84 [00:00<00:00, 127.69it/s]


TRAIN: EPOCH 87: MSE: 0.02541482417533795, KL_LOSS: 4.436128614586432e-06


100%|██████████| 84/84 [00:00<00:00, 96.28it/s] 


TRAIN: EPOCH 88: MSE: 0.02533402198570825, KL_LOSS: 4.4228128244651e-06


100%|██████████| 84/84 [00:00<00:00, 131.13it/s]


TRAIN: EPOCH 89: MSE: 0.02525174424850515, KL_LOSS: 4.416914136873813e-06


100%|██████████| 84/84 [00:00<00:00, 128.63it/s]


TRAIN: EPOCH 90: MSE: 0.02536713025931801, KL_LOSS: 4.405843807966275e-06


100%|██████████| 84/84 [00:00<00:00, 114.69it/s]


TRAIN: EPOCH 91: MSE: 0.025966802567598365, KL_LOSS: 4.470549662355119e-06


100%|██████████| 84/84 [00:00<00:00, 136.01it/s]


TRAIN: EPOCH 92: MSE: 0.025086438771159875, KL_LOSS: 4.4689527006388135e-06


100%|██████████| 84/84 [00:00<00:00, 142.38it/s]


TRAIN: EPOCH 93: MSE: 0.024923702169741904, KL_LOSS: 4.482158173925113e-06


100%|██████████| 84/84 [00:00<00:00, 115.99it/s]


TRAIN: EPOCH 94: MSE: 0.024641726365579025, KL_LOSS: 4.493164269769338e-06


100%|██████████| 84/84 [00:00<00:00, 124.66it/s]


TRAIN: EPOCH 95: MSE: 0.02451686126490434, KL_LOSS: 4.496879325545576e-06


100%|██████████| 84/84 [00:00<00:00, 135.05it/s]


TRAIN: EPOCH 96: MSE: 0.024389516717443865, KL_LOSS: 4.49541326068876e-06


100%|██████████| 84/84 [00:00<00:00, 134.82it/s]


TRAIN: EPOCH 97: MSE: 0.024919298632691305, KL_LOSS: 4.477989724212724e-06


100%|██████████| 84/84 [00:01<00:00, 67.42it/s] 


TRAIN: EPOCH 98: MSE: 0.02459654227520029, KL_LOSS: 4.512763888713034e-06


100%|██████████| 84/84 [00:00<00:00, 154.31it/s]

TRAIN: EPOCH 99: MSE: 0.024836466182023287, KL_LOSS: 4.511417287068027e-06





In [12]:
autoencoder.eval()
encoded_x=[]
cs=[]
for rec in tqdm.tqdm(dataset):
    x = rec['x'].reshape(1,-1).to(device) # GPU
    c = rec['c'].reshape(1,).to(device) # GPU
    encoded_x.append(autoencoder.encoder(x).cpu().detach().numpy())
    cs.append(c.cpu().detach().numpy())
encoded_x=np.concatenate(encoded_x,axis=0)
encoded_x=(encoded_x-encoded_x.mean(axis=0,keepdims=True))/encoded_x.std(axis=0,keepdims=True)
cs=np.concatenate(cs,axis=0)
df_to_be_shown=pd.DataFrame(encoded_x,columns=[f'f{i}' for i in range(encoded_x.shape[1])])
df_to_be_shown['control']=cs

  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}
100%|██████████| 2679/2679 [00:02<00:00, 995.27it/s] 


### We get the cosine similarity of every two perturabation

In [13]:
cos_sim_f=cosine_similarity(np.array(df_to_be_shown.drop(['control'], axis=1)))

### Now we want to know which two perturabations are similar

In [14]:
similarity_matrix=np.zeros(cos_sim_f.shape)
similarity_db=hu_data_loader()

for gene_name in tqdm.tqdm(adata_orig.obs.gene_name.unique()):
    query=query_hu_data(similarity_db,gene_name)
    for q in query:
        if q in adata_orig.obs.gene_name.values:
            y_indices=adata_orig.obs[adata_orig.obs.gene_name==q].id
            x_indices=adata_orig.obs[adata_orig.obs.gene_name==gene_name].id
            for x_id in x_indices:
                for y_id in y_indices:
                    similarity_matrix[y_id,x_id]=1
                    similarity_matrix[x_id,y_id]=1

cos_sim_f_flatten=cos_sim_f.reshape(-1,)
similarity_matrix_flatten=similarity_matrix.reshape(-1,)
cos_sim_f_flatten1=cos_sim_f_flatten[similarity_matrix_flatten==1]
cos_sim_f_flatten0=cos_sim_f_flatten[similarity_matrix_flatten==0]

File downloaded successfully to humap2_complexes_20200809.txt


100%|██████████| 2394/2394 [00:43<00:00, 54.56it/s] 


### We want to visualize the value of recall with respect to different quantiles as thresholds for similarities

In [1]:
def get_recall(rate):
    qrate_down=np.quantile(cos_sim_f_flatten,rate)
    qrate_up=np.quantile(cos_sim_f_flatten,1-rate)
    pred_p=np.logical_or((cos_sim_f_flatten>qrate_up,cos_sim_f_flatten<qrate_down))
    pred_n=np.logical_and((cos_sim_f_flatten<qrate_up,cos_sim_f_flatten>qrate_down))
    tp=np.logical_and(pred_p,similarity_matrix_flatten==1).sum()
    fp=np.logical_and(pred_p,similarity_matrix_flatten==0).sum()
    fn=np.logical_and(pred_n,similarity_matrix_flatten==1).sum()
    return tp/(tp+fn)
def visualize_recal_vs_quantile():
    values=[]
    xs=[i*0.05 for i in range(10)]
    for i in xs:
        values.append(round(get_recall(i),3))
    temp_df=pd.DataFrame({'quantile':xs,'recall':values})
    fig=px.line(temp_df,x='quantile',y='recall',title='recall_vs_quantile',width=1000, height=400)
    fig.update_traces(mode='lines+text', text=list(map(lambda x:f'{x}-{1-x}',values)), textposition='top center')
    fig.update_layout(
    font=dict(
        family="Arial, sans-serif",
        size=10,  # Set the desired font size
        color="black"
    )
)
    fig.show()
    

In [2]:
visualize_recal_vs_quantile()

NameError: name 'np' is not defined

### Okay. now we plot the distributions of similarities divided into classes. first, those pairs that are already know to be similar. Second those that are not.

In [17]:
choice=np.random.choice(cos_sim_f_flatten1.shape[0], 2048)
cos_sim_f_flatten1=cos_sim_f_flatten1[choice]
cos_sim_f_flatten1=pd.DataFrame(cos_sim_f_flatten1,columns=['correlations'])
fig=px.violin(cos_sim_f_flatten1, y='correlations',width=500, height=400,title="SIMILARS")
fig.show()


choice=np.random.choice(cos_sim_f_flatten0.shape[0], 2048)
cos_sim_f_flatten0=cos_sim_f_flatten0[choice]
cos_sim_f_flatten0=pd.DataFrame(cos_sim_f_flatten0,columns=['correlations'])
fig=px.violin(cos_sim_f_flatten0, y='correlations',width=500, height=400,title="Not SIMILARS")
fig.show()

In [18]:
print("Not SIMILARS MEAN:",cos_sim_f_flatten0.mean())
print("SIMILARS MEAN:",cos_sim_f_flatten1.mean())

Not SIMILARS MEAN: correlations    0.017548
dtype: float32
SIMILARS MEAN: correlations    0.460109
dtype: float32


### Here is the visualiztion of feature vectors

In [19]:
fig=px.scatter(df_to_be_shown,x='f0',y='f1',color='control',width=500, height=400)
fig.show()
fig=px.scatter(df_to_be_shown,x='f2',y='f3',color='control',width=500, height=400)
fig.show()
fig=px.scatter(df_to_be_shown,x='f4',y='f5',color='control',width=500, height=400)
fig.show()











