In [1]:
import sys
sys.path.append('..')
from data.data_reader import *
from models.VAEs import *

In [2]:
import tqdm
import scanpy as sc
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.figure_factory as ff

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [3]:
download_file('https://plus.figshare.com/ndownloader/files/35775512','35775512.h5ad')
adata_orig = sc.read_h5ad("35775512.h5ad")
adata_orig.X[adata_orig.X == float("inf")]=0

File downloaded successfully to 35775512.h5ad


In [4]:
adata_orig.obs['gene_name']=list(pd.Series(adata_orig.obs.index).apply(lambda x:x.split("_")[1]))
adata_orig.obs['id']=range(adata_orig.obs.shape[0])

In [5]:
def cosine_similarity(A):
  AAt=np.matmul(A,A.transpose())
  n_A=np.sqrt((A**2).sum(axis=1)).reshape(-1,1)
  n_A=np.matmul(n_A,n_A.transpose())
  return AAt/(n_A)

### Let's first train a VAE model

In [6]:
class X_dataset(Dataset):
    def __init__(self,data):
        self.data=data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}


In [7]:
dataset=X_dataset(adata_orig)
train_loader=DataLoader(dataset,batch_size=32,shuffle=True)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
autoencoder=VariationalAutoencoder(dataset[0]['x'].shape[0],20,1e-9,512,device)
opt = torch.optim.Adam(autoencoder.parameters(),lr=0.001)
loss_fn=torch.nn.MSELoss()

  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}


In [9]:
train(autoencoder,opt,loss_fn,train_loader,None,device,100)

  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}
100%|██████████| 84/84 [00:02<00:00, 38.16it/s]


TRAIN: EPOCH 0: MSE: 0.07803193923263323, KL_LOSS: 8.336109025497947e-07


100%|██████████| 84/84 [00:01<00:00, 47.86it/s]


TRAIN: EPOCH 1: MSE: 0.052132335801919304, KL_LOSS: 1.4956077715548115e-06


100%|██████████| 84/84 [00:01<00:00, 45.25it/s]


TRAIN: EPOCH 2: MSE: 0.05028229949641086, KL_LOSS: 1.6773652530013179e-06


100%|██████████| 84/84 [00:01<00:00, 44.90it/s]


TRAIN: EPOCH 3: MSE: 0.04731133424987396, KL_LOSS: 1.8496417015155333e-06


100%|██████████| 84/84 [00:01<00:00, 45.51it/s]


TRAIN: EPOCH 4: MSE: 0.04652115835675171, KL_LOSS: 1.9736212664156483e-06


100%|██████████| 84/84 [00:01<00:00, 51.36it/s]


TRAIN: EPOCH 5: MSE: 0.04481830042121666, KL_LOSS: 2.101219277641403e-06


100%|██████████| 84/84 [00:01<00:00, 54.89it/s]


TRAIN: EPOCH 6: MSE: 0.04377543966152838, KL_LOSS: 2.1690019864417126e-06


100%|██████████| 84/84 [00:01<00:00, 53.43it/s]


TRAIN: EPOCH 7: MSE: 0.043423653314156194, KL_LOSS: 2.3008085801869866e-06


100%|██████████| 84/84 [00:01<00:00, 53.57it/s]


TRAIN: EPOCH 8: MSE: 0.042465624488180594, KL_LOSS: 2.3510202181244366e-06


100%|██████████| 84/84 [00:01<00:00, 52.56it/s]


TRAIN: EPOCH 9: MSE: 0.04141006534475656, KL_LOSS: 2.4148376207038037e-06


100%|██████████| 84/84 [00:01<00:00, 52.61it/s]


TRAIN: EPOCH 10: MSE: 0.04117049514094279, KL_LOSS: 2.4646782592407406e-06


100%|██████████| 84/84 [00:01<00:00, 53.43it/s]


TRAIN: EPOCH 11: MSE: 0.040496804751455784, KL_LOSS: 2.5401277458145376e-06


100%|██████████| 84/84 [00:01<00:00, 52.15it/s]


TRAIN: EPOCH 12: MSE: 0.03938284605031922, KL_LOSS: 2.6149874737971653e-06


100%|██████████| 84/84 [00:01<00:00, 51.36it/s]


TRAIN: EPOCH 13: MSE: 0.038876427204481194, KL_LOSS: 2.6395428238044716e-06


