In [1]:
from propDEC_end2end import propDEC
import scanpy as sc
import matplotlib.pyplot as plt

In [2]:
# create dataset
pipeline = propDEC(input_size=3000, num_classes=8, metric="t-sne", mode="soft")#soft/hard/km

train_dataset = pipeline.create_dataset("cscc", "cscc_exper.h5ad", mode="train")
test_dataset = pipeline.create_dataset("cscc", "cscc_ctrl.h5ad", mode="test")

ref_adata = train_dataset.train_adata
test_adata = test_dataset.test_adata

In [3]:
weight, celltype, labels = pipeline(epochs=5)

                                                                           

==> Saving Checkpoints
==> Finish training !


                                                                       

==> Saving Checkpoints
Got 17468 / 20813 with accuracy' 83.93%
==> resample


  utils.warn_names_duplicates("obs")


==> Write the resample annotated data to cache
==> Loading Checkpoints


                                                                        

==> Saving Checkpoints
Got 11936 / 20813 with accuracy' 57.35%
==> resample


  utils.warn_names_duplicates("obs")


==> Write the resample annotated data to cache
==> Loading Checkpoints


                                                                        

==> Saving Checkpoints
Got 16707 / 20813 with accuracy' 80.27%
==> resample


  utils.warn_names_duplicates("obs")


==> Write the resample annotated data to cache
==> Loading Checkpoints


                                                                        

==> Saving Checkpoints
Got 6249 / 20813 with accuracy' 30.02%
==> resample


  utils.warn_names_duplicates("obs")


==> Write the resample annotated data to cache
==> Loading Checkpoints


                                                                        

==> Saving Checkpoints
Got 15726 / 20813 with accuracy' 75.56%
==> resample


  utils.warn_names_duplicates("obs")


==> Write the resample annotated data to cache
==> Finish!


### .to(device='cuda')

In [4]:
class KM_Algorithm:
    """
    1.最大权重匹配
    2.输入的二分图应该是一个经过soft assignment的scores或prob矩阵
    3.二分图是以left：cell，rights：celltype来写的
    """

    def __init__(self, Bipartite_Graph):

        self.Bipartite_Graph = Bipartite_Graph

        # 左右结点数量记录
        self.left = self.Bipartite_Graph.shape[0]  # 以左边（细胞）为主
        self.right_true = self.Bipartite_Graph.shape[1] 
        self.right = self.Bipartite_Graph.shape[1] + self.left
        self.reshape_graph() 

        # step1:最高标准初始化（顶标）
        self.label_left = torch.max(self.Bipartite_Graph.cuda(), axis=1)  # 设置左边顶标为权重最大值（每行的最大值）
        label_right = torch.zeros(self.right) # 右边集合的顶标设置为0
        self.label_right = label_right.to(device='cuda')
            # 初始化辅助变量——是否已匹配
        self.visit_left = torch.zeros(self.left, dtype=bool) #全是false
        self.visit_right = torch.zeros(self.right, dtype=bool)

            # 初始化右边的匹配结果.如果已匹配就会对应匹配结果
        self.match_right = torch.empty(self.left).fill_(float('nan')).to(device='cuda')
        
            # 用inc记录需要减去的权值d，不断取最小值故初始化为较大值。权值都为负数，应该不用很大也行
        self.inc = 1000*1000*1000
        self.fail_cell = list()  # 每次匹配重新创建一个二分图匹配对象，所以这个也不用手动重置了

    def reshape_graph(self):
        new = torch.zeros((self.left, self.left)).to(device='cuda')
        self.Bipartite_Graph = torch.cat((self.Bipartite_Graph, new),dim=1)
    def match(self, cell):
        cell = int(cell) 
        self.visit_left[cell] = True # 记录下这个cell已经被寻找
        #在step2最高标准调整的时候判断子图还是不是权重最高
        for celltype in range(self.right):
            if not self.visit_right[celltype] and self.Bipartite_Graph[cell][celltype] >= 0:    
                gap = self.label_left.values[cell] + self.label_right[celltype] - self.Bipartite_Graph[cell][celltype] 
                if gap == 0: 
                    self.visit_right[celltype] = True
                    if torch.isnan(self.match_right[celltype]) or self.match(self.match_right[celltype]):
                        self.match_right[celltype] = cell# 递归匹配，匹配成功
                        return 1
                elif self.inc > gap:  #找到权值最小的差距
                    self.inc = gap  
        return 0

    def Kuh_Munkras(self): 
        self.match_right = torch.empty(self.left).fill_(float('nan')).to(device='cuda')
        for cell in range(self.left):
            while True:
                self.inc = 1000*1000  # the minimum gap
                self.reset()  # 每次寻找过的路径，所有要重置一下
                if self.match(cell):
                    break
                for k in range(self.left):
                    if self.visit_left[k]:
                        self.label_left.values[k] -= self.inc 
                for n in range(self.right):
                    if self.visit_right[n]:
                        self.label_right[n] += self.inc
        return self.fail_cell

    def calculateSum(self):
        sum = 0
        cells_celltypes = []
        self.fail_cell = [i for i in range(self.left)]
        for i in range(self.right_true):
            if not torch.isnan(self.match_right[i]):
                sum += self.Bipartite_Graph[int(self.match_right[i])][i]
                cell_celltype = (int(self.match_right[i]), i)
                cells_celltypes.append(cell_celltype)
                self.fail_cell.remove(int(self.match_right[i]))
        return cells_celltypes, self.fail_cell

    def getResult(self):
        return self.match_right

    def reset(self):
        self.visit_left = torch.zeros(self.left, dtype=bool)
        self.visit_right = torch.zeros(self.right, dtype=bool)
        

