In [39]:
from torch import nn
import torch
import argparse
import numpy as np
import pandas as pd
import scanpy as sc
import os
import anndata
import math
import hnswlib
from torch.utils.data import DataLoader,random_split,TensorDataset

In [40]:
from modules import network,mlp,contrastive_loss
from utils import yaml_config_hook,save_model

parser = argparse.ArgumentParser()
config = yaml_config_hook("config/config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))
args = parser.parse_args([])
if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)

# torch.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)
# torch.cuda.manual_seed(args.seed)
# np.random.seed(args.seed)
class_num = args.classnum

In [41]:
import scipy.sparse
sparse_X = scipy.sparse.load_npz('data/filtered_Counts.npz')
annoData = pd.read_table('data/annoData.txt')
y = annoData["cellIden"].to_numpy()
high_var_gene = args.num_genes
# normlization and feature selection
adataSC = anndata.AnnData(X=sparse_X, obs=np.arange(sparse_X.shape[0]), var=np.arange(sparse_X.shape[1]))
sc.pp.filter_genes(adataSC, min_cells=10)
adataSC.raw = adataSC
sc.pp.highly_variable_genes(adataSC, n_top_genes=high_var_gene, flavor='seurat_v3')
sc.pp.normalize_total(adataSC, target_sum=1e4)
sc.pp.log1p(adataSC)

adataNorm = adataSC[:, adataSC.var.highly_variable]
dataframe = adataNorm.to_df()
x_ndarray = dataframe.values.squeeze()
y_ndarray = np.expand_dims(y, axis=1)
print(x_ndarray.shape,y_ndarray.shape)
dataframe.head()

  if index_name in anno:


In [None]:
dim=2000
bank=hnswlib.Index(space='cosine',dim=dim)
bank.init_index(max_elements=x_ndarray.shape[0], ef_construction=100, M=16)
bank.set_ef(100)
bank.set_num_threads(4)
bank.add_items(x_ndarray)

In [None]:
temp=np.resize(x_ndarray[0], (10, *x_ndarray[0].shape))
print(temp.shape)

(10, 2000)


In [None]:
contrasts=np.zeros((x_ndarray.shape[0],10,dim))
for step,(x,y) in enumerate(zip(x_ndarray,y_ndarray)):
    if y in [8,10,11,13,14]:
        temp=np.tile(x, 10).reshape((10,2000))
        print(temp)
        contrasts[step]=temp
    else:
        labels,distances=bank.knn_query(x,k=10)
        contrasts[step]=x_ndarray[labels.tolist()]

  contrasts[step]=x_ndarray[labels.tolist()]


[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[1.9767033 0.        0.        ... 0.        0.        0.       ]
 [1.9767033 0.        0.        ... 0.        0.        0.       ]
 [1.9767033 0.        0.        ... 0.        0.        0.       ]
 ...
 [1.9767033 0.        0.        ... 0.        0.        0.       ]
 [1.9767033 0.        0.        ... 0.        0.        0.       ]
 [1.9767033 0.        0.        ... 0.        0.        0.       ]]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0.        0.        0.        ... 0.        0.        1.2615825]
 [0.        0.        0.        ... 0.        0.        1

In [None]:
print(contrasts[1])
print(x_ndarray[1])

[[0.         1.35117137 0.         ... 0.         0.88829178 0.        ]
 [0.         1.95613039 0.         ... 0.36026391 0.83339739 0.36026391]
 [0.         0.73517883 0.         ... 0.         0.         0.73517883]
 ...
 [0.         0.         0.         ... 0.         0.5090453  0.5090453 ]
 [0.         0.78044915 0.         ... 1.02017438 0.46450487 0.        ]
 [0.         0.         0.         ... 0.         0.4774029  0.79920191]]
[0.        1.3511714 0.        ... 0.        0.8882918 0.       ]


In [None]:
temp=np.where(y_ndarray==14)[0]
print(temp)

[1900 1917 3077 3085 6298 6328 7833]


In [None]:

print(contrasts[temp[0]].shape)
print(contrasts[temp[0]])

(10, 2000)
[[0.         0.         0.         ... 0.         0.         1.79426384]
 [0.         0.         0.         ... 0.         0.         1.79426384]
 [0.         0.         0.         ... 0.         0.         1.79426384]
 ...
 [0.         0.         0.         ... 0.         0.         1.79426384]
 [0.         0.         0.         ... 0.         0.         1.79426384]
 [0.         0.         0.         ... 0.         0.         1.79426384]]


In [None]:
from torch.utils.data import DataLoader,random_split,TensorDataset
scDataset = TensorDataset(torch.tensor(x_ndarray, dtype=torch.float32),
                              torch.tensor(y_ndarray, dtype=torch.float32))

scDataLoader = DataLoader(scDataset, shuffle=True, batch_size=args.batch_size,drop_last=True)

for features, labels in scDataLoader:
    print(len(features[-1]))
    print(len(features))
    print(len(labels))
    break

scGenDataLoader = DataLoader(scDataset, shuffle=False, batch_size=args.batch_size,drop_last=True)

for features, labels in scDataLoader:
    print(len(features[-1]))
    print(len(features))
    print(len(labels))
    break

2000
1024
1024
2000
1024
1024


In [None]:
class StaticMemoryBank():

    def __init__(self,batch_size,x,y,dim):
        self.batch_size=batch_size
        self.dim=dim
        self.bank=hnswlib.Index(space='cosine',dim=dim)
        self.bank.init_index(max_elements=8569, ef_construction=100, M=16)
        self.bank.set_ef(100)
        self.bank.set_num_threads(4)
        self.bank.add_items(x)
        self.x_data=x
        self.y_data=y

    def generate_data(self,x_data,y_data):
        contrasts=np.zeros((self.batch_size,11,self.dim))
        rare_list=[]
        for step,(x,y) in enumerate(zip(x_data,y_data)):
            if y in [8,10,11,13,14]:
                temp=np.tile(x, 10).reshape((10,2000))
                temp=np.row_stack((temp,x))
                rare_list.append(step)
                contrasts[step]=temp
            else:
                labels,distances=self.bank.knn_query(x,k=10)
                temp=x_ndarray[labels.tolist()]
                temp=np.row_stack((temp,x))
                contrasts[step]=temp
        
        return contrasts,rare_list


In [None]:
staticMemoryBank=StaticMemoryBank(1024,x_ndarray,y_ndarray,dim=2000)

In [None]:

for (x,y) in scDataLoader:
    # print(x)
    # print(y)
    contrast_data,rare_list=staticMemoryBank.generate_data(x,y)
    print(contrast_data.shape)

    print(contrast_data[0])

    for i in rare_list:
        print(contrast_data[i].shape)
    
    break

  temp=x_ndarray[labels.tolist()]


(1024, 11, 2000)
[[0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.56665921 0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 ...
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.         0.         0.         ... 0.         0.         0.        ]]
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
(11, 2000)
