In [42]:
from torch import nn
import torch
import argparse
import numpy as np
import pandas as pd
import scanpy as sc
import os
import anndata
import math

In [43]:
from modules import network,mlp
from utils import yaml_config_hook,save_model

parser = argparse.ArgumentParser()
config = yaml_config_hook("config/config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))
args = parser.parse_args([])
if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
class_num = args.classnum

In [44]:
import scipy.sparse
sparse_X = scipy.sparse.load_npz('data/filtered_Counts.npz')
annoData = pd.read_table('data/annoData.txt')
y = annoData["cellIden"].to_numpy()
high_var_gene = 6000
# normlization and feature selection
adataSC = anndata.AnnData(X=sparse_X, obs=np.arange(sparse_X.shape[0]), var=np.arange(sparse_X.shape[1]))
sc.pp.filter_genes(adataSC, min_cells=10)
adataSC.raw = adataSC
sc.pp.highly_variable_genes(adataSC, n_top_genes=high_var_gene, flavor='seurat_v3')
sc.pp.normalize_total(adataSC, target_sum=1e4)
sc.pp.log1p(adataSC)

adataNorm = adataSC[:, adataSC.var.highly_variable]
dataframe = adataNorm.to_df()
x_ndarray = dataframe.values.squeeze()
y_ndarray = np.expand_dims(y, axis=1)
print(x_ndarray.shape,y_ndarray.shape)
dataframe.head()

  if index_name in anno:


(8569, 6000) (8569, 1)


Unnamed: 0,1,2,4,7,10,13,26,31,32,33,...,20104,20105,20108,20109,20115,20118,20121,20122,20123,20124
0,1.024218,0.0,0.0,0.0,1.302199,0.0,0.0,0.0,0.637877,0.0,...,0.0,0.0,0.36896,0.0,0.0,0.0,0.637877,0.0,0.0,0.36896
1,0.0,0.0,0.0,0.0,1.351171,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.888292,0.0,0.305824,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.4175,0.0,0.0,0.0,0.0,0.0,0.0,...,0.4175,0.0,0.0,0.0,0.4175,0.0,0.93785,0.4175,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.509045,0.0,0.0,0.509045


In [45]:
from torch.utils.data import DataLoader,random_split,TensorDataset

scDataset = TensorDataset(torch.tensor(x_ndarray, dtype=torch.float32),
                              torch.tensor(y_ndarray, dtype=torch.float32))

scTrainLength = int(len(scDataset) * 0.8)
scValidLength = len(scDataset) - scTrainLength
scTrain, scValid = random_split(scDataset, [scTrainLength, scValidLength])

scTrainDataLoader = DataLoader(scTrain, shuffle=True, batch_size=args.batch_size,drop_last=True)
scValidDataLoader = DataLoader(scValid, shuffle=True, batch_size=args.batch_size,drop_last=True)

for features, labels in scTrainDataLoader:
    print(len(features[-1]))
    print(len(features))
    print(len(labels))
    break

6000
256
256


In [46]:
# initialize model
mlp = mlp.MLP(num_genes=args.num_genes)
model = network.Network(mlp, args.feature_dim, args.classnum)
model = model.to('cuda')
# optimizer / loss
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

In [47]:
class InstanceLoss(nn.Module):

    def __init__(self,batch_size,temperature,device):
        super(InstanceLoss,self).__init__()
        self.batch_size=batch_size
        self.temperature=temperature
        self.device=device

        self.mask=self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")


    def mask_correlated_samples(self,batch_size):
        N=2*batch_size
        mask=torch.ones(N,N)
        mask=mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i,batch_size+1]=0
            mask[batch_size+i,i]=0
        mask=mask.bool()
        return mask

    def forward(self,z_i,z_j):
        N=2*self.batch_size
        z=torch.cat((z_i,z_j),dim=0)

        sim=torch.matmul(z,z.T)/self.temperature
        sim_i_j=torch.diag(sim,self.batch_size)
        sim_j_i=torch.diag(sim,-self.batch_size)

        positive_samples= torch.cat((sim_i_j,sim_j_i),dim=0).reshape(N,1)
        negative_samples=sim[self.mask].reshape(N,-1)

        labels = torch.zeros(N).to(positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N

        return loss

In [48]:
class ClusterLoss(nn.Module):
    def __init__(self, class_num, temperature, device):
        super(ClusterLoss, self).__init__()
        self.class_num = class_num
        self.temperature = temperature
        self.device = device

        self.mask = self.mask_correlated_clusters(class_num)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_clusters(self, class_num):
        N = 2 * class_num
        mask = torch.ones((N, N))
        mask = mask.fill_diagonal_(0)
        for i in range(class_num):
            mask[i, class_num + i] = 0
            mask[class_num + i, i] = 0
        mask = mask.bool()
        return mask

    def forward(self, c_i, c_j):
        p_i = c_i.sum(0).view(-1)
        p_i /= p_i.sum()
        ne_i = math.log(p_i.size(0)) + (p_i * torch.log(p_i)).sum()
        p_j = c_j.sum(0).view(-1)
        p_j /= p_j.sum()
        ne_j = math.log(p_j.size(0)) + (p_j * torch.log(p_j)).sum()
        ne_loss = ne_i + ne_j

        c_i = c_i.t()
        c_j = c_j.t()
        N = 2 * self.class_num
        c = torch.cat((c_i, c_j), dim=0)

        sim = self.similarity_f(c.unsqueeze(1), c.unsqueeze(0)) / self.temperature
        sim_i_j = torch.diag(sim, self.class_num)
        sim_j_i = torch.diag(sim, -self.class_num)

        positive_clusters = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        negative_clusters = sim[self.mask].reshape(N, -1)

        labels = torch.zeros(N).to(positive_clusters.device).long()
        logits = torch.cat((positive_clusters, negative_clusters), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N

        return loss + ne_loss


In [49]:
# loss_device=torch.device('cuda')
# instance_loss=InstanceLoss(batch_size=args.batch_size,temperature=0.5,device=loss_device)
# cluster_loss=ClusterLoss(class_num=class_num,temperature=0.5,device=loss_device)
# loss_epoch = 0
# for step, (data,label) in enumerate(scTrainDataLoader):
#     optimizer.zero_grad()
#     x_i=data.clone().to('cuda')
#     x_j=data.clone().to('cuda')
#     z_i,z_j,c_i,c_j=model(x_i,x_j)
#     loss_instance=instance_loss(z_i,z_j)
#     loss_cluster=cluster_loss(c_i,c_j)
#     loss = loss_instance + loss_cluster
#     loss.backward()
#     optimizer.step()
#     if step % 10 == 0:
#         print(
#             f"Step [{step}/{len(scTrainDataLoader)}]\t loss_instance: {loss_instance.item()}\t loss_cluster: {loss_cluster.item()}")
#     loss_epoch += loss.item()


In [50]:
def train(instance_loss,cluster_loss,device):
    loss_epoch = 0
    for step, (data,label) in enumerate(scTrainDataLoader):
        optimizer.zero_grad()
        x_i=data.clone().to('cuda')
        x_j=data.clone().to('cuda')
        z_i,z_j,c_i,c_j=model(x_i,x_j)
        loss_instance=instance_loss(z_i,z_j)
        loss_cluster=cluster_loss(c_i,c_j)
        loss = loss_instance + loss_cluster
        loss.backward()
        optimizer.step()
        if step % 10 == 0:
            print(
                f"Step [{step}/{len(scTrainDataLoader)}]\t loss_instance: {loss_instance.item()}\t loss_cluster: {loss_cluster.item()}")
        loss_epoch += loss.item()
    return loss_epoch


In [51]:
def inference(loader, model, device):
    model.eval()
    feature_vector = []
    labels_vector = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)
        with torch.no_grad():
            c = model.forward_cluster(x)
        c = c.detach()
        feature_vector.extend(c.cpu().detach().numpy())
        labels_vector.extend(y.numpy())
        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")
    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    print("Features shape {}".format(feature_vector.shape))
    # print(feature_vector.shape, labels_vector.shape)
    return feature_vector, labels_vector

In [52]:
from evaluation import evaluation
def test():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X, Y = inference(scValidDataLoader, model, device)
    print(X.shape,Y.shape)
    nmi, ari, f, acc = evaluation.evaluate(Y.reshape(-1), X)
    # print('NMI = {:.4f} ARI = {:.4f} F = {:.4f} ACC = {:.4f}'.format(nmi, ari, f, acc))
    return nmi, ari, f, acc

In [53]:
loss_device=torch.device('cuda')
instance_loss=InstanceLoss(batch_size=args.batch_size,temperature=0.5,device=loss_device)
cluster_loss=ClusterLoss(class_num=class_num,temperature=0.5,device=loss_device)

for epoch in range(args.start_epoch, args.epochs):
    lr = optimizer.param_groups[0]["lr"]
    loss_epoch = train(instance_loss,cluster_loss,loss_device)
    if epoch % 10 == 0:
        save_model(args, model, optimizer, epoch)
    print(f"\nEpoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(scTrainDataLoader)} \n")
    nmi, ari, f, acc = test()
    print('Test NMI = {:.4f} ARI = {:.4f} F = {:.4f} ACC = {:.4f}'.format(nmi, ari, f, acc))
    print('========'*8+'\n')

  input = module(input)


Step [0/26]	 loss_instance: 6.231241226196289	 loss_cluster: 3.2991526126861572
Step [10/26]	 loss_instance: 6.228573322296143	 loss_cluster: 3.2989978790283203
Step [20/26]	 loss_instance: 6.226306915283203	 loss_cluster: 3.298896551132202

Epoch [0/100]	 Loss: 9.526805070730356 

Step [0/6]	 Computing features...
Features shape (1536,)
(1536,) (1536, 1)
Test NMI = 0.1277 ARI = 0.0626 F = 0.2868 ACC = 0.0703

Step [0/26]	 loss_instance: 6.2114410400390625	 loss_cluster: 3.2988059520721436


  input = module(input)


Step [10/26]	 loss_instance: 6.184502124786377	 loss_cluster: 3.298677682876587
Step [20/26]	 loss_instance: 6.115828990936279	 loss_cluster: 3.298475742340088

Epoch [1/100]	 Loss: 9.454911818871132 

Step [0/6]	 Computing features...
Features shape (1536,)
(1536,) (1536, 1)
Test NMI = 0.1104 ARI = -0.0026 F = 0.3360 ACC = 0.0475

Step [0/26]	 loss_instance: 6.0420122146606445	 loss_cluster: 3.2982559204101562


  input = module(input)


Step [10/26]	 loss_instance: 5.9172821044921875	 loss_cluster: 3.29764461517334
Step [20/26]	 loss_instance: 5.855574607849121	 loss_cluster: 3.296943426132202

Epoch [2/100]	 Loss: 9.193093593303974 

Step [0/6]	 Computing features...
Features shape (1536,)
(1536,) (1536, 1)
Test NMI = 0.3755 ARI = 0.2985 F = 0.4635 ACC = 0.0267



  input = module(input)


Step [0/26]	 loss_instance: 5.743744850158691	 loss_cluster: 3.295684576034546
Step [10/26]	 loss_instance: 5.679439067840576	 loss_cluster: 3.294421434402466
Step [20/26]	 loss_instance: 5.657205581665039	 loss_cluster: 3.293341636657715

Epoch [3/100]	 Loss: 8.968069406656118 

Step [0/6]	 Computing features...
Features shape (1536,)
(1536,) (1536, 1)
Test NMI = 0.4789 ARI = 0.4235 F = 0.5414 ACC = 0.0469



  input = module(input)


Step [0/26]	 loss_instance: 5.621006488800049	 loss_cluster: 3.2922980785369873
Step [10/26]	 loss_instance: 5.551555156707764	 loss_cluster: 3.2902541160583496
Step [20/26]	 loss_instance: 5.471215724945068	 loss_cluster: 3.2890233993530273

Epoch [4/100]	 Loss: 8.815011134514442 

Step [0/6]	 Computing features...
Features shape (1536,)
(1536,) (1536, 1)
Test NMI = 0.4947 ARI = 0.3478 F = 0.4704 ACC = 0.0892



  input = module(input)


Step [0/26]	 loss_instance: 5.4052934646606445	 loss_cluster: 3.288037061691284
Step [10/26]	 loss_instance: 5.3280110359191895	 loss_cluster: 3.2848637104034424
Step [20/26]	 loss_instance: 5.249216079711914	 loss_cluster: 3.2842748165130615

Epoch [5/100]	 Loss: 8.601975807776817 

Step [0/6]	 Computing features...
Features shape (1536,)
(1536,) (1536, 1)
Test NMI = 0.5582 ARI = 0.4657 F = 0.5670 ACC = 0.0755



  input = module(input)


Step [0/26]	 loss_instance: 5.22073221206665	 loss_cluster: 3.2817556858062744
Step [10/26]	 loss_instance: 5.166556358337402	 loss_cluster: 3.2804930210113525
Step [20/26]	 loss_instance: 5.138511657714844	 loss_cluster: 3.2769439220428467

Epoch [6/100]	 Loss: 8.44427673633282 

Step [0/6]	 Computing features...
Features shape (1536,)
(1536,) (1536, 1)
Test NMI = 0.5878 ARI = 0.5222 F = 0.6110 ACC = 0.6530



  input = module(input)


Step [0/26]	 loss_instance: 5.108696937561035	 loss_cluster: 3.274916648864746
Step [10/26]	 loss_instance: 5.083678722381592	 loss_cluster: 3.2718863487243652
Step [20/26]	 loss_instance: 5.049376487731934	 loss_cluster: 3.2654192447662354

Epoch [7/100]	 Loss: 8.34597264803373 

Step [0/6]	 Computing features...
Features shape (1536,)
(1536,) (1536, 1)
Test NMI = 0.5595 ARI = 0.4840 F = 0.5792 ACC = 0.6296

Step [0/26]	 loss_instance: 5.030228614807129	 loss_cluster: 3.263798475265503


  input = module(input)


Step [10/26]	 loss_instance: 5.007358551025391	 loss_cluster: 3.2585785388946533
Step [20/26]	 loss_instance: 4.972696304321289	 loss_cluster: 3.2534735202789307

Epoch [8/100]	 Loss: 8.260188542879545 

Step [0/6]	 Computing features...
Features shape (1536,)
(1536,) (1536, 1)
Test NMI = 0.5024 ARI = 0.3906 F = 0.5018 ACC = 0.5534



  input = module(input)


Step [0/26]	 loss_instance: 4.9634623527526855	 loss_cluster: 3.2483842372894287
Step [10/26]	 loss_instance: 4.932798385620117	 loss_cluster: 3.2410998344421387
Step [20/26]	 loss_instance: 4.906030654907227	 loss_cluster: 3.23221755027771

Epoch [9/100]	 Loss: 8.165750320141132 

Step [0/6]	 Computing features...
Features shape (1536,)
(1536,) (1536, 1)
Test NMI = 0.5057 ARI = 0.4347 F = 0.5453 ACC = 0.6100

Step [0/26]	 loss_instance: 4.877933502197266	 loss_cluster: 3.2284905910491943


  input = module(input)


Step [10/26]	 loss_instance: 4.8615498542785645	 loss_cluster: 3.219658851623535


KeyboardInterrupt: 