In [5]:
import torch
import pandas as pd
weight = torch.softmax(weight, dim=-1) 
square_weight = (weight ** 2) / (weight ** 2).sum(dim=1, keepdim=True) #得到初始二分图，权重为扩大差异后的概率 

## part 1: argmax ##
index_argmax = [] #如果模型认为细胞是某个细胞类型的概率大于0.5，就直接用argmax转成hard assignment
index_km = []

for i in range(len(square_weight)):
    if (square_weight[i] > 0.5).any():
        index_argmax.append(i)
    else:
        index_km.append(i)

pred_indices = square_weight[index_argmax].argmax(dim=1)
argmax_df = pd.Series(pred_indices, index=index_argmax)

In [6]:
from tqdm import tqdm
import numpy as np


In [26]:
km = KM_Algorithm(square_weight[index_km].to(device='cuda'))  #输入二分图·

In [27]:
km.Kuh_Munkras()  # 匹配

KeyboardInterrupt: 

In [8]:
cells_celltypes, _ = km.calculateSum() #

<__main__.KM_Algorithm at 0x7ff65422aca0>

In [8]:
import torch
label_left = torch.max(weight, axis=1)  # 设置左边顶标为权重最大值（每行的最大值）
label_left

torch.return_types.max(
values=tensor([7.2221, 6.4928, 5.8720,  ..., 8.1468, 7.9132, 7.7422]),
indices=tensor([2, 2, 2,  ..., 2, 2, 2]))

In [9]:
left = weight.shape[0]  # 以左边（细胞）为主
right_true = weight.shape[1] 
right = weight.shape[1] + left

In [14]:
right

20821

In [15]:
label_right = torch.zeros(right)  # 右边集合的顶标设置为0
label_right

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [33]:
cell=0
celltype=0
gap = label_left.values[cell] + label_right[celltype] - weight[cell][celltype] 

In [34]:
gap

tensor(4.8397)

In [35]:
gap ==0

tensor(False)

In [36]:
if gap == 0:
    print ('a')

In [43]:
label_left.values[0] -= 1

In [44]:
label_left

torch.return_types.max(
values=tensor([5.2221, 5.4928, 4.8720,  ..., 7.1468, 6.9132, 6.7422]),
indices=tensor([2, 2, 2,  ..., 2, 2, 2]))

In [52]:
label_right[0] += 1

In [53]:
label_right

tensor([1., 0., 0.,  ..., 0., 0., 0.])

### cpu

In [4]:
import torch
import pandas as pd
weight = torch.softmax(weight, dim=-1) 
square_weight = (weight ** 2) / (weight ** 2).sum(dim=1, keepdim=True) #得到初始二分图，权重为扩大差异后的概率 

## part 1: argmax ##
index_argmax = [] #如果模型认为细胞是某个细胞类型的概率大于0.5，就直接用argmax转成hard assignment
index_km = []

for i in range(len(square_weight)):
    if (square_weight[i] > 0.5).any():
        index_argmax.append(i)
    else:
        index_km.append(i)

pred_indices = square_weight[index_argmax].argmax(dim=1)
argmax_df = pd.Series(pred_indices, index=index_argmax)