100%|██████████| 84/84 [00:01<00:00, 54.56it/s]


TRAIN: EPOCH 14: MSE: 0.038668250842463406, KL_LOSS: 2.694645442737792e-06


100%|██████████| 84/84 [00:01<00:00, 53.88it/s]


TRAIN: EPOCH 15: MSE: 0.037796587528040014, KL_LOSS: 2.7487795093639717e-06


100%|██████████| 84/84 [00:01<00:00, 53.53it/s]


TRAIN: EPOCH 16: MSE: 0.037238503784118665, KL_LOSS: 2.8029438248094743e-06


100%|██████████| 84/84 [00:01<00:00, 57.14it/s]


TRAIN: EPOCH 17: MSE: 0.037300273132998316, KL_LOSS: 2.848239197256979e-06


100%|██████████| 84/84 [00:01<00:00, 53.76it/s]


TRAIN: EPOCH 18: MSE: 0.03672049819890942, KL_LOSS: 2.853542663875227e-06


100%|██████████| 84/84 [00:01<00:00, 49.31it/s]


TRAIN: EPOCH 19: MSE: 0.035756708322358985, KL_LOSS: 2.891792105076872e-06


100%|██████████| 84/84 [00:01<00:00, 48.38it/s]


TRAIN: EPOCH 20: MSE: 0.03554080734916386, KL_LOSS: 2.928577311238422e-06


100%|██████████| 84/84 [00:01<00:00, 45.43it/s]


TRAIN: EPOCH 21: MSE: 0.03558500661026864, KL_LOSS: 2.9821020082939262e-06


100%|██████████| 84/84 [00:01<00:00, 47.63it/s]


TRAIN: EPOCH 22: MSE: 0.035412756216135766, KL_LOSS: 2.9908363054736194e-06


100%|██████████| 84/84 [00:01<00:00, 51.14it/s]


TRAIN: EPOCH 23: MSE: 0.034792528997751926, KL_LOSS: 3.033481261484537e-06


100%|██████████| 84/84 [00:01<00:00, 50.88it/s]


TRAIN: EPOCH 24: MSE: 0.03429589540298496, KL_LOSS: 3.0747830454033863e-06


100%|██████████| 84/84 [00:01<00:00, 48.63it/s]


TRAIN: EPOCH 25: MSE: 0.03400362491430271, KL_LOSS: 3.107355718427806e-06


100%|██████████| 84/84 [00:01<00:00, 50.00it/s]


TRAIN: EPOCH 26: MSE: 0.03363494530674957, KL_LOSS: 3.1606321883936333e-06


100%|██████████| 84/84 [00:01<00:00, 49.91it/s]


TRAIN: EPOCH 27: MSE: 0.0332290266551787, KL_LOSS: 3.1853356588163434e-06


100%|██████████| 84/84 [00:01<00:00, 48.37it/s]


TRAIN: EPOCH 28: MSE: 0.03327818323547641, KL_LOSS: 3.2207766263950145e-06


100%|██████████| 84/84 [00:01<00:00, 49.90it/s]


TRAIN: EPOCH 29: MSE: 0.032947411051108724, KL_LOSS: 3.221787177868204e-06


100%|██████████| 84/84 [00:01<00:00, 50.51it/s]


TRAIN: EPOCH 30: MSE: 0.03272038394407857, KL_LOSS: 3.2437478585918487e-06


100%|██████████| 84/84 [00:01<00:00, 47.46it/s]


TRAIN: EPOCH 31: MSE: 0.0318013709038496, KL_LOSS: 3.2791702843886535e-06


100%|██████████| 84/84 [00:01<00:00, 50.51it/s]


TRAIN: EPOCH 32: MSE: 0.03190054509433962, KL_LOSS: 3.3139267083286196e-06


100%|██████████| 84/84 [00:01<00:00, 50.32it/s]


TRAIN: EPOCH 33: MSE: 0.032067826306003896, KL_LOSS: 3.3376408136323984e-06


100%|██████████| 84/84 [00:01<00:00, 54.96it/s]


TRAIN: EPOCH 34: MSE: 0.03167982359549829, KL_LOSS: 3.373257938966119e-06


100%|██████████| 84/84 [00:01<00:00, 52.60it/s]


TRAIN: EPOCH 35: MSE: 0.03140211429092146, KL_LOSS: 3.382506076290849e-06


