In [15]:
import torch
import torch.nn as nn
import obonet
import numpy as np
import json

In [16]:
# customize a loss

class TreeLoss(torch.nn.Module):
    def __init__(self, lamda=0.5, cell_ontology=None, lca_dist_dict = None):
        super(TreeLoss, self).__init__()
        self.ce = torch.nn.CrossEntropyLoss()
        
        if not cell_ontology:
            onto_path = "http://purl.obolibrary.org/obo/cl/cl-basic.obo"
            self.cell_ontology = obonet.read_obo(onto_path)
        else:
            self.cell_ontology = cell_ontology
        # double check this mapping
        self.nodes = np.array(self.cell_ontology.nodes())
        
        self.lamda = lamda
            
        assert lca_dist_dict is not None
        
        self.lca_dist_dict = lca_dist_dict


    # pred: (N, C)
    # target: (N, C)
    def forward(self, pred, target):
        target_node_id = self.nodes[np.argmax(target, axis=1)]
        pred_node_id = self.nodes[np.argmax(pred, axis=1)]
        
        # total distance
        total_dist = 0
        for i in range(len(target_node_id)):
            total_dist += self.lca_dist(target_node_id[i], pred_node_id[i])
            
        # average distance
        total_dist /= len(target_node_id)
        
        return self.ce(pred, target) + self.lamda * total_dist
    
    def lca_dist(self, node1, node2):
        if node1 not in self.lca_dist_dict:
            return 0
        if node2 not in self.lca_dist_dict[node1]:
            return 0
        return self.lca_dist_dict[node1][node2]

In [17]:

# read /home/pangkuan/dev/course_work/csc311/continualTraining/notebooks/lca_distances.json
lca_dist_dict = json.load(open("/home/pangkuan/dev/course_work/csc311/continualTraining/notebooks/lca_distances.json", "r"))
tree_loss = TreeLoss(lamda=0.5, cell_ontology=None, lca_dist_dict = lca_dist_dict)

tree_loss(torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4]]), torch.tensor([[0., 0., 1., 0.], [0., 0., 1., 0.]]))


tensor(1.3425)

In [1]:
import scanpy as sc

data_path = "/home/pangkuan/dev/code_test_cz_cellxgene/eda/cellxgene50K.h5ad"
adata = sc.read(data_path)

In [2]:
adata

AnnData object with n_obs × n_vars = 50000 × 60664
    obs: 'soma_joinid', 'dataset_id', 'assay', 'assay_ontology_term_id', 'cell_type', 'cell_type_ontology_term_id', 'development_stage', 'development_stage_ontology_term_id', 'disease', 'disease_ontology_term_id', 'donor_id', 'is_primary_data', 'self_reported_ethnicity', 'self_reported_ethnicity_ontology_term_id', 'sex', 'sex_ontology_term_id', 'suspension_type', 'tissue', 'tissue_ontology_term_id', 'tissue_general', 'tissue_general_ontology_term_id'
    var: 'soma_joinid', 'feature_id', 'feature_name', 'feature_length'

In [4]:
adata.obs["disease"].unique()

['normal']
Categories (1, object): ['normal']

: 