In [1]:
from torch import nn
import torch
import argparse
import numpy as np
import pandas as pd
import scanpy as sc
import os
import anndata
import math
import hnswlib

In [2]:
from modules import network,mlp,contrastive_loss
from utils import yaml_config_hook

parser = argparse.ArgumentParser()
config = yaml_config_hook("config/config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))
args = parser.parse_args([])
if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)

# torch.manual_seed(args.seed)
# torch.cuda.manual_seed_all(args.seed)
# torch.cuda.manual_seed(args.seed)
# np.random.seed(args.seed)
class_num = args.classnum

In [4]:
import scipy.sparse
sparse_X = scipy.sparse.load_npz('data/filtered_Counts.npz')
annoData = pd.read_table('data/annoData.txt')
y = annoData["cellIden"].to_numpy()
high_var_gene = args.num_genes
# normlization and feature selection
adataSC = anndata.AnnData(X=sparse_X, obs=np.arange(sparse_X.shape[0]), var=np.arange(sparse_X.shape[1]))
sc.pp.filter_genes(adataSC, min_cells=10)
adataSC.raw = adataSC
sc.pp.highly_variable_genes(adataSC, n_top_genes=high_var_gene, flavor='seurat_v3')
sc.pp.normalize_total(adataSC, target_sum=1e4)
sc.pp.log1p(adataSC)

adataNorm = adataSC[:, adataSC.var.highly_variable]
dataframe = adataNorm.to_df()
x_ndarray = dataframe.values.squeeze()
y_ndarray = np.expand_dims(y, axis=1)
print(x_ndarray.shape,y_ndarray.shape)
dataframe.head()

  if index_name in anno:


(8569, 2000) (8569, 1)


Unnamed: 0,2,10,13,41,45,62,68,106,133,147,...,19763,19786,19808,19854,19883,20021,20073,20109,20121,20124
0,0.0,1.302199,0.0,0.0,0.0,0.0,0.36896,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.637877,0.36896
1,0.0,1.351171,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.888292,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.711146,0.0,0.0,0.0,0.0,...,0.0,0.4175,0.0,0.0,0.93785,0.0,0.0,0.0,0.93785,0.0
4,0.0,0.0,0.0,0.0,0.0,0.509045,0.0,0.0,0.0,0.0,...,0.0,0.509045,0.0,0.0,0.0,0.0,0.0,0.0,0.509045,0.509045


In [5]:
mlp = mlp.MLP(num_genes=args.num_genes)
model = network.Network(mlp, args.feature_dim, args.classnum)
model = model.to('cuda')
# optimizer / loss
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

In [6]:
from torch.utils.data import DataLoader,TensorDataset
scDataset = TensorDataset(torch.tensor(x_ndarray, dtype=torch.float32),
                              torch.tensor(y_ndarray, dtype=torch.float32))

scDataLoader = DataLoader(scDataset, shuffle=True, batch_size=1024,drop_last=True)

for features, labels in scDataLoader:
    print(len(features[-1]))
    print(len(features))
    print(len(labels))
    break

2000
1024
1024


MemoryBank

首先初始化时传入各种参数数值并建立一个基于原始数据的hnsw的Bank

而后每一次调用该对象时，传入一个batch的数据，并以矩阵的形式返回该batch的每个样本的Top 10近邻

而每次需要结合Embedding进行更新bank时使用update函数重新建立索引

In [16]:
class MemoryBank():

    # 初始化，传入参数
    def __init__(self,batch_size,full_data,topK=10):
        self.topK=topK
        self.batch_size=batch_size
        self.bank=None
        self.full_data=full_data

    # 根据在updateBank中更新的hnsw对象以及输入的数据data（这里可以是embedding）提取TopK个近邻的数据
    # 返回的结果是一个形状为[TopK,batch_size,num_genes]的数组，从第一个维度来看，
    # 每个[batch_size,num_genes]的子数组都是根据输入的数据data寻找的一个近邻，一共TopK个
    def generateContrast(self,data):
        if self.bank is not None:
            contrasts=np.empty((self.topK,args.batch_size,args.num_genes))
            labels,distances=self.bank.knn_query(data,k=self.topK)
            
            print(labels)

            for step,label in enumerate(labels):
                contrasts[:,step]=self.full_data[label.tolist()]
            return contrasts
        else:
            print('Memory Bank has not been initialized......')
            raise NotImplementedError()
            return None

    # 根据输入的embedding更新hnsw对象
    def updateBank(self,embedding):
        num_elements=len(embedding)
        dim=embedding.shape[1]
        self.bank=hnswlib.Index(space='cosine',dim=dim)
        self.bank.init_index(max_elements=num_elements, ef_construction=100, M=16)
        self.bank.set_ef(100)
        self.bank.set_num_threads(4)
        self.bank.add_items(embedding)
        

        

使用Model中的MLP生成整个数据集的embedding

需要另外新建一个不打乱数据顺序的数据集来往Model中传参数

并且需要注意在迭代更新Memory Bank后必须使用经过model处理的样本来进行索引查询，否则差距巨大！

In [21]:
embeddings=np.empty((0,mlp.rep_dim))
model.eval()
with torch.no_grad():
    for step, (x, y) in enumerate(scDataLoader):
        x=x.to('cuda')
        embedding=model.forward_embedding(x).cpu().detach().numpy()                 
        embeddings=np.row_stack((embeddings,embedding))
# for embedding in embeddings:
#     print(embedding)
#         # Model输出的embedding的形状是batch_size x mlp.rep_dim
print(embeddings.shape)
test_bank=MemoryBank(batch_size=args.batch_size,full_data=x_ndarray,topK=10)
test_bank.updateBank(x_ndarray)
test_result_1=test_bank.generateContrast(x_ndarray[0])
print(test_result_1.shape)
test_bank.updateBank(embedding=embeddings)