100%|██████████| 84/84 [00:01<00:00, 50.34it/s]


TRAIN: EPOCH 36: MSE: 0.03128360930298056, KL_LOSS: 3.423025808453011e-06


100%|██████████| 84/84 [00:02<00:00, 38.12it/s]


TRAIN: EPOCH 37: MSE: 0.03137515139366899, KL_LOSS: 3.4565605373938436e-06


100%|██████████| 84/84 [00:02<00:00, 37.29it/s]


TRAIN: EPOCH 38: MSE: 0.03158539793054972, KL_LOSS: 3.5203048041477726e-06


100%|██████████| 84/84 [00:00<00:00, 99.58it/s] 


TRAIN: EPOCH 39: MSE: 0.030709447161782356, KL_LOSS: 3.547625353695323e-06


100%|██████████| 84/84 [00:01<00:00, 82.13it/s]


TRAIN: EPOCH 40: MSE: 0.03023679785075642, KL_LOSS: 3.547261336561427e-06


100%|██████████| 84/84 [00:00<00:00, 95.42it/s]


TRAIN: EPOCH 41: MSE: 0.030131454978670393, KL_LOSS: 3.585554782689758e-06


100%|██████████| 84/84 [00:00<00:00, 85.19it/s]


TRAIN: EPOCH 42: MSE: 0.03035229905730202, KL_LOSS: 3.6228018984729715e-06


100%|██████████| 84/84 [00:01<00:00, 79.19it/s]


TRAIN: EPOCH 43: MSE: 0.030112501411210923, KL_LOSS: 3.6487055067049196e-06


100%|██████████| 84/84 [00:00<00:00, 92.64it/s]


TRAIN: EPOCH 44: MSE: 0.029748639424464533, KL_LOSS: 3.650686922117616e-06


100%|██████████| 84/84 [00:00<00:00, 94.82it/s]


TRAIN: EPOCH 45: MSE: 0.029550821653434207, KL_LOSS: 3.6791470252172164e-06


100%|██████████| 84/84 [00:00<00:00, 102.19it/s]


TRAIN: EPOCH 46: MSE: 0.02914805971972999, KL_LOSS: 3.7054184157828525e-06


100%|██████████| 84/84 [00:00<00:00, 88.60it/s] 


TRAIN: EPOCH 47: MSE: 0.02964236671548514, KL_LOSS: 3.722101401611629e-06


100%|██████████| 84/84 [00:01<00:00, 55.08it/s]


TRAIN: EPOCH 48: MSE: 0.02948521612034667, KL_LOSS: 3.757565439497869e-06


100%|██████████| 84/84 [00:01<00:00, 59.19it/s]


TRAIN: EPOCH 49: MSE: 0.02873632617826973, KL_LOSS: 3.7765876646474115e-06


100%|██████████| 84/84 [00:01<00:00, 77.24it/s]


TRAIN: EPOCH 50: MSE: 0.028294225689023733, KL_LOSS: 3.8022973339615247e-06


100%|██████████| 84/84 [00:01<00:00, 78.72it/s]


TRAIN: EPOCH 51: MSE: 0.02824989524448202, KL_LOSS: 3.812066067677647e-06


100%|██████████| 84/84 [00:00<00:00, 93.36it/s]


TRAIN: EPOCH 52: MSE: 0.028773224047784294, KL_LOSS: 3.849580290686087e-06


100%|██████████| 84/84 [00:01<00:00, 80.91it/s]


TRAIN: EPOCH 53: MSE: 0.028131195277507817, KL_LOSS: 3.862172472205809e-06


100%|██████████| 84/84 [00:00<00:00, 90.94it/s]


TRAIN: EPOCH 54: MSE: 0.028450889246804372, KL_LOSS: 3.9127705100500385e-06


100%|██████████| 84/84 [00:00<00:00, 94.53it/s] 


TRAIN: EPOCH 55: MSE: 0.02822624614817046, KL_LOSS: 3.92552685536343e-06


100%|██████████| 84/84 [00:00<00:00, 101.37it/s]


TRAIN: EPOCH 56: MSE: 0.02841033480529274, KL_LOSS: 3.971078662007563e-06


100%|██████████| 84/84 [00:01<00:00, 79.74it/s]


TRAIN: EPOCH 57: MSE: 0.02780100122271549, KL_LOSS: 3.9687939842484155e-06


