In [14]:
from torch import nn
import torch
import argparse
import numpy as np
import os
import hnswlib
from modules import network,mlp
from modules.memory_bank import StaticMemoryBank_for_MSLOSS
from modules.multi_similarity_loss import MultiSimilarityLoss,MultiSimilarityLoss_Boost
from utils import yaml_config_hook,save_model,data_preocess
from evaluation import  evaluation_tools
import warnings
warnings.filterwarnings("ignore")

In [3]:
parser = argparse.ArgumentParser()
config = yaml_config_hook("config/config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))
args = parser.parse_args([])
if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)

class_num = args.classnum

In [4]:
x_ndarray,y_ndarray=data_preocess.data_process(x_path='data/filtered_Counts.npz',y_path='data/annoData.txt',args=args)

X Shape: (8569, 2000), Y Shape: (8569, 1)


In [5]:
from torch.utils.data import DataLoader,TensorDataset
scDataset = TensorDataset(torch.tensor(x_ndarray, dtype=torch.float32))

scDataLoader = DataLoader(scDataset, shuffle=True, batch_size=args.batch_size,drop_last=True)

for features in scDataLoader:
    print(len(features))
    print(len(features[0]))
    break

1
512


In [38]:
class StaticMemoryBank_for_MSLOSS_SelfEnhanced():

    def __init__(self,batch_size,x,dim,nn_counts):
        self.batch_size=batch_size
        self.dim=dim
        self.nn_counts=nn_counts
        self.bank=hnswlib.Index(space='cosine',dim=dim)
        self.bank.init_index(max_elements=8569, ef_construction=100, M=16)
        self.bank.set_ef(100)
        self.bank.set_num_threads(4)
        self.bank.add_items(x)
        self.x_data=x
 
    def generate_data(self,sample):

        labels,distances=self.bank.knn_query(sample,k=self.nn_counts)
        pseudolabel=np.arange(labels.shape[0])
        pseudolabel=np.repeat(pseudolabel,self.nn_counts).reshape(-1)
        
        print(labels[0])
        self_index=labels[:,0]
        labels[:,-1]=self_index
        labels[:,-2]=self_index
        labels[:,-3]=self_index
        print(self_index.shape)
        print(labels.shape)
        print(labels[0])
        labels=labels.reshape(-1)

        data=self.x_data[labels]

        return data,pseudolabel

In [41]:


memoryBank=StaticMemoryBank_for_MSLOSS_SelfEnhanced(batch_size=args.batch_size,x=x_ndarray,dim=2000,nn_counts=args.NN_COUNT)
for features, in scDataLoader:
    feature,label=memoryBank.generate_data(features.numpy())
    print(label.shape)
    break

[8064 7282 7868 3099 7901 7292 7961 7291 7894 7286]
(512,)
(512, 10)
[8064 7282 7868 3099 7901 7292 7961 8064 8064 8064]
(5120,)


In [5]:
mlpp=mlp.MLP()
model=network.Network(mlpp,feature_dim=args.feature_dim)
model.to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
memory_bank=StaticMemoryBank_for_MSLOSS(args.batch_size,x_ndarray,args.num_genes,args.NN_COUNT)
ms_loss_boost=MultiSimilarityLoss_Boost(args=args)
ms_loss=MultiSimilarityLoss(args=args)

In [6]:

loss_epoch=0
for step,data in enumerate(scDataLoader):
    optimizer.zero_grad()
    
    data,pseudolabel=memory_bank.generate_data(data[0].numpy())
    data=torch.tensor(data,dtype=torch.float32).to('cuda')
    pseudolabel=torch.tensor(pseudolabel,dtype=torch.long).to('cuda')
    embedding=model(data)
#     loss=ms_loss(embedding,pseudolabel)
    loss_1=ms_loss_boost(embedding,pseudolabel)
    loss_1.backward()
    optimizer.step()
    loss_epoch+=loss_1.item()
    if step % 2 == 0:
            print(f"Step [{step}/{len(scDataLoader)}]\t,  MSLoss_BoostL{loss_1.item()}")
    break


torch.Size([2560, 2560])
torch.Size([2560, 2560])
torch.Size([25600])
torch.Size([6528000])
tensor([True, True, True,  ..., True, True, True], device='cuda:0')
tensor(0.9999, device='cuda:0', grad_fn=<MaxBackward1>)
tensor([0.9000, 0.8988, 0.8992,  ..., 0.8989, 0.8993, 0.9000], device='cuda:0',
       grad_fn=<SubBackward0>)
tensor([True, True, True,  ..., True, True, True], device='cuda:0')
torch.Size([25600])
torch.Size([6528000])
Step [0/33]	,  MSLoss_BoostL5.466931343078613
