In [1]:
import sys
sys.path.append('..')
from data.data_reader import *

In [2]:
import tqdm
import scanpy as sc
import pandas as pd

In [3]:
download_file('https://plus.figshare.com/ndownloader/files/35775512','35775512.h5ad')
adata_orig = sc.read_h5ad("35775512.h5ad")
adata_orig.X[adata_orig.X == float("inf")]=0

File downloaded successfully to 35775512.h5ad


In [4]:
adata_orig.obs['gene_name']=list(pd.Series(adata_orig.obs.index).apply(lambda x:x.split("_")[1]))
adata_orig.obs['id']=range(adata_orig.obs.shape[0])

In [5]:
import tqdm
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [6]:
def cosine_similarity(A):
  AAt=np.matmul(A,A.transpose())
  n_A=np.sqrt((A**2).sum(axis=1)).reshape(-1,1)
  n_A=np.matmul(n_A,n_A.transpose())
  return AAt/(n_A)

In [7]:
class X_dataset(Dataset):
    def __init__(self,data):
        self.data=data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}


In [8]:
dataset=X_dataset(adata_orig)
train_loader=DataLoader(dataset,batch_size=32,shuffle=True)

In [9]:
BASENUM=512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Encoder(nn.Module):
    def __init__(self, latent_dim=10,kl_coef=0.000001):
        super(Encoder, self).__init__()
        self.latent_dim=latent_dim
        self.kl_coef=kl_coef
        self.dense1=nn.Linear(dataset[0]['x'].shape[0],BASENUM)
        self.bn1=nn.BatchNorm1d(BASENUM)
        self.dense2=nn.Linear(BASENUM,BASENUM//4)
        self.bn2=nn.BatchNorm1d(BASENUM//4)
        self.mu=nn.Linear(BASENUM//4, latent_dim)
        self.logvar=nn.Linear(BASENUM//4, latent_dim)
        self.kl = 0
    def reparameterize(self, mu , logvar):
        std = torch.exp(logvar*0.5)
        eps = torch.randn_like(std).to(device)
        z = mu + eps * std
        return z
    def forward(self, x):
        bn=x.size(0)
        x=F.relu(self.bn1(self.dense1(x)))
        x=F.relu(self.bn2(self.dense2(x)))
        mu =  self.mu(x)
        logvar = self.logvar(x)
        z=self.reparameterize(mu , logvar)
        self.kl = 0.5*(logvar.exp() + mu**2 - logvar - 1).sum()*self.kl_coef
        return z


class Decoder(nn.Module):
    def __init__(self, latent_dim=8):
        super(Decoder, self).__init__()
        self.dense1=nn.Linear(latent_dim,BASENUM//4)
        self.bn1=nn.BatchNorm1d(BASENUM//4)
        self.dense2=nn.Linear(BASENUM//4,BASENUM)
        self.bn2=nn.BatchNorm1d(BASENUM)
        self.out=nn.Linear(BASENUM,dataset[0]['x'].shape[0])

    def forward(self, z):
        z = F.relu(self.bn1(self.dense1(z)))
        z = F.relu(self.bn2(self.dense2(z)))
        z = self.out(z)
        return z

class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims=10,kl_coef=0.000001):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(latent_dims,kl_coef).to(device)
        self.decoder = Decoder(latent_dims).to(device)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)


In [10]:
# autoencoder=VariationalAutoencoder(20,0.1)
autoencoder=VariationalAutoencoder(20,1e-9)
opt = torch.optim.Adam(autoencoder.parameters(),lr=0.001)
loss_fn=torch.nn.MSELoss()


  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}


In [11]:
autoencoder.train()
for epoch in range(100):
    train_loss1=0
    train_loss2=0
    autoencoder.train()
    for batch in tqdm.tqdm(train_loader):
        x = batch['x'].to(device) # GPU
        c = batch['c'].to(device) # GPU
        opt.zero_grad()
        x_hat = autoencoder(x)
        loss1=loss_fn(x_hat,x)
        loss2=autoencoder.encoder.kl
        loss = loss1 + loss2
        train_loss1+=loss1.detach().cpu().numpy()
        train_loss2+=loss2.detach().cpu().numpy()
        loss.backward()
        opt.step()
    print(f"TRAIN: EPOCH {epoch}: MSE: {train_loss1/len(train_loader)}, KL_LOSS: {train_loss2/len(train_loader)}")



  0%|          | 0/84 [00:00<?, ?it/s]

  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}
100%|██████████| 84/84 [00:01<00:00, 77.97it/s] 


TRAIN: EPOCH 0: MSE: 0.07920623238065413, KL_LOSS: 8.315684678416617e-07


100%|██████████| 84/84 [00:00<00:00, 156.88it/s]


TRAIN: EPOCH 1: MSE: 0.05247495205895532, KL_LOSS: 1.5055436655248264e-06


100%|██████████| 84/84 [00:00<00:00, 158.16it/s]


TRAIN: EPOCH 2: MSE: 0.04935496491158292, KL_LOSS: 1.758328740682676e-06


100%|██████████| 84/84 [00:00<00:00, 157.38it/s]


TRAIN: EPOCH 3: MSE: 0.04722175343583027, KL_LOSS: 1.9392740839629413e-06


100%|██████████| 84/84 [00:00<00:00, 152.61it/s]


TRAIN: EPOCH 4: MSE: 0.04672922384703443, KL_LOSS: 2.018308330369629e-06


100%|██████████| 84/84 [00:00<00:00, 153.53it/s]


TRAIN: EPOCH 5: MSE: 0.04534792760387063, KL_LOSS: 2.1557219954935308e-06


100%|██████████| 84/84 [00:00<00:00, 153.68it/s]


TRAIN: EPOCH 6: MSE: 0.04395974356503714, KL_LOSS: 2.194646459061221e-06


100%|██████████| 84/84 [00:00<00:00, 136.16it/s]


TRAIN: EPOCH 7: MSE: 0.04292813706256095, KL_LOSS: 2.3617932031681576e-06


100%|██████████| 84/84 [00:00<00:00, 114.68it/s]


TRAIN: EPOCH 8: MSE: 0.041823065595790035, KL_LOSS: 2.434106087483607e-06


100%|██████████| 84/84 [00:00<00:00, 153.78it/s]


TRAIN: EPOCH 9: MSE: 0.04154249134340456, KL_LOSS: 2.5198372451909347e-06


100%|██████████| 84/84 [00:00<00:00, 121.54it/s]


TRAIN: EPOCH 10: MSE: 0.04128543611261107, KL_LOSS: 2.5808121032847944e-06


100%|██████████| 84/84 [00:00<00:00, 118.33it/s]


TRAIN: EPOCH 11: MSE: 0.040365570513088075, KL_LOSS: 2.633674934518889e-06


100%|██████████| 84/84 [00:00<00:00, 149.45it/s]


TRAIN: EPOCH 12: MSE: 0.04019630598347811, KL_LOSS: 2.6485552138767943e-06


100%|██████████| 84/84 [00:00<00:00, 99.13it/s] 


TRAIN: EPOCH 13: MSE: 0.03935144710842343, KL_LOSS: 2.71779326848922e-06


100%|██████████| 84/84 [00:00<00:00, 126.61it/s]


TRAIN: EPOCH 14: MSE: 0.03860723739489913, KL_LOSS: 2.7569388326550585e-06


100%|██████████| 84/84 [00:00<00:00, 135.26it/s]


TRAIN: EPOCH 15: MSE: 0.03806936794093677, KL_LOSS: 2.797959098188585e-06


100%|██████████| 84/84 [00:00<00:00, 152.33it/s]


TRAIN: EPOCH 16: MSE: 0.03723400600609325, KL_LOSS: 2.830155616655767e-06


100%|██████████| 84/84 [00:00<00:00, 148.15it/s]


TRAIN: EPOCH 17: MSE: 0.03685428430548027, KL_LOSS: 2.8687068612227794e-06


100%|██████████| 84/84 [00:00<00:00, 151.45it/s]


TRAIN: EPOCH 18: MSE: 0.03692592620583517, KL_LOSS: 2.9334173659446517e-06


100%|██████████| 84/84 [00:00<00:00, 150.67it/s]


TRAIN: EPOCH 19: MSE: 0.037227107877177854, KL_LOSS: 2.968798929443784e-06


100%|██████████| 84/84 [00:00<00:00, 152.15it/s]


TRAIN: EPOCH 20: MSE: 0.0363629384941998, KL_LOSS: 3.006516906333716e-06


100%|██████████| 84/84 [00:00<00:00, 129.19it/s]


TRAIN: EPOCH 21: MSE: 0.03580966578530414, KL_LOSS: 3.085038017724444e-06


100%|██████████| 84/84 [00:00<00:00, 148.18it/s]


TRAIN: EPOCH 22: MSE: 0.035045697464652005, KL_LOSS: 3.105029553093378e-06


100%|██████████| 84/84 [00:00<00:00, 152.61it/s]


TRAIN: EPOCH 23: MSE: 0.034680941164316165, KL_LOSS: 3.148265399150183e-06


100%|██████████| 84/84 [00:00<00:00, 140.76it/s]


TRAIN: EPOCH 24: MSE: 0.03419930867052504, KL_LOSS: 3.185032583241125e-06


100%|██████████| 84/84 [00:00<00:00, 118.97it/s]


TRAIN: EPOCH 25: MSE: 0.03414485780965714, KL_LOSS: 3.2143388791285233e-06


100%|██████████| 84/84 [00:00<00:00, 115.75it/s]


TRAIN: EPOCH 26: MSE: 0.03370320745965555, KL_LOSS: 3.2209948049707497e-06


100%|██████████| 84/84 [00:00<00:00, 103.89it/s]


TRAIN: EPOCH 27: MSE: 0.03349873822714601, KL_LOSS: 3.2656369653306595e-06


100%|██████████| 84/84 [00:00<00:00, 93.26it/s] 


TRAIN: EPOCH 28: MSE: 0.033860686368175914, KL_LOSS: 3.2871930887354427e-06


100%|██████████| 84/84 [00:01<00:00, 79.57it/s]


TRAIN: EPOCH 29: MSE: 0.03334149740458954, KL_LOSS: 3.2951879613089354e-06


100%|██████████| 84/84 [00:00<00:00, 118.15it/s]


TRAIN: EPOCH 30: MSE: 0.03293393942571822, KL_LOSS: 3.3427368532015646e-06


100%|██████████| 84/84 [00:00<00:00, 142.62it/s]


TRAIN: EPOCH 31: MSE: 0.03241873390617825, KL_LOSS: 3.3699051002821805e-06


100%|██████████| 84/84 [00:00<00:00, 150.67it/s]


TRAIN: EPOCH 32: MSE: 0.03188634105026722, KL_LOSS: 3.3746157360578114e-06


100%|██████████| 84/84 [00:00<00:00, 136.91it/s]


TRAIN: EPOCH 33: MSE: 0.03221378850174092, KL_LOSS: 3.411932041109096e-06


100%|██████████| 84/84 [00:00<00:00, 150.73it/s]


TRAIN: EPOCH 34: MSE: 0.031842321600942386, KL_LOSS: 3.4448088047697854e-06


100%|██████████| 84/84 [00:00<00:00, 146.05it/s]


TRAIN: EPOCH 35: MSE: 0.031520582580318056, KL_LOSS: 3.4668723477201033e-06


100%|██████████| 84/84 [00:00<00:00, 143.87it/s]


TRAIN: EPOCH 36: MSE: 0.032819858707842375, KL_LOSS: 3.525177819916453e-06


100%|██████████| 84/84 [00:00<00:00, 142.50it/s]


TRAIN: EPOCH 37: MSE: 0.03173920942381734, KL_LOSS: 3.5598869338909503e-06


100%|██████████| 84/84 [00:00<00:00, 149.30it/s]


TRAIN: EPOCH 38: MSE: 0.03175708917634828, KL_LOSS: 3.5546070139822577e-06


100%|██████████| 84/84 [00:00<00:00, 148.10it/s]


TRAIN: EPOCH 39: MSE: 0.030706984656197683, KL_LOSS: 3.560465954489614e-06


100%|██████████| 84/84 [00:00<00:00, 154.13it/s]


TRAIN: EPOCH 40: MSE: 0.03120336408859917, KL_LOSS: 3.639887944743913e-06


100%|██████████| 84/84 [00:00<00:00, 153.92it/s]


TRAIN: EPOCH 41: MSE: 0.030468882704597144, KL_LOSS: 3.6595685709731437e-06


100%|██████████| 84/84 [00:00<00:00, 156.95it/s]


TRAIN: EPOCH 42: MSE: 0.030720806991060574, KL_LOSS: 3.685399194962104e-06


100%|██████████| 84/84 [00:00<00:00, 148.42it/s]


TRAIN: EPOCH 43: MSE: 0.030105477797665765, KL_LOSS: 3.709909257005555e-06


100%|██████████| 84/84 [00:00<00:00, 124.48it/s]


TRAIN: EPOCH 44: MSE: 0.029641647534888415, KL_LOSS: 3.741474686302397e-06


100%|██████████| 84/84 [00:00<00:00, 133.43it/s]


TRAIN: EPOCH 45: MSE: 0.029741298739931414, KL_LOSS: 3.7585360192202815e-06


100%|██████████| 84/84 [00:00<00:00, 155.42it/s]


TRAIN: EPOCH 46: MSE: 0.029986015166200343, KL_LOSS: 3.7827436769822226e-06


100%|██████████| 84/84 [00:00<00:00, 149.16it/s]


TRAIN: EPOCH 47: MSE: 0.029393992431107022, KL_LOSS: 3.8065483176112704e-06


100%|██████████| 84/84 [00:00<00:00, 154.63it/s]


TRAIN: EPOCH 48: MSE: 0.029229032430088238, KL_LOSS: 3.81850191474985e-06


100%|██████████| 84/84 [00:00<00:00, 157.83it/s]


TRAIN: EPOCH 49: MSE: 0.02865310104209043, KL_LOSS: 3.854109972204848e-06


100%|██████████| 84/84 [00:00<00:00, 155.73it/s]


TRAIN: EPOCH 50: MSE: 0.029021345633303837, KL_LOSS: 3.889811003838466e-06


100%|██████████| 84/84 [00:00<00:00, 149.95it/s]


TRAIN: EPOCH 51: MSE: 0.028881196597857133, KL_LOSS: 3.899767891977847e-06


100%|██████████| 84/84 [00:00<00:00, 151.58it/s]


TRAIN: EPOCH 52: MSE: 0.028839012962721643, KL_LOSS: 3.927483541182093e-06


100%|██████████| 84/84 [00:00<00:00, 153.33it/s]


TRAIN: EPOCH 53: MSE: 0.028708185490575574, KL_LOSS: 3.950052652206588e-06


100%|██████████| 84/84 [00:00<00:00, 151.24it/s]


TRAIN: EPOCH 54: MSE: 0.028672941893871342, KL_LOSS: 3.966574933569757e-06


100%|██████████| 84/84 [00:00<00:00, 154.90it/s]


TRAIN: EPOCH 55: MSE: 0.028618611506230775, KL_LOSS: 4.000783825458279e-06


100%|██████████| 84/84 [00:00<00:00, 156.79it/s]


TRAIN: EPOCH 56: MSE: 0.0281431870978503, KL_LOSS: 3.99292579394179e-06


100%|██████████| 84/84 [00:00<00:00, 155.66it/s]


TRAIN: EPOCH 57: MSE: 0.028100160083600452, KL_LOSS: 4.032728197346712e-06


100%|██████████| 84/84 [00:00<00:00, 157.45it/s]


TRAIN: EPOCH 58: MSE: 0.027678012204844327, KL_LOSS: 4.044602782133075e-06


100%|██████████| 84/84 [00:00<00:00, 155.20it/s]


TRAIN: EPOCH 59: MSE: 0.027892816346138716, KL_LOSS: 4.072617468322971e-06


100%|██████████| 84/84 [00:00<00:00, 155.30it/s]


TRAIN: EPOCH 60: MSE: 0.028093901063714708, KL_LOSS: 4.0717967792824355e-06


100%|██████████| 84/84 [00:00<00:00, 158.46it/s]


TRAIN: EPOCH 61: MSE: 0.027566759137525446, KL_LOSS: 4.0893422340646105e-06


100%|██████████| 84/84 [00:00<00:00, 155.75it/s]


TRAIN: EPOCH 62: MSE: 0.027459649258248862, KL_LOSS: 4.121207815560843e-06


100%|██████████| 84/84 [00:00<00:00, 152.96it/s]


TRAIN: EPOCH 63: MSE: 0.027451623297695602, KL_LOSS: 4.153242776911134e-06


100%|██████████| 84/84 [00:00<00:00, 146.12it/s]


TRAIN: EPOCH 64: MSE: 0.027167655882381257, KL_LOSS: 4.173955395409783e-06


100%|██████████| 84/84 [00:00<00:00, 142.72it/s]


TRAIN: EPOCH 65: MSE: 0.027299688131149327, KL_LOSS: 4.192524586631494e-06


100%|██████████| 84/84 [00:00<00:00, 147.74it/s]


TRAIN: EPOCH 66: MSE: 0.026870780314008396, KL_LOSS: 4.207211494152337e-06


100%|██████████| 84/84 [00:00<00:00, 158.78it/s]


TRAIN: EPOCH 67: MSE: 0.026863344811967442, KL_LOSS: 4.2290254728392205e-06


100%|██████████| 84/84 [00:00<00:00, 160.25it/s]


TRAIN: EPOCH 68: MSE: 0.026719442613068082, KL_LOSS: 4.231767047171854e-06


100%|██████████| 84/84 [00:00<00:00, 159.33it/s]


TRAIN: EPOCH 69: MSE: 0.02698106924071908, KL_LOSS: 4.248462985179642e-06


100%|██████████| 84/84 [00:00<00:00, 151.67it/s]


TRAIN: EPOCH 70: MSE: 0.026550948176355588, KL_LOSS: 4.271866443906176e-06


100%|██████████| 84/84 [00:00<00:00, 160.46it/s]


TRAIN: EPOCH 71: MSE: 0.026442328645359902, KL_LOSS: 4.278307436661202e-06


100%|██████████| 84/84 [00:00<00:00, 154.46it/s]


TRAIN: EPOCH 72: MSE: 0.026405746848987683, KL_LOSS: 4.293321072033168e-06


100%|██████████| 84/84 [00:00<00:00, 155.31it/s]


TRAIN: EPOCH 73: MSE: 0.026370073890402204, KL_LOSS: 4.301511169268495e-06


100%|██████████| 84/84 [00:00<00:00, 155.75it/s]


TRAIN: EPOCH 74: MSE: 0.026227649345639207, KL_LOSS: 4.3005460600484405e-06


100%|██████████| 84/84 [00:00<00:00, 153.74it/s]


TRAIN: EPOCH 75: MSE: 0.026699292761761518, KL_LOSS: 4.327592139794606e-06


100%|██████████| 84/84 [00:00<00:00, 158.74it/s]


TRAIN: EPOCH 76: MSE: 0.0261570398828813, KL_LOSS: 4.358828050499142e-06


100%|██████████| 84/84 [00:00<00:00, 159.12it/s]


TRAIN: EPOCH 77: MSE: 0.026089385196211793, KL_LOSS: 4.340006379758658e-06


100%|██████████| 84/84 [00:00<00:00, 156.55it/s]


TRAIN: EPOCH 78: MSE: 0.025413784841519026, KL_LOSS: 4.350413048216849e-06


100%|██████████| 84/84 [00:00<00:00, 151.23it/s]


TRAIN: EPOCH 79: MSE: 0.02592531835571641, KL_LOSS: 4.378761148126657e-06


100%|██████████| 84/84 [00:00<00:00, 155.03it/s]


TRAIN: EPOCH 80: MSE: 0.026893571351787875, KL_LOSS: 4.447099194121187e-06


100%|██████████| 84/84 [00:00<00:00, 158.36it/s]


TRAIN: EPOCH 81: MSE: 0.025768562308734373, KL_LOSS: 4.422165044984251e-06


100%|██████████| 84/84 [00:00<00:00, 159.66it/s]


TRAIN: EPOCH 82: MSE: 0.026328560940566518, KL_LOSS: 4.464644416116027e-06


100%|██████████| 84/84 [00:00<00:00, 152.52it/s]


TRAIN: EPOCH 83: MSE: 0.02551392306174551, KL_LOSS: 4.453676722061486e-06


100%|██████████| 84/84 [00:00<00:00, 160.78it/s]


TRAIN: EPOCH 84: MSE: 0.025331471935801562, KL_LOSS: 4.441778972228522e-06


100%|██████████| 84/84 [00:00<00:00, 151.62it/s]


TRAIN: EPOCH 85: MSE: 0.02509396057575941, KL_LOSS: 4.484212248761261e-06


100%|██████████| 84/84 [00:00<00:00, 160.01it/s]


TRAIN: EPOCH 86: MSE: 0.025176121675897212, KL_LOSS: 4.502020057576114e-06


100%|██████████| 84/84 [00:00<00:00, 157.08it/s]


TRAIN: EPOCH 87: MSE: 0.02481955804285549, KL_LOSS: 4.532322447760505e-06


100%|██████████| 84/84 [00:00<00:00, 159.17it/s]


TRAIN: EPOCH 88: MSE: 0.024751161446883566, KL_LOSS: 4.539260108534411e-06


100%|██████████| 84/84 [00:00<00:00, 159.98it/s]


TRAIN: EPOCH 89: MSE: 0.02482182461590994, KL_LOSS: 4.5456819991938706e-06


100%|██████████| 84/84 [00:00<00:00, 160.05it/s]


TRAIN: EPOCH 90: MSE: 0.02478150138631463, KL_LOSS: 4.549065484601063e-06


100%|██████████| 84/84 [00:00<00:00, 156.51it/s]


TRAIN: EPOCH 91: MSE: 0.02444187006247895, KL_LOSS: 4.547462714028716e-06


100%|██████████| 84/84 [00:00<00:00, 154.47it/s]


TRAIN: EPOCH 92: MSE: 0.02426132032026847, KL_LOSS: 4.5444788974063635e-06


100%|██████████| 84/84 [00:00<00:00, 156.88it/s]


TRAIN: EPOCH 93: MSE: 0.02427459525920096, KL_LOSS: 4.514488158032951e-06


100%|██████████| 84/84 [00:00<00:00, 146.96it/s]


TRAIN: EPOCH 94: MSE: 0.024151814524971303, KL_LOSS: 4.480702646755422e-06


100%|██████████| 84/84 [00:00<00:00, 154.33it/s]


TRAIN: EPOCH 95: MSE: 0.024398053907567545, KL_LOSS: 4.510190928201718e-06


100%|██████████| 84/84 [00:00<00:00, 155.42it/s]


TRAIN: EPOCH 96: MSE: 0.025089822044329985, KL_LOSS: 4.5096985180255915e-06


100%|██████████| 84/84 [00:00<00:00, 152.71it/s]


TRAIN: EPOCH 97: MSE: 0.024667176950190748, KL_LOSS: 4.4905209614669385e-06


100%|██████████| 84/84 [00:00<00:00, 152.73it/s]


TRAIN: EPOCH 98: MSE: 0.024192569333882558, KL_LOSS: 4.492937743331519e-06


100%|██████████| 84/84 [00:00<00:00, 155.92it/s]

TRAIN: EPOCH 99: MSE: 0.024254227012750647, KL_LOSS: 4.517291642969212e-06





In [12]:
autoencoder.eval()
encoded_x=[]
cs=[]
for rec in tqdm.tqdm(dataset):
    x = rec['x'].reshape(1,-1).to(device) # GPU
    c = rec['c'].reshape(1,).to(device) # GPU
    encoded_x.append(autoencoder.encoder(x).cpu().detach().numpy())
    cs.append(c.cpu().detach().numpy())
encoded_x=np.concatenate(encoded_x,axis=0)
encoded_x=(encoded_x-encoded_x.mean(axis=0,keepdims=True))/encoded_x.std(axis=0,keepdims=True)
cs=np.concatenate(cs,axis=0)
df_to_be_shown=pd.DataFrame(encoded_x,columns=[f'f{i}' for i in range(encoded_x.shape[1])])
df_to_be_shown['control']=cs

  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}
100%|██████████| 2679/2679 [00:02<00:00, 1093.86it/s]


In [13]:
cos_sim_f=cosine_similarity(np.array(df_to_be_shown.drop(['control'], axis=1)))

In [14]:
similarity_matrix=np.zeros(cos_sim_f.shape)
similarity_db=hu_data_loader()

for gene_name in tqdm.tqdm(adata_orig.obs.gene_name.unique()):
    query=query_hu_data(similarity_db,gene_name)
    for q in query:
        if q in adata_orig.obs.gene_name.values:
            y_indices=adata_orig.obs[adata_orig.obs.gene_name==q].id
            x_indices=adata_orig.obs[adata_orig.obs.gene_name==gene_name].id
            for x_id in x_indices:
                for y_id in y_indices:
                    similarity_matrix[y_id,x_id]=1
                    similarity_matrix[x_id,y_id]=1

cos_sim_f_flatten=cos_sim_f.reshape(-1,)
similarity_matrix_flatten=similarity_matrix.reshape(-1,)
cos_sim_f_flatten1=cos_sim_f_flatten[similarity_matrix_flatten==1]
cos_sim_f_flatten0=cos_sim_f_flatten[similarity_matrix_flatten==0]

File downloaded successfully to humap2_complexes_20200809.txt


100%|██████████| 2394/2394 [00:31<00:00, 74.95it/s] 


In [15]:
choice=np.random.choice(cos_sim_f_flatten1.shape[0], 2048)
cos_sim_f_flatten1=cos_sim_f_flatten1[choice]
cos_sim_f_flatten1=pd.DataFrame(cos_sim_f_flatten1,columns=['correlations'])
fig=px.violin(cos_sim_f_flatten1, y='correlations',width=500, height=400,title="SIMILARS")
fig.show()


choice=np.random.choice(cos_sim_f_flatten0.shape[0], 2048)
cos_sim_f_flatten0=cos_sim_f_flatten0[choice]
cos_sim_f_flatten0=pd.DataFrame(cos_sim_f_flatten0,columns=['correlations'])
fig=px.violin(cos_sim_f_flatten0, y='correlations',width=500, height=400,title="Not SIMILARS")
fig.show()

In [17]:
print("Not SIMILARS MEAN:",cos_sim_f_flatten0.mean())
print("SIMILARS MEAN:",cos_sim_f_flatten1.mean())

Not SIMILARS MEAN: correlations    0.02002
dtype: float32
SIMILARS MEAN: correlations    0.436908
dtype: float32


In [18]:
fig=px.scatter(df_to_be_shown,x='f0',y='f1',color='control',width=500, height=400)
fig.show()
fig=px.scatter(df_to_be_shown,x='f2',y='f3',color='control',width=500, height=400)
fig.show()
fig=px.scatter(df_to_be_shown,x='f4',y='f5',color='control',width=500, height=400)
fig.show()