100%|██████████| 84/84 [00:00<00:00, 88.09it/s]


TRAIN: EPOCH 58: MSE: 0.027676595530162256, KL_LOSS: 4.015268340632853e-06


100%|██████████| 84/84 [00:00<00:00, 88.94it/s]


TRAIN: EPOCH 59: MSE: 0.02712983921879814, KL_LOSS: 4.018082417392829e-06


100%|██████████| 84/84 [00:00<00:00, 94.49it/s]


TRAIN: EPOCH 60: MSE: 0.02779785452765368, KL_LOSS: 4.020901758936123e-06


100%|██████████| 84/84 [00:00<00:00, 91.30it/s]


TRAIN: EPOCH 61: MSE: 0.0272605349087999, KL_LOSS: 4.0485756904845835e-06


100%|██████████| 84/84 [00:00<00:00, 94.65it/s]


TRAIN: EPOCH 62: MSE: 0.02738299922618483, KL_LOSS: 4.076209173880281e-06


100%|██████████| 84/84 [00:00<00:00, 90.00it/s]


TRAIN: EPOCH 63: MSE: 0.026950685573475703, KL_LOSS: 4.082123853420074e-06


100%|██████████| 84/84 [00:00<00:00, 85.89it/s]


TRAIN: EPOCH 64: MSE: 0.02714211256465032, KL_LOSS: 4.109143100365708e-06


100%|██████████| 84/84 [00:00<00:00, 95.11it/s]


TRAIN: EPOCH 65: MSE: 0.026816418084005516, KL_LOSS: 4.127047683747757e-06


100%|██████████| 84/84 [00:00<00:00, 85.89it/s]


TRAIN: EPOCH 66: MSE: 0.026567317501065276, KL_LOSS: 4.130305691368059e-06


100%|██████████| 84/84 [00:00<00:00, 93.35it/s]


TRAIN: EPOCH 67: MSE: 0.026621883468968526, KL_LOSS: 4.127416738291296e-06


100%|██████████| 84/84 [00:00<00:00, 95.53it/s] 


TRAIN: EPOCH 68: MSE: 0.026392052704024883, KL_LOSS: 4.14476782106292e-06


100%|██████████| 84/84 [00:00<00:00, 88.05it/s]


TRAIN: EPOCH 69: MSE: 0.02623860500309439, KL_LOSS: 4.153198333478074e-06


100%|██████████| 84/84 [00:00<00:00, 93.33it/s]


TRAIN: EPOCH 70: MSE: 0.026551746608068545, KL_LOSS: 4.167533973031823e-06


100%|██████████| 84/84 [00:00<00:00, 87.65it/s]


TRAIN: EPOCH 71: MSE: 0.02627351780288986, KL_LOSS: 4.182628169062463e-06


100%|██████████| 84/84 [00:00<00:00, 94.66it/s]


TRAIN: EPOCH 72: MSE: 0.025918057942319484, KL_LOSS: 4.202840990447363e-06


100%|██████████| 84/84 [00:00<00:00, 89.19it/s]


TRAIN: EPOCH 73: MSE: 0.026027766898984, KL_LOSS: 4.2089187402217205e-06


100%|██████████| 84/84 [00:00<00:00, 90.01it/s] 


TRAIN: EPOCH 74: MSE: 0.026275900259081806, KL_LOSS: 4.216510094898868e-06


100%|██████████| 84/84 [00:00<00:00, 97.87it/s] 


TRAIN: EPOCH 75: MSE: 0.02578908901306845, KL_LOSS: 4.233695963626988e-06


100%|██████████| 84/84 [00:00<00:00, 94.07it/s]


TRAIN: EPOCH 76: MSE: 0.02588891756853887, KL_LOSS: 4.255641155872408e-06


100%|██████████| 84/84 [00:00<00:00, 89.83it/s]


TRAIN: EPOCH 77: MSE: 0.025548147924599193, KL_LOSS: 4.2783689899633395e-06


100%|██████████| 84/84 [00:01<00:00, 83.24it/s] 


TRAIN: EPOCH 78: MSE: 0.02548340115962284, KL_LOSS: 4.295064565255979e-06


100%|██████████| 84/84 [00:01<00:00, 80.63it/s]


TRAIN: EPOCH 79: MSE: 0.025386150216772443, KL_LOSS: 4.297380831715037e-06