# 查询时必须经过model的mlp传输生成embedding才可查询
test_x_0=model.forward_embedding(torch.tensor(x_ndarray[0],dtype=torch.float32).to('cuda')).cpu().detach().numpy()  
test_result_2=test_bank.generateContrast(test_x_0)
print(test_result_2.shape)

(8192, 128)
[[   0 1202 1197 1207  573 1216    1 1221   46 6371]]
(10, 1024, 2000)
[[   0 3834 4771 1197 1207  172  706 5455 1221 1203]]
(10, 1024, 2000)


In [7]:
# memoryBank=MemoryBank(args.batch_size,x_ndarray,topK=10)
# memoryBank.updateBank(x_ndarray)
# for featres,tags in scDataLoader:
#     result=memoryBank.generateContrast(featres.numpy())
#     print(result.shape)
#     # print(result[0])

(10, 1024, 2000)
(10, 1024, 2000)
(10, 1024, 2000)
(10, 1024, 2000)
(10, 1024, 2000)
(10, 1024, 2000)
(10, 1024, 2000)
(10, 1024, 2000)


In [8]:
memory_bank=MemoryBank(batch_size=args.batch_size,full_data=x_ndarray,topK=10)
instance_loss=contrastive_loss.InstanceLoss(batch_size=args.batch_size,temperature=0.5,device='cuda')
cluster_loss=contrastive_loss.ClusterLoss(class_num=class_num,temperature=0.5,device='cuda')
memory_bank.updateBank(x_ndarray)

训练函数一次会针对一个batch中的所有样本生成一个形状为 [TopK,batch_size,num_genes] 的矩阵，分别对应原有batch中的每一个样本的TopK近邻。

而后针对这一个针对所有batch内样本索引TopK近邻的对比矩阵分别进行TopK次对比学习，对于batch内的每一个样本而言，每一次均是使用一个近邻构成正样本对、以及其余的batch_size*2-2的样本构成负样本对。

In [9]:
def train(instance_loss,cluster_loss,memory_bank):
    loss_epoch=0
    for step,(data,label) in enumerate(scDataLoader):
        # optimizer.zero_grad()
        
        contrast_samples=memory_bank.generateContrast(data)
        iter_times=contrast_samples.shape[0]
        for sample in contrast_samples:
            print(f'sample shape:\n{sample.shape}')
            print(f'data shape:\n{data.shape}')
            x_i=data.clone().to('cuda')
            x_j=torch.tensor(sample,dtype=torch.float32).to('cuda')

            z_i,z_j,c_i,c_j=model(x_i,x_j)
            loss_instance=instance_loss(z_i,z_j)
            loss_cluster=cluster_loss(c_i,c_j)
            loss = loss_instance + loss_cluster
            print(f'------ loss:\n{loss}')
            # loss.backward()
            # optimizer.step()
        if step % 1000 == 0:
                print(f"Step [{step}/{len(scDataLoader)}]\t loss_instance: {loss_instance.item()}\t loss_cluster: {loss_cluster.item()}")
            
    return loss_epoch/iter_times


sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])


  input = module(input)


------ loss:
10.92076587677002
sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])
------ loss:
10.92131233215332
sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])
------ loss:
10.921283721923828
sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])
------ loss:
10.921366691589355
sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])
------ loss:
10.921391487121582
sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])
------ loss:
10.921318054199219
sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])
------ loss:
10.921403884887695
sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])
------ loss:
10.921327590942383
sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])
------ loss:
10.921369552612305
sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])
------ loss:
10.92131233215332
sample shape:
(1024, 2000)
data shape:
torch.Size([1024, 2000])
------ loss:
10.920676231384277
sample shap

In [11]:
scGenDataLoader = DataLoader(scDataset, shuffle=False, batch_size=1024,drop_last=True)

for features, labels in scDataLoader:
    print(len(features[-1]))
    print(len(features))
    print(len(labels))
    break

2000
1024
1024


In [None]:
from utils import  save_model
loss_device=torch.device('cuda')
instance_loss=contrastive_loss.InstanceLoss(batch_size=args.batch_size,temperature=0.5,device=loss_device)
cluster_loss=contrastive_loss.ClusterLoss(class_num=class_num,temperature=0.5,device=loss_device)
accs=[]
losses=[]

memory_bank=MemoryBank(batch_size=args.batch_size,full_data=x_ndarray,topK=10)
memory_bank.updateBank(x_ndarray)

for epoch in range(args.start_epoch, args.epochs):
    lr = optimizer.param_groups[0]["lr"]
    loss_epoch = train(instance_loss,cluster_loss,memory_bank=memory_bank)
    losses.append(loss_epoch/len(scDataLoader))
    if epoch % 10 == 0:
        save_model(args, model, optimizer, epoch)
    print(f"\nEpoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(scDataLoader)} \n")
    
    if epoch %20 ==0:
        model.eval()
        with torch.no_grad():
            for step, (x, y) in enumerate(scGenDataLoader):
                # print(type(x))
                x=x.to('cuda')
                embedding=model.forward_embedding(x).cpu().detach().numpy()                 
                embeddings=np.row_stack((embeddings,embedding))
                memory_bank.updateBank(embeddings)
    # nmi, ari, f, acc = test()
    # accs.append(acc)
    # print('Test NMI = {:.4f} ARI = {:.4f} F = {:.4f} ACC = {:.4f}'.format(nmi, ari, f, acc))
    # print('========'*8+'\n')