In [5]:
class KM_Algorithm:
    """
    1.最大权重匹配
    2.输入的二分图应该是一个经过soft assignment的scores或prob矩阵
    3.二分图是以left：cell，rights：celltype来写的
    """

    def __init__(self, Bipartite_Graph):

        self.Bipartite_Graph = Bipartite_Graph

        # 左右结点数量记录
        self.left = self.Bipartite_Graph.shape[0]  # 以左边（细胞）为主
        self.right_true = self.Bipartite_Graph.shape[1] 
        self.right = self.Bipartite_Graph.shape[1] + self.left
        self.reshape_graph() 

        # step1:最高标准初始化（顶标）
        self.label_left = np.max(self.Bipartite_Graph, axis=1)  # 设置左边顶标为权重最大值（每行的最大值）
        self.label_right = np.zeros(self.right)  # 右边集合的顶标设置为0

        # 初始化辅助变量——是否已匹配
        self.visit_left = np.zeros(self.left, dtype=bool) #全是false
        self.visit_right = np.zeros(self.right, dtype=bool)

        # 初始化右边的匹配结果.如果已匹配就会对应匹配结果
        self.match_right = np.empty(self.right) * np.nan #全是nan

        # 用inc记录需要减去的权值d，不断取最小值故初始化为较大值。权值都为负数，应该不用很大也行
        self.inc = 1000*1000*1000
        self.fail_cell = list()  # 每次匹配重新创建一个二分图匹配对象，所以这个也不用手动重置了

    def reshape_graph(self):
        new = np.ones((self.left, self.left)) * 0 #全0方阵
        self.Bipartite_Graph = np.column_stack((self.Bipartite_Graph, new)) #在右边拼上矩阵
        #new = torch.zeros((self.left, self.left)).to(device='cuda')
        #self.Bipartite_Graph = torch.cat((self.Bipartite_Graph, new),dim=1)
    def match(self, cell):
        cell = int(cell) 
        self.visit_left[cell] = True # 记录下这个cell已经被寻找
        
        for celltype in range(self.right):
            if not self.visit_right[celltype] and self.Bipartite_Graph[cell][celltype] >= 0:    
                  # 如果这个celltype还没访问过       # celltype仍未匹配并且它们之间存在匹配的可能性(不可匹配的点设置为负数，取反后变正数,故正数不可取)
                gap = self.label_left[cell] + self.label_right[celltype] - self.Bipartite_Graph[cell][celltype]  # gap也不会取到不能匹配的那条边
                if gap == 0:   # 差值为0，是可行的替换。所以可以直接尝试替换。后面不行再去将这个一起减去gap。这个列表是记录希望匹配的
                    self.visit_right[celltype] = True
                    # celltype未被匹配，或虽然已被匹配，但是已匹配对象(cell)有其他可选备胎。这里那些是否已访问的列表不需要重置，因为不改变前面的尝试匹配
                    if np.isnan(self.match_right[celltype]) or self.match(self.match_right[celltype]):
                        self.match_right[celltype] = cell# 递归匹配，匹配成功
                        return 1
                # 找到权值最小的差距
                elif self.inc > gap:
                    self.inc = gap  # 等于0的gap不会存在这，所以只要存在可能匹配的情况，gap就不会等于原来
        return 0

    def Kuh_Munkras(self):
        self.match_right = np.empty(self.right) * np.nan
        
        for cell in range(self.left):
            while True:
                self.inc = 1000*1000  # the minimum gap
                self.reset()  # 每次寻找过的路径，所有要重置一下
                # 可找到可行匹配
                if self.match(cell):
                    break #如果返回1的话就会立即break这个while循环
                # 不能找到可行匹配
                # (1)将所有在增广路中的cell方点的label全部减去最小常数
                # (2)将所有在增广路中的celltype方点的label全部加上最小常数
                for k in range(self.left):
                    if self.visit_left[k]:
                        self.label_left[k] -= self.inc
                for n in range(self.right):
                    if self.visit_right[n]:
                        self.label_right[n] += self.inc
        return self.fail_cell

    def calculateSum(self):
        sum = 0
        cells_celltypes = []
        self.fail_cell = [i for i in range(self.left)]
        for i in range(self.right_true):
            if not np.isnan(self.match_right[i]):
                sum += self.Bipartite_Graph[int(self.match_right[i])][i]
                cell_celltype = (int(self.match_right[i]), i)
                cells_celltypes.append(cell_celltype)
                self.fail_cell.remove(int(self.match_right[i]))
         #得到的sum是最短路径
        return cells_celltypes, self.fail_cell
            #匹配成功           #匹配失败
        
    def getResult(self):
        return self.match_right

    def reset(self):
        self.visit_left = np.zeros(self.left, dtype=bool)
        self.visit_right = np.zeros(self.right, dtype=bool)

In [9]:
## part2
from tqdm import tqdm
import numpy as np
loops = 0
total_loops = len(index_km)//8 +1
pred_indices = []
pred_celltype=[] 

for loops in tqdm(range(total_loops)): 
    km = KM_Algorithm(square_weight[index_km])  #输入二分图.to(device='cuda')
    km.Kuh_Munkras()  # 匹配
    cells_celltypes, _ = km.calculateSum() #.to(device='cpu').detach()

    cells = [index_km[cells[0]] for cells in cells_celltypes] #对应到原始索引的细胞index
    celltypes = [cells[1] for cells in cells_celltypes] #细胞名不需要对应回去
    pred_indices += cells #因为是列表
    pred_celltype += celltypes
    index_km = [x for x in index_km if x not in cells]