100%|██████████| 84/84 [00:01<00:00, 81.90it/s]


TRAIN: EPOCH 80: MSE: 0.02521563618488255, KL_LOSS: 4.3104720253679606e-06


100%|██████████| 84/84 [00:01<00:00, 74.18it/s]


TRAIN: EPOCH 81: MSE: 0.02548656856552476, KL_LOSS: 4.300178307489876e-06


100%|██████████| 84/84 [00:00<00:00, 94.58it/s]


TRAIN: EPOCH 82: MSE: 0.025267041620931456, KL_LOSS: 4.308733101273295e-06


100%|██████████| 84/84 [00:00<00:00, 94.21it/s]


TRAIN: EPOCH 83: MSE: 0.02478446771523782, KL_LOSS: 4.343781851968656e-06


100%|██████████| 84/84 [00:00<00:00, 87.28it/s]


TRAIN: EPOCH 84: MSE: 0.02498812953542386, KL_LOSS: 4.338057947063048e-06


100%|██████████| 84/84 [00:00<00:00, 85.41it/s]


TRAIN: EPOCH 85: MSE: 0.02504148099216677, KL_LOSS: 4.350546210692603e-06


100%|██████████| 84/84 [00:01<00:00, 82.69it/s]


TRAIN: EPOCH 86: MSE: 0.02484331146946975, KL_LOSS: 4.3513858664872326e-06


100%|██████████| 84/84 [00:00<00:00, 94.65it/s]


TRAIN: EPOCH 87: MSE: 0.02482575818984991, KL_LOSS: 4.399786178055365e-06


100%|██████████| 84/84 [00:00<00:00, 94.20it/s]


TRAIN: EPOCH 88: MSE: 0.024710811098061856, KL_LOSS: 4.381220032016808e-06


100%|██████████| 84/84 [00:00<00:00, 95.10it/s]


TRAIN: EPOCH 89: MSE: 0.024839439296296666, KL_LOSS: 4.383374982266798e-06


100%|██████████| 84/84 [00:00<00:00, 96.44it/s]


TRAIN: EPOCH 90: MSE: 0.02488989591421116, KL_LOSS: 4.4162005922641716e-06


100%|██████████| 84/84 [00:00<00:00, 95.55it/s]


TRAIN: EPOCH 91: MSE: 0.02479938626111973, KL_LOSS: 4.4547037012827265e-06


100%|██████████| 84/84 [00:00<00:00, 89.21it/s]


TRAIN: EPOCH 92: MSE: 0.024696454450133302, KL_LOSS: 4.4594068256065795e-06


100%|██████████| 84/84 [00:00<00:00, 91.64it/s] 


TRAIN: EPOCH 93: MSE: 0.024376445028576114, KL_LOSS: 4.46382980661448e-06


100%|██████████| 84/84 [00:00<00:00, 85.77it/s]


TRAIN: EPOCH 94: MSE: 0.024773239468534786, KL_LOSS: 4.4813557965859394e-06


100%|██████████| 84/84 [00:00<00:00, 96.93it/s]


TRAIN: EPOCH 95: MSE: 0.024533350264564865, KL_LOSS: 4.476790734099181e-06


100%|██████████| 84/84 [00:00<00:00, 90.01it/s]


TRAIN: EPOCH 96: MSE: 0.023979664491933016, KL_LOSS: 4.475941425750664e-06


100%|██████████| 84/84 [00:00<00:00, 95.08it/s]


TRAIN: EPOCH 97: MSE: 0.023911138619517998, KL_LOSS: 4.44744551941767e-06


100%|██████████| 84/84 [00:00<00:00, 93.34it/s]


TRAIN: EPOCH 98: MSE: 0.02421164315282589, KL_LOSS: 4.472711939992483e-06


100%|██████████| 84/84 [00:00<00:00, 92.34it/s]

TRAIN: EPOCH 99: MSE: 0.02427032689696976, KL_LOSS: 4.498267209046822e-06





In [10]:
df_to_be_shown=encode(autoencoder,dataset,device)

  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}
100%|██████████| 2679/2679 [00:04<00:00, 618.83it/s]


### We get the cosine similarity of every two perturabation

In [11]:
cos_sim_f=cosine_similarity(np.array(df_to_be_shown.drop(['control'], axis=1)))

### Now we want to know which two perturabations are similar

