In [3]:
import sys
sys.path.append('..')
from data.data_reader import *

In [4]:
import tqdm
import scanpy as sc
import pandas as pd

In [5]:
download_file('https://plus.figshare.com/ndownloader/files/35775512','35775512.h5ad')
adata_orig = sc.read_h5ad("35775512.h5ad")
adata_orig.X[adata_orig.X == float("inf")]=0

File downloaded successfully to 35775512.h5ad


In [6]:
adata_orig.obs['gene_name']=list(pd.Series(adata_orig.obs.index).apply(lambda x:x.split("_")[1]))
adata_orig.obs['id']=range(adata_orig.obs.shape[0])

In [7]:
import tqdm
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [8]:
def cosine_similarity(A):
  AAt=np.matmul(A,A.transpose())
  n_A=np.sqrt((A**2).sum(axis=1)).reshape(-1,1)
  n_A=np.matmul(n_A,n_A.transpose())
  return AAt/(n_A)

### Let's first train a VAE model

In [9]:
class X_dataset(Dataset):
    def __init__(self,data):
        self.data=data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}


In [10]:
dataset=X_dataset(adata_orig)
train_loader=DataLoader(dataset,batch_size=32,shuffle=True)

In [11]:
BASENUM=512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Encoder(nn.Module):
    def __init__(self, latent_dim=10,kl_coef=0.000001):
        super(Encoder, self).__init__()
        self.latent_dim=latent_dim
        self.kl_coef=kl_coef
        self.dense1=nn.Linear(dataset[0]['x'].shape[0],BASENUM)
        self.bn1=nn.BatchNorm1d(BASENUM)
        self.dense2=nn.Linear(BASENUM,BASENUM//4)
        self.bn2=nn.BatchNorm1d(BASENUM//4)
        self.mu=nn.Linear(BASENUM//4, latent_dim)
        self.logvar=nn.Linear(BASENUM//4, latent_dim)
        self.kl = 0
    def reparameterize(self, mu , logvar):
        std = torch.exp(logvar*0.5)
        eps = torch.randn_like(std).to(device)
        z = mu + eps * std
        return z
    def forward(self, x):
        bn=x.size(0)
        x=F.relu(self.bn1(self.dense1(x)))
        x=F.relu(self.bn2(self.dense2(x)))
        mu =  self.mu(x)
        logvar = self.logvar(x)
        z=self.reparameterize(mu , logvar)
        self.kl = 0.5*(logvar.exp() + mu**2 - logvar - 1).sum()*self.kl_coef
        return z


class Decoder(nn.Module):
    def __init__(self, latent_dim=8):
        super(Decoder, self).__init__()
        self.dense1=nn.Linear(latent_dim,BASENUM//4)
        self.bn1=nn.BatchNorm1d(BASENUM//4)
        self.dense2=nn.Linear(BASENUM//4,BASENUM)
        self.bn2=nn.BatchNorm1d(BASENUM)
        self.out=nn.Linear(BASENUM,dataset[0]['x'].shape[0])

    def forward(self, z):
        z = F.relu(self.bn1(self.dense1(z)))
        z = F.relu(self.bn2(self.dense2(z)))
        z = self.out(z)
        return z

class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims=10,kl_coef=0.000001):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(latent_dims,kl_coef).to(device)
        self.decoder = Decoder(latent_dims).to(device)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)


In [12]:
# autoencoder=VariationalAutoencoder(20,0.1)
autoencoder=VariationalAutoencoder(20,1e-9)
opt = torch.optim.Adam(autoencoder.parameters(),lr=0.001)
loss_fn=torch.nn.MSELoss()


  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}


In [13]:
autoencoder.train()
for epoch in range(100):
    train_loss1=0
    train_loss2=0
    autoencoder.train()
    for batch in tqdm.tqdm(train_loader):
        x = batch['x'].to(device) # GPU
        c = batch['c'].to(device) # GPU
        opt.zero_grad()
        x_hat = autoencoder(x)
        loss1=loss_fn(x_hat,x)
        loss2=autoencoder.encoder.kl
        loss = loss1 + loss2
        train_loss1+=loss1.detach().cpu().numpy()
        train_loss2+=loss2.detach().cpu().numpy()
        loss.backward()
        opt.step()
    print(f"TRAIN: EPOCH {epoch}: MSE: {train_loss1/len(train_loader)}, KL_LOSS: {train_loss2/len(train_loader)}")



  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}
100%|██████████| 84/84 [00:00<00:00, 84.20it/s] 


TRAIN: EPOCH 0: MSE: 0.0763701115779224, KL_LOSS: 8.426465376893637e-07


100%|██████████| 84/84 [00:00<00:00, 149.86it/s]


TRAIN: EPOCH 1: MSE: 0.051926176462854655, KL_LOSS: 1.4981188458239644e-06


100%|██████████| 84/84 [00:00<00:00, 146.71it/s]


TRAIN: EPOCH 2: MSE: 0.04896696532766024, KL_LOSS: 1.7122020954914097e-06


100%|██████████| 84/84 [00:00<00:00, 154.56it/s]


TRAIN: EPOCH 3: MSE: 0.047867352032058295, KL_LOSS: 1.823806561181603e-06


100%|██████████| 84/84 [00:00<00:00, 148.39it/s]


TRAIN: EPOCH 4: MSE: 0.04618242400742713, KL_LOSS: 1.9511368114828135e-06


100%|██████████| 84/84 [00:00<00:00, 148.66it/s]


TRAIN: EPOCH 5: MSE: 0.04541666160470673, KL_LOSS: 2.053110905752051e-06


100%|██████████| 84/84 [00:00<00:00, 142.57it/s]


TRAIN: EPOCH 6: MSE: 0.043973338923283985, KL_LOSS: 2.1793760145967443e-06


100%|██████████| 84/84 [00:00<00:00, 140.73it/s]


TRAIN: EPOCH 7: MSE: 0.04327619725483514, KL_LOSS: 2.2558455048922386e-06


100%|██████████| 84/84 [00:00<00:00, 140.01it/s]


TRAIN: EPOCH 8: MSE: 0.04260828634280534, KL_LOSS: 2.342553048471191e-06


100%|██████████| 84/84 [00:00<00:00, 136.40it/s]


TRAIN: EPOCH 9: MSE: 0.041544016472817885, KL_LOSS: 2.4132119503541444e-06


100%|██████████| 84/84 [00:00<00:00, 134.17it/s]


TRAIN: EPOCH 10: MSE: 0.04121629811734671, KL_LOSS: 2.467109540237375e-06


100%|██████████| 84/84 [00:00<00:00, 93.30it/s]


TRAIN: EPOCH 11: MSE: 0.04026206281213533, KL_LOSS: 2.545824612200574e-06


100%|██████████| 84/84 [00:00<00:00, 127.21it/s]


TRAIN: EPOCH 12: MSE: 0.04020019072950596, KL_LOSS: 2.5919810880060096e-06


100%|██████████| 84/84 [00:00<00:00, 117.56it/s]


TRAIN: EPOCH 13: MSE: 0.03941867081448436, KL_LOSS: 2.6583593813440874e-06


100%|██████████| 84/84 [00:00<00:00, 135.35it/s]


TRAIN: EPOCH 14: MSE: 0.039194195376088224, KL_LOSS: 2.6807293888201196e-06


100%|██████████| 84/84 [00:00<00:00, 118.79it/s]


TRAIN: EPOCH 15: MSE: 0.038331254718026946, KL_LOSS: 2.72722001054284e-06


100%|██████████| 84/84 [00:01<00:00, 83.34it/s]


TRAIN: EPOCH 16: MSE: 0.038048723202553536, KL_LOSS: 2.7408613781127703e-06


100%|██████████| 84/84 [00:00<00:00, 124.75it/s]


TRAIN: EPOCH 17: MSE: 0.03755666069420321, KL_LOSS: 2.770446463901386e-06


100%|██████████| 84/84 [00:00<00:00, 130.90it/s]


TRAIN: EPOCH 18: MSE: 0.03665200847068003, KL_LOSS: 2.8302905087955534e-06


100%|██████████| 84/84 [00:00<00:00, 131.88it/s]


TRAIN: EPOCH 19: MSE: 0.036354366911663896, KL_LOSS: 2.851004580861627e-06


100%|██████████| 84/84 [00:00<00:00, 128.46it/s]


TRAIN: EPOCH 20: MSE: 0.03654362454212138, KL_LOSS: 2.914207917521188e-06


100%|██████████| 84/84 [00:00<00:00, 134.64it/s]


TRAIN: EPOCH 21: MSE: 0.03539874857025487, KL_LOSS: 2.9349952499268555e-06


100%|██████████| 84/84 [00:00<00:00, 133.06it/s]


TRAIN: EPOCH 22: MSE: 0.0373732366244353, KL_LOSS: 2.9876340693677082e-06


100%|██████████| 84/84 [00:00<00:00, 135.38it/s]


TRAIN: EPOCH 23: MSE: 0.03564763763209894, KL_LOSS: 3.087834481145061e-06


100%|██████████| 84/84 [00:00<00:00, 129.11it/s]


TRAIN: EPOCH 24: MSE: 0.03537122189022955, KL_LOSS: 3.117480516612843e-06


100%|██████████| 84/84 [00:00<00:00, 133.42it/s]


TRAIN: EPOCH 25: MSE: 0.034520295404252554, KL_LOSS: 3.1449449693354836e-06


100%|██████████| 84/84 [00:00<00:00, 130.12it/s]


TRAIN: EPOCH 26: MSE: 0.03468950447582063, KL_LOSS: 3.161149228011103e-06


100%|██████████| 84/84 [00:00<00:00, 136.61it/s]


TRAIN: EPOCH 27: MSE: 0.03364168462299165, KL_LOSS: 3.1875874891998786e-06


100%|██████████| 84/84 [00:00<00:00, 140.15it/s]


TRAIN: EPOCH 28: MSE: 0.03357176316369857, KL_LOSS: 3.185965673375384e-06


100%|██████████| 84/84 [00:00<00:00, 142.04it/s]


TRAIN: EPOCH 29: MSE: 0.03314068115183285, KL_LOSS: 3.226043986614968e-06


100%|██████████| 84/84 [00:00<00:00, 136.59it/s]


TRAIN: EPOCH 30: MSE: 0.0336746719133641, KL_LOSS: 3.276094690006305e-06


100%|██████████| 84/84 [00:00<00:00, 138.27it/s]


TRAIN: EPOCH 31: MSE: 0.03288518162887721, KL_LOSS: 3.2925425389727197e-06


100%|██████████| 84/84 [00:00<00:00, 143.83it/s]


TRAIN: EPOCH 32: MSE: 0.03292125097608992, KL_LOSS: 3.337178860701464e-06


100%|██████████| 84/84 [00:00<00:00, 145.60it/s]


TRAIN: EPOCH 33: MSE: 0.032092338161809106, KL_LOSS: 3.3357444522152946e-06


100%|██████████| 84/84 [00:00<00:00, 141.34it/s]


TRAIN: EPOCH 34: MSE: 0.03260190753887097, KL_LOSS: 3.3895327619791803e-06


100%|██████████| 84/84 [00:00<00:00, 141.97it/s]


TRAIN: EPOCH 35: MSE: 0.03207799792289734, KL_LOSS: 3.4014112039975382e-06


100%|██████████| 84/84 [00:00<00:00, 144.15it/s]


TRAIN: EPOCH 36: MSE: 0.0327960114172172, KL_LOSS: 3.4427465850837635e-06


100%|██████████| 84/84 [00:00<00:00, 147.46it/s]


TRAIN: EPOCH 37: MSE: 0.03210936994513586, KL_LOSS: 3.503973197266099e-06


100%|██████████| 84/84 [00:00<00:00, 148.63it/s]


TRAIN: EPOCH 38: MSE: 0.03136368565970943, KL_LOSS: 3.508952041946551e-06


100%|██████████| 84/84 [00:00<00:00, 147.64it/s]


TRAIN: EPOCH 39: MSE: 0.03152251110545227, KL_LOSS: 3.5294101589248124e-06


100%|██████████| 84/84 [00:00<00:00, 141.97it/s]


TRAIN: EPOCH 40: MSE: 0.031003917523083232, KL_LOSS: 3.566794565098722e-06


100%|██████████| 84/84 [00:00<00:00, 143.21it/s]


TRAIN: EPOCH 41: MSE: 0.03091511150289859, KL_LOSS: 3.5861593313971492e-06


100%|██████████| 84/84 [00:00<00:00, 144.35it/s]


TRAIN: EPOCH 42: MSE: 0.030695757518212, KL_LOSS: 3.6092168437833262e-06


100%|██████████| 84/84 [00:00<00:00, 143.21it/s]


TRAIN: EPOCH 43: MSE: 0.030632300169340203, KL_LOSS: 3.6210342846927577e-06


100%|██████████| 84/84 [00:00<00:00, 151.75it/s]


TRAIN: EPOCH 44: MSE: 0.03132962854579091, KL_LOSS: 3.7019149501594212e-06


100%|██████████| 84/84 [00:00<00:00, 155.01it/s]


TRAIN: EPOCH 45: MSE: 0.031531665984186386, KL_LOSS: 3.747029942197904e-06


100%|██████████| 84/84 [00:00<00:00, 152.17it/s]


TRAIN: EPOCH 46: MSE: 0.03126332915521094, KL_LOSS: 3.784972155547924e-06


100%|██████████| 84/84 [00:00<00:00, 151.34it/s]


TRAIN: EPOCH 47: MSE: 0.030158687759900375, KL_LOSS: 3.764952769610578e-06


100%|██████████| 84/84 [00:00<00:00, 144.83it/s]


TRAIN: EPOCH 48: MSE: 0.0298545483411068, KL_LOSS: 3.810573692438387e-06


100%|██████████| 84/84 [00:00<00:00, 138.82it/s]


TRAIN: EPOCH 49: MSE: 0.02940223201931942, KL_LOSS: 3.7988388720848544e-06


100%|██████████| 84/84 [00:00<00:00, 147.80it/s]


TRAIN: EPOCH 50: MSE: 0.03005736642738893, KL_LOSS: 3.8059918584362355e-06


100%|██████████| 84/84 [00:00<00:00, 145.31it/s]


TRAIN: EPOCH 51: MSE: 0.029471066469947498, KL_LOSS: 3.870625600181181e-06


100%|██████████| 84/84 [00:00<00:00, 143.01it/s]


TRAIN: EPOCH 52: MSE: 0.029122711225811924, KL_LOSS: 3.873045156545968e-06


100%|██████████| 84/84 [00:00<00:00, 146.04it/s]


TRAIN: EPOCH 53: MSE: 0.029103618469976243, KL_LOSS: 3.927127138359494e-06


100%|██████████| 84/84 [00:00<00:00, 137.77it/s]


TRAIN: EPOCH 54: MSE: 0.028543101845397836, KL_LOSS: 3.916846510708724e-06


100%|██████████| 84/84 [00:00<00:00, 143.94it/s]


TRAIN: EPOCH 55: MSE: 0.028483388014137745, KL_LOSS: 3.941170789511532e-06


100%|██████████| 84/84 [00:00<00:00, 137.22it/s]


TRAIN: EPOCH 56: MSE: 0.028681048812965553, KL_LOSS: 3.964017483191301e-06


100%|██████████| 84/84 [00:00<00:00, 148.01it/s]


TRAIN: EPOCH 57: MSE: 0.028541172876776683, KL_LOSS: 3.989298948324306e-06


100%|██████████| 84/84 [00:00<00:00, 143.53it/s]


TRAIN: EPOCH 58: MSE: 0.02797778509557247, KL_LOSS: 3.968111662916704e-06


100%|██████████| 84/84 [00:00<00:00, 143.01it/s]


TRAIN: EPOCH 59: MSE: 0.027877937048851026, KL_LOSS: 3.996684305158331e-06


100%|██████████| 84/84 [00:00<00:00, 142.22it/s]


TRAIN: EPOCH 60: MSE: 0.027794003530981996, KL_LOSS: 3.993670599724976e-06


100%|██████████| 84/84 [00:00<00:00, 141.04it/s]


TRAIN: EPOCH 61: MSE: 0.0277829447184645, KL_LOSS: 4.000118235039519e-06


100%|██████████| 84/84 [00:00<00:00, 142.34it/s]


TRAIN: EPOCH 62: MSE: 0.02743278829646962, KL_LOSS: 4.021094934528048e-06


100%|██████████| 84/84 [00:00<00:00, 146.45it/s]


TRAIN: EPOCH 63: MSE: 0.028509762581615222, KL_LOSS: 4.085256786571076e-06


100%|██████████| 84/84 [00:00<00:00, 140.81it/s]


TRAIN: EPOCH 64: MSE: 0.027930187260998145, KL_LOSS: 4.0870804398417975e-06


100%|██████████| 84/84 [00:00<00:00, 139.10it/s]


TRAIN: EPOCH 65: MSE: 0.02817611773276613, KL_LOSS: 4.103947471126698e-06


100%|██████████| 84/84 [00:00<00:00, 143.08it/s]


TRAIN: EPOCH 66: MSE: 0.027284699425633465, KL_LOSS: 4.139972244978535e-06


100%|██████████| 84/84 [00:00<00:00, 142.54it/s]


TRAIN: EPOCH 67: MSE: 0.0271753919133473, KL_LOSS: 4.163259139507621e-06


100%|██████████| 84/84 [00:00<00:00, 143.08it/s]


TRAIN: EPOCH 68: MSE: 0.026898586412980444, KL_LOSS: 4.162618316578508e-06


100%|██████████| 84/84 [00:00<00:00, 140.63it/s]


TRAIN: EPOCH 69: MSE: 0.026956898914206596, KL_LOSS: 4.1677079978052505e-06


100%|██████████| 84/84 [00:00<00:00, 144.54it/s]


TRAIN: EPOCH 70: MSE: 0.026784513300905626, KL_LOSS: 4.17867697489696e-06


100%|██████████| 84/84 [00:00<00:00, 133.72it/s]


TRAIN: EPOCH 71: MSE: 0.026569874580239967, KL_LOSS: 4.199037134423547e-06


100%|██████████| 84/84 [00:00<00:00, 139.82it/s]


TRAIN: EPOCH 72: MSE: 0.026905100681774673, KL_LOSS: 4.2266654477749575e-06


100%|██████████| 84/84 [00:00<00:00, 147.22it/s]


TRAIN: EPOCH 73: MSE: 0.027355626912876255, KL_LOSS: 4.246769056711249e-06


100%|██████████| 84/84 [00:00<00:00, 138.74it/s]


TRAIN: EPOCH 74: MSE: 0.026934191429366667, KL_LOSS: 4.283175268851448e-06


100%|██████████| 84/84 [00:00<00:00, 139.84it/s]


TRAIN: EPOCH 75: MSE: 0.026483988721988032, KL_LOSS: 4.286936849612608e-06


100%|██████████| 84/84 [00:00<00:00, 153.08it/s]


TRAIN: EPOCH 76: MSE: 0.02635444501148803, KL_LOSS: 4.300104105173627e-06


100%|██████████| 84/84 [00:00<00:00, 154.11it/s]


TRAIN: EPOCH 77: MSE: 0.02609767250361897, KL_LOSS: 4.330520918453355e-06


100%|██████████| 84/84 [00:00<00:00, 153.89it/s]


TRAIN: EPOCH 78: MSE: 0.025973217251400154, KL_LOSS: 4.305170721506459e-06


100%|██████████| 84/84 [00:00<00:00, 152.97it/s]


TRAIN: EPOCH 79: MSE: 0.025859961397058907, KL_LOSS: 4.3207026335169625e-06


100%|██████████| 84/84 [00:00<00:00, 148.56it/s]


TRAIN: EPOCH 80: MSE: 0.025687839259349164, KL_LOSS: 4.309519979306926e-06


100%|██████████| 84/84 [00:00<00:00, 149.34it/s]


TRAIN: EPOCH 81: MSE: 0.025945373633432956, KL_LOSS: 4.333627768595588e-06


100%|██████████| 84/84 [00:00<00:00, 148.91it/s]


TRAIN: EPOCH 82: MSE: 0.025763563989173798, KL_LOSS: 4.333367620628928e-06


100%|██████████| 84/84 [00:00<00:00, 152.83it/s]


TRAIN: EPOCH 83: MSE: 0.025657529787470896, KL_LOSS: 4.32235719656438e-06


100%|██████████| 84/84 [00:00<00:00, 152.27it/s]


TRAIN: EPOCH 84: MSE: 0.025626759089174726, KL_LOSS: 4.351775628276214e-06


100%|██████████| 84/84 [00:00<00:00, 142.52it/s]


TRAIN: EPOCH 85: MSE: 0.025541537352615877, KL_LOSS: 4.3662444519603535e-06


100%|██████████| 84/84 [00:00<00:00, 150.84it/s]


TRAIN: EPOCH 86: MSE: 0.025366466337194044, KL_LOSS: 4.394876506402008e-06


100%|██████████| 84/84 [00:00<00:00, 152.40it/s]


TRAIN: EPOCH 87: MSE: 0.025181365487653585, KL_LOSS: 4.391435538764199e-06


100%|██████████| 84/84 [00:00<00:00, 153.19it/s]


TRAIN: EPOCH 88: MSE: 0.025361578024569013, KL_LOSS: 4.3894094525447665e-06


100%|██████████| 84/84 [00:00<00:00, 151.42it/s]


TRAIN: EPOCH 89: MSE: 0.024910116390812965, KL_LOSS: 4.392024582752388e-06


100%|██████████| 84/84 [00:00<00:00, 150.40it/s]


TRAIN: EPOCH 90: MSE: 0.024997664172024953, KL_LOSS: 4.4096925486363555e-06


100%|██████████| 84/84 [00:00<00:00, 150.25it/s]


TRAIN: EPOCH 91: MSE: 0.025079477489704176, KL_LOSS: 4.446507643608909e-06


100%|██████████| 84/84 [00:00<00:00, 151.42it/s]


TRAIN: EPOCH 92: MSE: 0.02472753224096128, KL_LOSS: 4.4513116839490455e-06


100%|██████████| 84/84 [00:00<00:00, 150.31it/s]


TRAIN: EPOCH 93: MSE: 0.02499489714613273, KL_LOSS: 4.4522206264576066e-06


100%|██████████| 84/84 [00:00<00:00, 154.68it/s]


TRAIN: EPOCH 94: MSE: 0.0251253598946191, KL_LOSS: 4.45021124949649e-06


100%|██████████| 84/84 [00:00<00:00, 152.94it/s]


TRAIN: EPOCH 95: MSE: 0.02505794389262086, KL_LOSS: 4.476582440864669e-06


100%|██████████| 84/84 [00:00<00:00, 146.17it/s]


TRAIN: EPOCH 96: MSE: 0.02486263902946597, KL_LOSS: 4.459635152085996e-06


100%|██████████| 84/84 [00:00<00:00, 147.50it/s]


TRAIN: EPOCH 97: MSE: 0.024746468312860953, KL_LOSS: 4.4826613275074035e-06


100%|██████████| 84/84 [00:00<00:00, 146.03it/s]


TRAIN: EPOCH 98: MSE: 0.024858988865855195, KL_LOSS: 4.482159029283225e-06


100%|██████████| 84/84 [00:00<00:00, 153.66it/s]

TRAIN: EPOCH 99: MSE: 0.024649457451665684, KL_LOSS: 4.511997406639455e-06





In [14]:
autoencoder.eval()
encoded_x=[]
cs=[]
for rec in tqdm.tqdm(dataset):
    x = rec['x'].reshape(1,-1).to(device) # GPU
    c = rec['c'].reshape(1,).to(device) # GPU
    encoded_x.append(autoencoder.encoder(x).cpu().detach().numpy())
    cs.append(c.cpu().detach().numpy())
encoded_x=np.concatenate(encoded_x,axis=0)
encoded_x=(encoded_x-encoded_x.mean(axis=0,keepdims=True))/encoded_x.std(axis=0,keepdims=True)
cs=np.concatenate(cs,axis=0)
df_to_be_shown=pd.DataFrame(encoded_x,columns=[f'f{i}' for i in range(encoded_x.shape[1])])
df_to_be_shown['control']=cs

  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}
100%|██████████| 2679/2679 [00:02<00:00, 1221.12it/s]


### We get the cosine similarity of every two perturabation

In [15]:
cos_sim_f=cosine_similarity(np.array(df_to_be_shown.drop(['control'], axis=1)))

### Now we want to know which two perturabations are similar

In [16]:
similarity_matrix=np.zeros(cos_sim_f.shape)
similarity_db=hu_data_loader()

for gene_name in tqdm.tqdm(adata_orig.obs.gene_name.unique()):
    query=query_hu_data(similarity_db,gene_name)
    for q in query:
        if q in adata_orig.obs.gene_name.values:
            y_indices=adata_orig.obs[adata_orig.obs.gene_name==q].id
            x_indices=adata_orig.obs[adata_orig.obs.gene_name==gene_name].id
            for x_id in x_indices:
                for y_id in y_indices:
                    similarity_matrix[y_id,x_id]=1
                    similarity_matrix[x_id,y_id]=1

cos_sim_f_flatten=cos_sim_f.reshape(-1,)
similarity_matrix_flatten=similarity_matrix.reshape(-1,)
cos_sim_f_flatten1=cos_sim_f_flatten[similarity_matrix_flatten==1]
cos_sim_f_flatten0=cos_sim_f_flatten[similarity_matrix_flatten==0]

File downloaded successfully to humap2_complexes_20200809.txt


100%|██████████| 2394/2394 [01:22<00:00, 28.99it/s] 


### We want to visualize the value of recall with respect to different quantiles as thresholds for similarities

In [17]:
def get_recall(rate):
    qrate_down=np.quantile(cos_sim_f_flatten,rate)
    qrate_up=np.quantile(cos_sim_f_flatten,1-rate)
    pred_p=np.logical_or(cos_sim_f_flatten>qrate_up,cos_sim_f_flatten<qrate_down)
    pred_n=np.logical_and(cos_sim_f_flatten<qrate_up,cos_sim_f_flatten>qrate_down)
    tp=np.logical_and(pred_p,similarity_matrix_flatten==1).sum()
    fp=np.logical_and(pred_p,similarity_matrix_flatten==0).sum()
    fn=np.logical_and(pred_n,similarity_matrix_flatten==1).sum()
    return tp/(tp+fn)
def visualize_recal_vs_quantile():
    values=[]
    xs=[i*0.05 for i in range(10)]
    for i in xs:
        values.append(get_recall(i))
    temp_df=pd.DataFrame({'quantile':xs,'recall':values})
    fig=px.line(temp_df,x='quantile',y='recall',title='recall_vs_quantile',width=1000, height=400)
    fig.update_traces(mode='lines+text', text=list(map(lambda x:round(x,2),values)), textposition='top center')
    fig.update_layout(
    font=dict(
        family="Arial, sans-serif",
        size=10,  # Set the desired font size
        color="black"
    )
)
    fig.show()
    

In [18]:
visualize_recal_vs_quantile()

### Okay. now we plot the distributions of similarities divided into classes. first, those pairs that are already know to be similar. Second those that are not.

In [19]:
choice=np.random.choice(cos_sim_f_flatten1.shape[0], 2048)
cos_sim_f_flatten1=cos_sim_f_flatten1[choice]
cos_sim_f_flatten1=pd.DataFrame(cos_sim_f_flatten1,columns=['correlations'])
fig=px.violin(cos_sim_f_flatten1, y='correlations',width=500, height=400,title="SIMILARS")
fig.show()


choice=np.random.choice(cos_sim_f_flatten0.shape[0], 2048)
cos_sim_f_flatten0=cos_sim_f_flatten0[choice]
cos_sim_f_flatten0=pd.DataFrame(cos_sim_f_flatten0,columns=['correlations'])
fig=px.violin(cos_sim_f_flatten0, y='correlations',width=500, height=400,title="Not SIMILARS")
fig.show()

In [20]:
print("Not SIMILARS MEAN:",cos_sim_f_flatten0.mean())
print("SIMILARS MEAN:",cos_sim_f_flatten1.mean())

Not SIMILARS MEAN: correlations    0.014843
dtype: float32
SIMILARS MEAN: correlations    0.46017
dtype: float32


### Here is the visualiztion of feature vectors

In [21]:
fig=px.scatter(df_to_be_shown,x='f0',y='f1',color='control',width=500, height=400)
fig.show()
fig=px.scatter(df_to_be_shown,x='f2',y='f3',color='control',width=500, height=400)
fig.show()
fig=px.scatter(df_to_be_shown,x='f4',y='f5',color='control',width=500, height=400)
fig.show()