print("==> KM algortithm over!")

100%|██████████| 7/7 [00:00<00:00, 44.14it/s]

==> KM algortithm over!





In [10]:
km_df = pd.DataFrame(pred_celltype, index=pred_indices)

In [12]:
## part 3: prediction results ##
pred_df = pd.concat([argmax_df,km_df],axis=0)
pred_df = pred_df.sort_index(ascending=True) #细胞数量行1列的矩阵，行索引是细胞index，值是预测的label
proportion = pd.DataFrame(pred_df[0].value_counts()/len(pred_df))
proportion = proportion.sort_index(ascending=True) #细胞类型数量行1列的矩阵
proportion.index = celltype
proportion = pd.Series(proportion.iloc[:,0])#不然后面有些函数还得分类讨论

In [13]:
proportion

B Cell              0.003411
Endothelial Cell    0.000336
Epithelial          0.983616
Fibroblast          0.000577
Melanocyte          0.000336
Myeloid cells       0.009321
NK                  0.000769
Tcell               0.001634
Name: 0, dtype: float64

In [16]:
from utils import adjust
proportion_1 = adjust(proportion) #不知道为啥，老是导入就不行，单独用那个函数就可以

In [17]:
proportion_1

B Cell              0.003405
Endothelial Cell    0.000998
Epithelial          0.981670
Fibroblast          0.000998
Melanocyte          0.000998
Myeloid cells       0.009303
NK                  0.000998
Tcell               0.001630
Name: 0, dtype: float64

In [18]:
pred_proportion = proportion.rename("predict_1")

In [19]:
pred_proportion

B Cell              0.003411
Endothelial Cell    0.000336
Epithelial          0.983616
Fibroblast          0.000577
Melanocyte          0.000336
Myeloid cells       0.009321
NK                  0.000769
Tcell               0.001634
Name: predict_1, dtype: float64

In [28]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, precision_recall_fscore_support

def metrics(scores, labels, mode, pred_labels):
    '''
    evaluate the prediction results
    soft assignment对于单个细胞预测的准确率还是得利用hard assignment，因为它跳过了对单个细胞的预测、直接预测比例
    '''
    dict_correct = {i: 0 for i in range(8)}
    dict_samples = {i: 0 for i in range(8)}
    y_true, y_pred=[], []
    cell_type = [i for i in range(8)]

    if mode == 'soft' or 'hard':
        _, prediction = scores.max(dim=1) #predictions是(cell_num,1)的矩阵（张量的写法是(cell_num,)）
        _, indices = labels.max(dim=1)#_是一个占位符，第一个是样本的最大值，第二个是其索引
    elif mode == 'km':
        if pred_labels is None:
            raise ValueError("The pred_labels is required!") 
        else:
            prediction = pred_labels
            _, indices = labels.max(dim=1)
    else:
        raise ValueError("Invalid mode!")
        
    for i in range(prediction.size(0)):
        m = indices[i].item() 
        n = prediction[i].item()
        y_true.append(m)
        y_pred.append(n)
        if m in dict_correct:
            dict_samples[m] += 1
            if m == n:
                dict_correct[m] += 1
        else:
            print("The number of cell type is not suitable!")
                
    #（1）每个cell type的指标
    p,r,f,num_true=precision_recall_fscore_support(y_true=y_true,y_pred=y_pred,labels=cell_type,average=None)
    precision = pd.DataFrame(p)  
    recall = pd.DataFrame(r)
    
    #（2）整体的指标
    f1_score_micro = f1_score(y_true, y_pred,labels=cell_type,average='micro')
    f1_score_macro = f1_score(y_true, y_pred,labels=cell_type,average='macro')
    f1score = pd.DataFrame([f1_score_micro,f1_score_macro], index=['f1score_micro','f1score_macro'])   
    
    num_correct = sum(dict_correct.values())
    num_samples = sum(dict_samples.values())
    accuracy = float(num_correct) / float(num_samples)
    print(f"Got {num_correct} / {num_samples} with accuracy' {accuracy * 100:.2f}%")
        
    return accuracy, f1score, precision, recall

In [29]:
_, f1scores, precision, recall = metrics(weight,labels,'km', pred_df)

Got 15726 / 20813 with accuracy' 75.56%


In [30]:
f1scores

Unnamed: 0,0
f1score_micro,0.755585
f1score_macro,0.130155


In [31]:
precision

Unnamed: 0,0
0,0.0
1,0.0
2,0.758165
3,0.666667
4,0.0
5,0.871287
6,0.0
7,0.416667


In [32]:
recall

Unnamed: 0,0
0,0.0
1,0.0
2,0.995067
3,0.017937
4,0.0
5,0.044177
6,0.0
7,0.033259