In [12]:
similarity_matrix=np.zeros(cos_sim_f.shape)
similarity_db=hu_data_loader()

for gene_name in tqdm.tqdm(adata_orig.obs.gene_name.unique()):
    query=query_hu_data(similarity_db,gene_name)
    for q in query:
        if q in adata_orig.obs.gene_name.values:
            y_indices=adata_orig.obs[adata_orig.obs.gene_name==q].id
            x_indices=adata_orig.obs[adata_orig.obs.gene_name==gene_name].id
            for x_id in x_indices:
                for y_id in y_indices:
                    similarity_matrix[y_id,x_id]=1
                    similarity_matrix[x_id,y_id]=1

cos_sim_f_flatten=cos_sim_f.reshape(-1,)
similarity_matrix_flatten=similarity_matrix.reshape(-1,)
cos_sim_f_flatten1=cos_sim_f_flatten[similarity_matrix_flatten==1]
cos_sim_f_flatten0=cos_sim_f_flatten[similarity_matrix_flatten==0]

File downloaded successfully to humap2_complexes_20200809.txt


100%|██████████| 2394/2394 [00:54<00:00, 44.06it/s] 


### We want to visualize the value of recall with respect to different quantiles as thresholds for similarities

In [13]:
def get_recall(rate):
    qrate_down=np.quantile(cos_sim_f_flatten,rate)
    qrate_up=np.quantile(cos_sim_f_flatten,1-rate)
    pred_p=np.logical_or(cos_sim_f_flatten>qrate_up,cos_sim_f_flatten<qrate_down)
    pred_n=np.logical_and(cos_sim_f_flatten<qrate_up,cos_sim_f_flatten>qrate_down)
    tp=np.logical_and(pred_p,similarity_matrix_flatten==1).sum()
    fp=np.logical_and(pred_p,similarity_matrix_flatten==0).sum()
    fn=np.logical_and(pred_n,similarity_matrix_flatten==1).sum()
    return tp/(tp+fn)
def visualize_recal_vs_quantile():
    values=[]
    xs=[i*0.05 for i in range(10)]
    for i in xs:
        values.append(get_recall(i))
    temp_df=pd.DataFrame({'quantile':xs,'recall':values})
    fig=px.line(temp_df,x='quantile',y='recall',title='recall_vs_quantile',width=1000, height=400)
    fig.update_traces(mode='lines+text', text=list(map(lambda x:round(x,2),values)), textposition='top center')
    fig.update_layout(
    font=dict(
        family="Arial, sans-serif",
        size=10,  # Set the desired font size
        color="black"
    )
)
    fig.show()
    

In [14]:
visualize_recal_vs_quantile()

### Okay. now we plot the distributions of similarities divided into classes. first, those pairs that are already know to be similar. Second those that are not.

In [15]:
choice=np.random.choice(cos_sim_f_flatten1.shape[0], 2048)
cos_sim_f_flatten1=cos_sim_f_flatten1[choice]
cos_sim_f_flatten1=pd.DataFrame(cos_sim_f_flatten1,columns=['correlations'])
fig=px.violin(cos_sim_f_flatten1, y='correlations',width=500, height=400,title="SIMILARS")
fig.show()


choice=np.random.choice(cos_sim_f_flatten0.shape[0], 2048)
cos_sim_f_flatten0=cos_sim_f_flatten0[choice]
cos_sim_f_flatten0=pd.DataFrame(cos_sim_f_flatten0,columns=['correlations'])
fig=px.violin(cos_sim_f_flatten0, y='correlations',width=500, height=400,title="Not SIMILARS")
fig.show()

In [16]:
print("Not SIMILARS MEAN:",cos_sim_f_flatten0.mean())
print("SIMILARS MEAN:",cos_sim_f_flatten1.mean())

Not SIMILARS MEAN: correlations    0.033704
dtype: float32
SIMILARS MEAN: correlations    0.450013
dtype: float32


### Here is the visualiztion of feature vectors

In [17]:
fig=px.scatter(df_to_be_shown,x='f0',y='f1',color='control',width=500, height=400)
fig.show()
fig=px.scatter(df_to_be_shown,x='f2',y='f3',color='control',width=500, height=400)
fig.show()
fig=px.scatter(df_to_be_shown,x='f4',y='f5',color='control',width=500, height=400)
fig.show()











