In [69]:
import os
import csv
import pickle
import random
import pandas as pd
import numpy as np
import uuid
import json
import os.path as osp

from glob import glob
from tqdm import tqdm
from shutil import copy
from pprint import pprint
# import cv2
# import imageio
# import imagesize
from PIL import Image

import matplotlib.pyplot as plt
from itertools import groupby,chain,combinations
from functools import partial
from collections import defaultdict, Counter, OrderedDict
from scipy.special import softmax

from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, cohen_kappa_score, accuracy_score
from sklearn.metrics import average_precision_score, roc_curve, auc, precision_recall_curve, f1_score, roc_auc_score

In [77]:
from anytree import Node, RenderTree, search
from anytree.importer import JsonImporter
from anytree.exporter import JsonExporter, DotExporter
import lzma
from nltk.tree import Tree

### helper function for processing the hierarchical classification

In [86]:
def find_level_name(class_name, level=1):
    path_node_classes = classname_paths[class_name]
    if len(path_node_classes)>level:
        return path_node_classes[level]
    else:
        return None
def find_level_name_v2(class_name, level=1):
    """fill the finest label using the coarse lable"""
    path_node_classes = classname_paths[class_name]
    if len(path_node_classes)>level:
        return path_node_classes[level]
    else:
        return path_node_classes[-1]
def lca_height(class_name1, class_name2, logarithmic=True):
    """lowest common ancestor height, taking the level into acount np.log(1+height)
    """
    node1 = search.find_by_attr(TCT, class_name1)
    node2 = search.find_by_attr(TCT, class_name2)
    node1_path_names = [x.name for x in node1.path]
    node2_path_names = [x.name for x in node2.path]
    if len(node1_path_names) == len(node2_path_names):
        height = 0
        for name1, name2 in list(zip(node1_path_names, node2_path_names))[::-1]:
            if name1==name2:
                return np.log(1+height) if logarithmic else height
            else:
                height +=1
    #             return name1
    else:
        common_length = len(set(node1_path_names).intersection(set(node2_path_names)))
        longest_length = max(len(node1_path_names), len(node2_path_names))
        height = longest_length - common_length
        return height
def tree2list(node):
    res = []
    if len(node.siblings)== 0:
        res.append([node.name])
    if len(node.children) > 0:
        res.append([x.name for x in node.children])
    for x in node.children:
        res.extend(tree2list(x))
    return res

In [49]:
# lca_height('AGC-NOS','LSIL', logarithmic=False)
# lca_height('AGC-NOS','AGC-NOS', logarithmic=False)

In [84]:
def sample_df(df,num=200, name='grade', classes=None):
    if classes is None:
        name_set = list(set(df[name]))
    else:
        name_set = list(set(df[name]).intersection(classes))
    print(name_set)
    df_sampled = pd.concat([df[df[name]==x].sample(min(num, df[df[name]==x].shape[0]), random_state=23) for x in name_set])
    return df_sampled

def df_v2(df):
    df_res = df.copy()
    print('Using all data')
    df_res['level_2'] = df_res['class_name'].apply(lambda x: find_level_name_v2(x, level=2))
    df_res['level_3'] = df_res['class_name'].apply(lambda x: find_level_name_v2(x, level=3))
    return df_res
def filter_df(df_in):
    df = df_in.copy()
    df = df[df['class_name'].isin(selected_classes)]
    df['class_id'] = df['class_name'].map(classes2id)
    return df

def evaluate(df, method=2, to_level=3, cal_auc=True):
    df_res = df.copy()
    if method==1:
        print('Using partial data...')  # 
    else:
        print('Using all data...')   # default way
        df_res['level_1'] = df_res['class_name'].apply(lambda x: find_level_name_v2(x, level=1))
        df_res['level_2'] = df_res['class_name'].apply(lambda x: find_level_name_v2(x, level=2))
        df_res['level_3'] = df_res['class_name'].apply(lambda x: find_level_name_v2(x, level=3))
        df_res['level_1_id'] = df_res['level_1'].map(level_1_names2id)
        df_res['level_2_id'] = df_res['level_2'].apply(lambda x: level_2_names2id.get(x, -1))
        df_res['level_3_id'] = df_res['level_3'].apply(lambda x: level_3_names2id.get(x, -1))
    acc_lst = []
    hierarchical_distances = []
    auc_lst = []
    f1_score_lst = []
    middle_class_metric = []
    #### level 1
    level_1_acc = accuracy_score(df_res['level_1'], df_res['level_1_pred'])
    acc_lst.append(level_1_acc)
    level_1_distance = df_res[['level_1','level_1_pred']].apply(lambda x: lca_height(x[0], x[1]), axis =1).mean() # 同层的距离
    hierarchical_distances.append(level_1_distance)
    if cal_auc:
        auc_lst.append(roc_auc_score(df_res['level_1_id'], df_res.filter(regex='order*',axis=1), multi_class='ovr'))
    f1_score_lst.append(f1_score(df_res['level_1'], df_res['level_1_pred'], average='macro'))

    print('Level_1 accuracy: {:.4f}'.format(level_1_acc))
    print('level_1 hierarchical distanse: {:.4f}'.format(level_1_distance))
    print('#'*30)
    
    #### level 2
    df_level2 = df_res[(~df_res['level_2'].isna())&(df_res['level_2']!='AGC')] # 不考虑AGC
    level_2_distance = df_level2[['level_2','level_2_pred']].apply(lambda x: lca_height(x[0], x[1]), axis =1).mean()
    hierarchical_distances.append(level_2_distance)

    level_2_acc = accuracy_score(df_level2['level_2'], df_level2['level_2_pred'])
    acc_lst.append(level_2_acc)
    if cal_auc:
        auc_lst.append(roc_auc_score(df_level2['level_2_id'], df_level2.filter(regex='family*',axis=1), multi_class='ovr'))
    f1_score_lst.append(f1_score(df_level2['level_2'], df_level2['level_2_pred'], average='macro'))
    print('Level_2 accuracy: {:.4f}'.format(level_2_acc))
    print('level_2 hierarchical distande: {:.4f}'.format(level_2_distance))
    print('#'*30)
    
    if to_level == 2:
        pass
    else:
        #### level 3
        df_level3 = df_res[(~df_res['level_3'].isna())&(~df_res['level_3'].isin(['AGC','AGC-NOS', 'ADC']))]
    #     df_level3_normalcls = df_level3[~df_level3['level_3_pred'].isna()]
        level_3_distance = df_level3[['level_3','level_3_pred']].apply(lambda x: lca_height(x[0], x[1]), axis =1).mean()
        hierarchical_distances.append(level_3_distance)
    #     df_level3.fillna(value='undercls',inplace=True)
        ## only consider intra-level confusion 
        level_3_acc = accuracy_score(df_level3['level_3'], df_level3['level_3_pred'])
        acc_lst.append(level_3_acc)
        if cal_auc:
            auc_lst.append(roc_auc_score(df_level3['level_3_id'], df_level3.filter(regex='species*',axis=1), multi_class='ovr'))
        f1_score_lst.append(f1_score(df_level3['level_3'], df_level3['level_3_pred'], average='macro'))
        print('Level_3 accuracy: {:.4f}'.format(level_3_acc))
        print('level_3 hierarchical distande: {:.4f}'.format(level_3_distance))
        print('#'*30)

        print('Average accuracy: {:.4f}'.format(np.mean(acc_lst)))
        print('Average hierarchical distance: {:.4f}'.format(np.mean(hierarchical_distances)))
    res_data = acc_lst + hierarchical_distances + auc_lst# + f1_score_lst
    return res_data

def add_level_pred(df):
    df['level_1_pred'] = [x.split('_')[-1] for x in df.filter(regex='order*',axis=1).idxmax(axis=1)]
    df['level_2_pred'] = [x.split('_')[-1] for x in df.filter(regex='family*',axis=1).idxmax(axis=1)]
    df['level_3_pred'] = [x.split('_')[-1] for x in df.filter(regex='species*',axis=1).idxmax(axis=1)]
    
    df['level_1_pred_score'] = df.filter(regex='order*',axis=1).max(axis=1)
    df['level_2_pred_score'] = df.filter(regex='family*',axis=1).max(axis=1)
    df['level_3_pred_score'] = df.filter(regex='species*',axis=1).max(axis=1)
    return df

def add_order_pred_score(df):
    for class_name in level_1_names:
        class_name_node = search.find_by_attr(TCT, class_name)
        leaf_node_names =[x.name for x in  class_name_node.children]
        df['order_{}'.format(class_name)] = df[['family_{}'.format(x) for x in leaf_node_names]].sum(axis=1)
    return df

def add_order_family_pred_score(df):
    for class_name in level_1_names:
        class_name_node = search.find_by_attr(TCT, class_name)
        leaf_node_names =[x.name for x in  class_name_node.leaves]
#         print(['species_{}'.format(x) for x in leaf_node_names])
        df['order_{}'.format(class_name)] = df[['species_{}'.format(x) for x in leaf_node_names]].sum(axis=1)
    
    for class_name in level_2_names:
        class_name_node = search.find_by_attr(TCT, class_name)
        leaf_node_names =[x.name for x in  class_name_node.leaves]
        df['family_{}'.format(class_name)] = df[['species_{}'.format(x) for x in leaf_node_names]].sum(axis=1)
    return df

def conditional_pred_score(df):
    for class_name in level_1_names:
        class_name_node = search.find_by_attr(TCT, class_name)
        path_node_classes = [x.name for x in class_name_node.path][1:]
        df['order_{}'.format(class_name)] = df[path_node_classes].prod(axis=1)
    
    for class_name in level_2_names:
        class_name_node = search.find_by_attr(TCT, class_name)
        path_node_classes = [x.name for x in class_name_node.path][1:]
        df['family_{}'.format(class_name)] = df[path_node_classes].prod(axis=1)
    
    for class_name in level_3_names:
        class_name_node = search.find_by_attr(TCT, class_name)
        path_node_classes = [x.name for x in class_name_node.path][1:]
        df['species_{}'.format(class_name)] = df[path_node_classes].prod(axis=1)
    return df

def fgvc_helper(df_species_fgvc):
    df_species_fgvc['pred_class'] = [level_3_names.index(x.split('_')[-1]) for x in df_species_fgvc.filter(regex='species*',axis=1).idxmax(axis=1)]
    df_species_fgvc['level_3_pred']=df_species_fgvc['pred_class'].map(dict(zip(range(len(level_3_names)),level_3_names)))
    df_species_fgvc['level_2_pred']=df_species_fgvc['level_3_pred'].apply(lambda x: find_level_name_v2(x, level=2))
    df_species_fgvc['level_1_pred']=df_species_fgvc['level_3_pred'].apply(lambda x: find_level_name_v2(x, level=1))
    return df_species_fgvc

def get_res_summary(res_lst, model_lst):
    df_summary = pd.DataFrame(res_lst,
             index=model_lst,
            columns=['level_1_acc', 'level_2_acc', 'level_3_acc', 'level_1_hier_dist', 'level_2_hier_dist', 'level_3_hier_dist',
                     'level_1_auc', 'level_2_auc', 'level_3_auc', 
                     #'level_1_f1_score', 'level_2_f1_score', 'level_3_f1_score',
                    #'AGC_acc', 'AGC_hier_dist', 'AGC-NOS/ADC_acc', 'AGC-NOS/ADC_hier_dist'
                    ])
    df_summary['avg_acc']=df_summary[['level_1_acc','level_2_acc','level_3_acc']].mean(axis=1)
    df_summary['avg_hier_dist']=df_summary[['level_1_hier_dist','level_2_hier_dist','level_3_hier_dist']].mean(axis=1)
    columns = ['level_1_acc', 'level_2_acc', 'level_3_acc', 'avg_acc', 'level_1_hier_dist', 'level_2_hier_dist', 'level_3_hier_dist', 'avg_hier_dist',
               'level_1_auc', 'level_2_auc', 'level_3_auc', 
               #'level_1_f1_score', 'level_2_f1_score', 'level_3_f1_score',
                        #'AGC_acc', 'AGC_hier_dist', 'AGC-NOS/ADC_acc', 'AGC-NOS/ADC_hier_dist'
              ]
    df_summary = df_summary[columns]
    return df_summary.round(4)

### hierarchy class in HiCervix

In [109]:
### hierarchical tree construction for HiCervix
## the full names of the acronyms are listed in dataset/hierarchy_classification/version2023/hierarchy_names.csv
TCT = Node("TCT")
negative = Node("negative",parent=TCT) 
ASC = Node("ASC", parent=TCT) 
AGC = Node("AGC", parent=TCT) 
microbe = Node("microbe",parent=TCT)  ## microbe is the class  Organisms

ASCUS = Node("ASC-US", parent=ASC)
LSIL = Node("LSIL", parent=ASC)
ASCH = Node("ASC-H", parent=ASC)
HSIL = Node("HSIL", parent=ASC)
SCC = Node("SCC", parent=ASC)

AGCNOS = Node("AGC-NOS", parent=AGC)
AGCFN = Node("AGC-FN", parent=AGC)
ADC = Node("ADC", parent=AGC)

AGCNOS1 = Node("AGC-ECC-NOS", parent=AGCNOS)
AGCNOS2 = Node("AGC-EMC-NOS", parent=AGCNOS)

ADC1 = Node("ADC-ECC", parent=ADC)
ADC2 = Node("ADC-EMC", parent=ADC)

normal = Node("Normal", parent=negative)
endocervical = Node("ECC", parent=negative)
xiufu = Node("RPC", parent=negative)
huasheng = Node("MPC", parent=negative)
glucose = Node("PG", parent=negative)
Atrophy = Node("Atrophy", parent=negative)
EMC = Node("EMC", parent=negative)
HCG = Node("HCG", parent=negative)

FUNGI = Node("FUNGI", parent=microbe)
ACTINO = Node("ACTINO", parent=microbe)
TRI = Node("TRI", parent=microbe)
HSV = Node("HSV", parent=microbe)
CC = Node("CC", parent=microbe)

# darkclusters = Node("HCG", parent=negative)
# DotExporter(TCT_en).to_picture("TCT.png")

In [79]:
### hierarchical tree construction for HiCervix,chinese version if needed 
## which could also serve as the dictionary for english names and chinese names tranlation (if needed)
# TCT = Node("TCT")
# ASC = Node("ASC", parent=TCT) # 鳞状上皮病变
# AGC = Node("AGC", parent=TCT) # 腺上皮病变
# microbe = Node("microbe",parent=TCT) #微生物类
# negative = Node("negative",parent=TCT) # 其他类
# # OTHERS = Node("OTHERS",parent=TCT) # 其他病变类

# # 鳞状上皮病变
# ASCUS = Node("ASC-US", parent=ASC)
# LSIL = Node("LSIL", parent=ASC)
# ASCH = Node("ASC-H", parent=ASC)
# HSIL = Node("HSIL", parent=ASC)
# SCC = Node("SCC", parent=ASC)

# # 腺上皮病变
# AGCNOS = Node("AGC-NOS", parent=AGC)
# AGCFN = Node("AGC-FN", parent=AGC)
# ADC = Node("ADC", parent=AGC)

# AGCNOS1 = Node("非典型子宫内膜细胞", parent=AGCNOS)
# AGCNOS2 = Node("非典型颈管腺细胞", parent=AGCNOS)

# ADC1 = Node("颈管腺癌", parent=ADC)
# ADC2 = Node("子宫内膜腺癌", parent=ADC)

# TRI = Node("滴虫", parent=microbe)
# CC = Node("细菌性阴道病", parent=microbe)
# FUNGI = Node("放线菌", parent=microbe)
# ACTINO = Node("念珠菌", parent=microbe)
# HSV = Node("疱疹病毒感染", parent=microbe)
# # CMV = Node("多核巨细胞", parent=microbe)

# EM = Node("子宫内膜细胞", parent=negative)
# PM = Node("萎缩性改变", parent=negative)
# endocervical = Node("宫颈管细胞", parent=negative)
# xiufu = Node("修复细胞", parent=negative)
# huasheng = Node("化生细胞", parent=negative)
# glucose = Node("糖原溶解细胞", parent=negative)
# darkclusters = Node("深染细胞团", parent=negative)
# normal = Node("其他正常细胞", parent=negative)
# scancer = Node("小细胞癌", parent=OTHERS)
# TCT.leaves
# tmp = tree2list(TCT)
# len(list(chain(*tmp[1:])))
# DotExporter(TCT).to_picture("TCT.png")

In [103]:
# chn2english_names

In [104]:
# print([chn2english_names[x] for x in level_3_names])

In [113]:
class_names = ['Normal', 'ECC', 'RPC', 'MPC', 'PG', 'Atrophy', 'EMC', 'HCG', 'ASC-US', 'LSIL',
               'ASC-H', 'HSIL', 'SCC', 'AGC', 'AGC-NOS', 'AGC-FN', 'ADC', 'AGC-ECC-NOS', 'AGC-EMC-NOS', 
               'ADC-ECC', 'ADC-EMC', 'FUNGI', 'ACTINO', 'TRI', 'HSV', 'CC']
level_1_names = ['negative','ASC','AGC','microbe']
level_2_names = ['Normal', 'ECC', 'RPC', 'MPC', 'PG', 'Atrophy', 'EMC', 'HCG ', 'ASC-US', 'LSIL', 
                 'ASC-H', 'HSIL', 'SCC', 'AGC-NOS', 'AGC-FN', 'ADC', 'FUNGI', 'ACTINO', 'TRI', 'HSV', 'CC']
level_3_names = ['Normal', 'ECC', 'RPC', 'MPC', 'PG', 'Atrophy', 'EMC', 'HCG ', 'ASC-US', 'LSIL', 'ASC-H', 'HSIL', 'SCC', 'AGC-FN', 
                 'AGC-ECC-NOS', 'AGC-EMC-NOS', 'ADC-ECC', 'ADC-EMC', 'FUNGI', 'ACTINO', 'TRI', 'HSV', 'CC']
# class_names = ['其他正常细胞','宫颈管细胞','修复细胞','化生细胞','糖原溶解细胞','萎缩性改变','子宫内膜细胞','深染细胞团',
#                'ASC-US','LSIL','ASC-H','HSIL', 'SCC',
#                'AGC','AGC-NOS','AGC-FN','ADC',  
#                 '非典型颈管腺细胞','非典型子宫内膜细胞', '颈管腺癌','子宫内膜腺癌',
#                '念珠菌','放线菌','滴虫','疱疹病毒感染','细菌性阴道病']
# level_1_names = ['negative','ASC','AGC','microbe']
# level_2_names = ['其他正常细胞','宫颈管细胞','修复细胞','化生细胞','糖原溶解细胞','萎缩性改变','子宫内膜细胞','深染细胞团',
#                'ASC-US','LSIL','ASC-H','HSIL', 'SCC',
#                 'AGC-NOS','AGC-FN','ADC',
#                '念珠菌','放线菌','滴虫','疱疹病毒感染','细菌性阴道病']
# # no intermediate classes such as AGC
# level_3_names = ['其他正常细胞','宫颈管细胞','修复细胞','化生细胞','糖原溶解细胞','萎缩性改变','子宫内膜细胞','深染细胞团',
#                'ASC-US','LSIL','ASC-H','HSIL', 'SCC',
#                 'AGC-FN', #'AGC','AGC-NOS', 'ADC',
#                 '非典型颈管腺细胞','非典型子宫内膜细胞', '颈管腺癌','子宫内膜腺癌',
#                '念珠菌','放线菌','滴虫','疱疹病毒感染','细菌性阴道病']
level_1_names2id = dict(zip(level_1_names, range(len(level_1_names))))
level_2_names2id = dict(zip(level_2_names, range(len(level_2_names))))
level_3_names2id = dict(zip(level_3_names, range(len(level_3_names))))

level_names = [level_1_names, level_2_names,level_3_names,]
classname2newid = dict(zip(class_names, range(len(class_names))))
newid2classname= {v:k for k,v in classname2newid.items()}
print('Number of total annotated classes: {}'.format(len(class_names)))
for i in range(3):
    print('Number of level {} classes: {}'.format(i+1, len(level_names[i])))

Number of total annotated classes: 26
Number of level 1 classes: 4
Number of level 2 classes: 21
Number of level 3 classes: 23


In [115]:
classname_paths = {}
for class_name in class_names:
    # print(class_name)
    class_name_node = search.find_by_attr(TCT, class_name)
    path_node_classes = [x.name for x in class_name_node.path]#[1:] #exclude the root node of 'TCT'
#     print('The nodes in the path (from root node) to reach the nodes of {}'.format(class_name))
#     print(path_node_classes)
    classname_paths[class_name] = path_node_classes

In [116]:
classname_paths

{'Normal': ['TCT', 'negative', 'Normal'],
 'ECC': ['TCT', 'negative', 'ECC'],
 'RPC': ['TCT', 'negative', 'RPC'],
 'MPC': ['TCT', 'negative', 'MPC'],
 'PG': ['TCT', 'negative', 'PG'],
 'Atrophy': ['TCT', 'negative', 'Atrophy'],
 'EMC': ['TCT', 'negative', 'EMC'],
 'HCG': ['TCT', 'negative', 'HCG'],
 'ASC-US': ['TCT', 'ASC', 'ASC-US'],
 'LSIL': ['TCT', 'ASC', 'LSIL'],
 'ASC-H': ['TCT', 'ASC', 'ASC-H'],
 'HSIL': ['TCT', 'ASC', 'HSIL'],
 'SCC': ['TCT', 'ASC', 'SCC'],
 'AGC': ['TCT', 'AGC'],
 'AGC-NOS': ['TCT', 'AGC', 'AGC-NOS'],
 'AGC-FN': ['TCT', 'AGC', 'AGC-FN'],
 'ADC': ['TCT', 'AGC', 'ADC'],
 'AGC-ECC-NOS': ['TCT', 'AGC', 'AGC-NOS', 'AGC-ECC-NOS'],
 'AGC-EMC-NOS': ['TCT', 'AGC', 'AGC-NOS', 'AGC-EMC-NOS'],
 'ADC-ECC': ['TCT', 'AGC', 'ADC', 'ADC-ECC'],
 'ADC-EMC': ['TCT', 'AGC', 'ADC', 'ADC-EMC'],
 'FUNGI': ['TCT', 'microbe', 'FUNGI'],
 'ACTINO': ['TCT', 'microbe', 'ACTINO'],
 'TRI': ['TCT', 'microbe', 'TRI'],
 'HSV': ['TCT', 'microbe', 'HSV'],
 'CC': ['TCT', 'microbe', 'CC']}

In [117]:
# classname_paths

In [None]:
# hierarchy for making better mistakes
tct = Tree.fromstring("(TCT (negative (Normal Normal) (ECC ECC) (RPC RPC) (MPC MPC) (PG PG) (Atrophy Atrophy) (EMC EMC) (HCG HCG))  \
                     (ASC (ASC-US ASC-US) (LSIL LSIL) (ASC-H ASC-H) (HSIL HSIL) (SCC SCC))   \
                     (AGC (AGC-FN AGC-FN) (AGC-NOS AGC-ECC-NOS AGC-EMC-NOS) (ADC ADC-ECC ADC-EMC)) \
                     (microbe (FUNGI FUNGI) (ACTINO ACTINO) (TRI TRI) (HSV HSV) (CC CC)) \
                     )")
tct_2level = Tree.fromstring("(TCT (negative Normal  ECC RPC MPC  PG  Atrophy  EMC  HCG)  \
                     (ASC  ASC-US  LSIL  ASC-H  HSIL  SCC)   \
                     (AGC AGC-FN  AGC-NOS  ADC ) \
                     (microbe  FUNGI  ACTINO  TRI  HSV  CC) \
                     )")

# tct = Tree.fromstring("(TCT (negative (其他正常细胞 其他正常细胞) (宫颈管细胞 宫颈管细胞) (修复细胞 修复细胞) (化生细胞 化生细胞) (糖原溶解细胞 糖原溶解细胞) (萎缩性改变 萎缩性改变) (子宫内膜细胞 子宫内膜细胞) (深染细胞团 深染细胞团))  \
#                      (ASC (ASC-US ASC-US) (LSIL LSIL) (ASC-H ASC-H) (HSIL HSIL) (SCC SCC))   \
#                      (AGC (AGC-FN AGC-FN) (AGC-NOS 非典型颈管腺细胞 非典型子宫内膜细胞) (ADC 颈管腺癌 子宫内膜腺癌)) \
#                      (microbe (念珠菌 念珠菌) (放线菌 放线菌) (滴虫 滴虫) (疱疹病毒感染 疱疹病毒感染) (细菌性阴道病 细菌性阴道病)) \
#                      )")
# tct_2level = Tree.fromstring("(TCT (negative 其他正常细胞  宫颈管细胞 修复细胞 化生细胞  糖原溶解细胞  萎缩性改变  子宫内膜细胞  深染细胞团)  \
#                      (ASC  ASC-US  LSIL  ASC-H  HSIL  SCC)   \
#                      (AGC AGC-FN  AGC-NOS  ADC ) \
#                      (microbe  念珠菌  放线菌  滴虫  疱疹病毒感染  细菌性阴道病) \
#                      )")

# hierarchy for making better mistakes
tct_distances = {}
all_names = set(level_1_names + level_2_names + level_3_names)
for name1 in all_names:
    for name2 in all_names:
        tct_distances[(name1,name2)] = lca_height(name1, name2)
# with open('classification_hierarchy/making-better-mistakes/data/tct_tree.pkl','wb') as fi:
#     pickle.dump(tct, fi)

# with open('classification_hierarchy/making-better-mistakes-2level/data/tct_tree.pkl','wb') as fi:
#     pickle.dump(tct_2level, fi)
# with open('classification_hierarchy/making-better-mistakes/data/tct_distances.pkl','wb') as fi:
#     pickle.dump(tct_distances, fi)

# with open('classification_hierarchy/making-better-mistakes/data/tct_tree.pkl','rb') as fi:
#     tree_data = pickle.load(fi)
# tree_data
# with lzma.open('/classification_hierarchy/making-better-mistakes/data/imagenet_ilsvrc_distances.pkl.xz', "rb") as f:
#     tmp = pickle.load(f)
# #     distance_data = DistanceDict(pickle.load(f))
#     distance_data = DistanceDict(tmp)
level_names_dict = dict(zip(['order','family','species'], level_names))
# with open('dataset/hierarchy_classification/version3/level_names_dict.pkl','wb') as fi:
#     pickle.dump(level_names_dict, fi)
# with open('dataset/hierarchy_classification/version3/level_names_dict.pkl','rb') as fo:
#     tmp = pickle.load(fo)

In [87]:
# for multi-head and FGoN
trees =[]
trees_names = []
trees_dict = {}
def find_class_id(class_name, lst):
    try:
        class_id = lst.index(class_name)
    except:
        class_id = -1
    return class_id
for i, class_name in enumerate(class_names):
    class_name_node = search.find_by_attr(TCT, class_name)
    path_node_classes = [x.name for x in class_name_node.path]#[1:] #exclude the root node of 'TCT'
    extended_node_names = [find_level_name_v2(class_name, i) for i in range(1,4)]
    extended_node_ids = [find_class_id(x, level_names[i]) for i, x in enumerate(extended_node_names)] #level_names[i].index(x)
    trees.append(extended_node_ids)
    trees_names.append(extended_node_names)
    trees_dict[class_name]=extended_node_ids 

### Data preprocessing for different methods

In [None]:
# HierSwin methods and making better mistakes
datasets = ['train.csv','val.csv','test.csv']
# datasets = ['train_image_path.csv','val_image_path.csv','test_image_path.csv']
csv_dir = 'dataset/hierarchy_classification/version2023/'
for dataset in datasets:
    csv_file = osp.join(csv_dir, dataset)
    df_tmp = pd.read_csv(csv_file)
    df_tmp = df_tmp[df_tmp['class_name'].isin(level_3_names)]
    df_tmp['level_3_id'] = df_tmp['class_name'].map(dict(zip(level_3_names,range(len(level_3_names)))))
    df_tmp.to_csv(csv_file.replace('.csv','_mbm.csv'),index=False)

In [None]:
## HRN
datasets = ['train_image_path.csv']#,'val_image_path.csv','test_image_path.csv']
csv_dir = 'dataset/hierarchy_classification/version2023/'
for dataset in datasets:
    csv_file = osp.join(csv_dir, dataset)
    df = pd.read_csv(csv_file)
    df['level_1_id'] = df['level_1'].map(level_1_names2id)
    df['level_2_name'] = df['class_name'].apply(lambda x: find_level_name_v2(x, level=2))
    df['level_3_name'] = df['class_name'].apply(lambda x: find_level_name_v2(x, level=3))
    df['level_2_id'] = df['level_2_name'].apply(lambda x: level_2_names2id.get(x, -1))
    df['level_3_id'] = df['level_3_name'].apply(lambda x: level_3_names2id.get(x, -1))
    #df[['image_path', 'image_name', 'class_id', 'level_1_id', 'level_2_id', 'level_3_id']].to_csv(csv_file.replace('.csv','_hrn.csv')) #encoding='utf-8-sig', ,  index=False
    df[['image_path', 'image_name', 'class_id', 'level_1_id', 'level_2_id', 'level_3_id']].to_csv('/mnt/group-ai-medical-abp/private/daviddecai/tmp.csv') #encoding='utf-8-sig', ,  index=False

In [None]:
### fine-grained visual classification

# level_3 finegrained，intermediate classes such as AGC,'AGC-NOS','ADC' are removed 
csv_dir = 'dataset/hierarchy_classification/version2023'
datasets = ['train','val','test']
for dataset in datasets:
    df_tmp = pd.read_csv(osp.join(csv_dir, dataset + '.csv'))
    #df_tmp = df_tmp[~df_tmp['class_name'].isin(['AGC','AGC-NOS','ADC'])] 
    df_tmp = df_v2(df_tmp)
    #df_tmp['level_3_id']= df_tmp['level_3'].map(dict(zip(level_3_names, range(len(level_3_names)))))
    df_tmp['level_3_id']= df_tmp['level_3'].apply(lambda x: level_3_names2id.get(x, -1))
    df_tmp.to_csv(osp.join(csv_dir, dataset + '_keep_species_all.csv'))

### Evaluation for different methods

In [87]:
df_multi = pd.read_csv('classification_hierarchy/vanilla_single/output_0512_384/test_image_path_res.csv')
df_multi = add_level_pred(df_multi)
multi2 = evaluate(df_multi,method=2)

Using all data...
Level_1 accuracy: 0.8809
level_1 hierarchical distanse: 0.0826
##############################
Level_2 accuracy: 0.6406
level_2 hierarchical distande: 0.3028
##############################
Level_3 accuracy: 0.6389
level_3 hierarchical distande: 0.3605
##############################
Average accuracy: 0.7201
Average hierarchical distance: 0.2486


In [74]:
df_mbm = pd.read_csv('classification_hierarchy/making-better-mistakes/hxe_tct_alpha0.4_0512_384/test_image_path_mbm_res.csv')
df_mbm = add_order_family_pred_score(df_mbm)
df_mbm = add_level_pred(df_mbm)
mbm2 = evaluate(df_mbm,method=2)

Using all data...
Level_1 accuracy: 0.9048
level_1 hierarchical distanse: 0.0660
##############################
Level_2 accuracy: 0.7577
level_2 hierarchical distande: 0.2072
##############################
Level_3 accuracy: 0.7566
level_3 hierarchical distande: 0.2396
##############################
Average accuracy: 0.8064
Average hierarchical distance: 0.1709


In [9]:
## HierSwin = making better mistakes + Swin-Transformer
df_mbm_swint = pd.read_csv('test_image_path_mbm_res_api_swint_epoch30.csv')
df_mbm_swint = add_order_family_pred_score(df_mbm_swint)
df_mbm_swint = add_level_pred(df_mbm_swint)
mbm_swint = evaluate(df_mbm_swint,method=2)

Using all data...
Level_1 accuracy: 0.9208
level_1 hierarchical distanse: 0.0549
##############################
Level_2 accuracy: 0.7836
level_2 hierarchical distande: 0.1821
##############################
Level_3 accuracy: 0.7835
level_3 hierarchical distande: 0.2082
##############################
Average accuracy: 0.8293
Average hierarchical distance: 0.1484


In [12]:
# df_mbm_swint
def plt_roc_curve_multiclass(df, classnames=('negative','ASC','AGC','Organisms'), save_fig = False, return_res=False):
    class_ids = list(set(df['class_id']))
    class_ids.sort()
    class_id_names = dict(zip(range(len(class_ids)),classnames))
    colors_list = ['darkorange','deeppink', 'cornflowerblue','aqua']
    fprs = []
    tprs = []
    res = {}
    plt.figure()
    for i, pos_label in enumerate(class_ids):
        class_name = class_id_names[pos_label]
        gt = df['class_id']
        pred_score = df[class_name]
        fpr, tpr, thresholds = roc_curve(gt, pred_score, pos_label=pos_label)
        roc_auc = auc(fpr, tpr)
        lw = 1.5
        res[class_name]=[fpr, tpr]
        plt.plot(fpr, tpr, color=colors_list[i],
                 lw=lw, label='{} AUC = {:.3f}'.format(class_name, roc_auc))
        plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver operating characteristic')
        plt.legend(loc="lower right")
    doctor_fpr_tpr_v1 = [0.062267152511992795, 0.8034326881370635]
    doctor_fpr_tpr_v2 = [0.05344528593653488, 0.8211803296464975]
    res['doctor1'] = doctor_fpr_tpr_v1
    res['doctor2'] = doctor_fpr_tpr_v2
    plt.plot(doctor_fpr_tpr_v1[0], doctor_fpr_tpr_v1[1], marker="o", markersize=6, markeredgecolor="black", markerfacecolor="orange") #markerfacecolor="green"
    plt.plot(doctor_fpr_tpr_v2[0], doctor_fpr_tpr_v2[1], marker="o", markersize=6, markeredgecolor="black", markerfacecolor="orange") #markerfacecolor="green"
    if save_fig:
        plt.savefig('hierarchy_roc.png',dpi=300)
    plt.show()
    return res

In [24]:
def compare_alpha(csv_file):
    df_alpha =  pd.read_csv(csv_file)
    df_alpha = add_order_family_pred_score(df_alpha)
    df_alpha = add_level_pred(df_alpha)
    alpha = evaluate(df_alpha, method=2)
    return alpha

In [43]:
# compare alpha
# alpha2 = compare_alpha('classification_hierarchy/making-better-mistakes-swinT/hxe_tct_alpha0.4_0412_alpha0.2/test_image_path_mbm_res.csv')
# alpha3 = compare_alpha('classification_hierarchy/making-better-mistakes-swinT/hxe_tct_alpha0.4_0412_alpha0.3/test_image_path_mbm_res.csv')
# alpha4 = compare_alpha('test_image_path_mbm_res_api_swint_epoch30.csv')
# alpha5 = compare_alpha('classification_hierarchy/making-better-mistakes-swinT/hxe_tct_alpha0.4_0412_alpha0.5/test_image_path_mbm_res.csv')
# alpha6 = compare_alpha('classification_hierarchy/making-better-mistakes-swinT/hxe_tct_alpha0.4_0412_alpha0.6/test_image_path_mbm_res.csv')

In [28]:
df_summary_alpha = get_res_summary([alpha2, alpha3,  alpha4, alpha5, alpha6], [0.2, 0.3, 0.4, 0.5, 0.6])
df_summary_alpha
# df_summary_alpha.to_csv('hierarchy_0412/comparison_summary_alpha.csv')

Unnamed: 0,level_1_acc,level_2_acc,level_3_acc,avg_acc,level_1_hier_dist,level_2_hier_dist,level_3_hier_dist,avg_hier_dist,level_1_auc,level_2_auc,level_3_auc,level_1_f1_score,level_2_f1_score,level_3_f1_score
0.2,0.9188,0.7779,0.7773,0.8246,0.0563,0.1875,0.2166,0.1535,0.986,0.9804,0.9808,0.9103,0.7953,0.7852
0.3,0.9195,0.7845,0.7839,0.8293,0.0558,0.1824,0.2081,0.1488,0.9862,0.982,0.9827,0.9148,0.804,0.7984
0.4,0.9208,0.7836,0.7835,0.8293,0.0549,0.1821,0.2082,0.1484,0.987,0.9824,0.9829,0.9165,0.8002,0.7942
0.5,0.9152,0.7807,0.7802,0.8254,0.0588,0.1867,0.2149,0.1535,0.9859,0.9822,0.983,0.9092,0.7962,0.7875
0.6,0.9188,0.7839,0.7833,0.8287,0.0563,0.1831,0.2101,0.1498,0.9873,0.9821,0.9832,0.9149,0.8006,0.7907


In [71]:
df_FGoN = pd.read_csv('classification_hierarchy/FGoN/output_0412_384/test_image_path_res.csv')
df_FGoN = add_level_pred(df_FGoN)
FGoN2 = evaluate(df_FGoN,method=2)

Using all data...
Level_1 accuracy: 0.9085
level_1 hierarchical distanse: 0.0635
##############################
Level_2 accuracy: 0.7229
level_2 hierarchical distande: 0.2344
##############################
Level_3 accuracy: 0.7115
level_3 hierarchical distande: 0.2886
##############################
Average accuracy: 0.7810
Average hierarchical distance: 0.1955


In [72]:
## HRN 
df_hrn = pd.read_csv('classification_hierarchy/test_image_path_hrn_res.csv')
# multiple binary sigmoid for multi-class classification
df_hrn[['order_' + x for x in level_1_names]] = softmax(df_hrn.filter(regex='order*',axis=1).to_numpy(), axis = 1)
df_hrn[['family_' + x for x in level_2_names]] = softmax(df_hrn.filter(regex='family*',axis=1).to_numpy(), axis = 1)
df_hrn[['species_' + x for x in level_3_names]] = softmax(df_hrn.filter(regex='species*',axis=1).to_numpy(), axis = 1)
df_hrn = add_level_pred(df_hrn)
hrn2 = evaluate(df_hrn,method=2)

Using all data...
Level_1 accuracy: 0.9014
level_1 hierarchical distanse: 0.0684
##############################
Level_2 accuracy: 0.7435
level_2 hierarchical distande: 0.2198
##############################
Level_3 accuracy: 0.7665
level_3 hierarchical distande: 0.2253
##############################
Average accuracy: 0.8038
Average hierarchical distance: 0.1712


In [1107]:
level_2_class_acc = {}
level_2_class_recall = {}
level_2_class_precision = {}
for class_name in level_2_names:
    df_sub = df_lht[df_lht['level_2'].isin([class_name])]
    level_2_class_acc[class_name]=accuracy_score(df_sub['level_1'],df_sub['level_1_pred'])

In [99]:
# evaluate(df_families, to_level=2)

In [100]:
### order & family two level classification with making better mistakes
# df_families = pd.read_csv('csv_results/test_image_path_mbm_2level_res.csv')
df_families = pd.read_csv('csv_results/test_image_path_mbm_res_0512_swinT_2levels.csv')
df_families = add_order_pred_score(df_families)
df_families['level_2_pred'] = [x.split('_')[-1] for x in df_families.filter(regex='family*',axis=1).idxmax(axis=1)]
df_families['level_1_pred'] = df_families['level_2_pred'].apply(lambda x: find_level_name_v2(x, level=1))
# df_families['level_2_id'] = df_families['level_2'].apply(lambda x: level_2_names2id.get(x, -1))
evaluate(df_families, to_level=2)

Using all data...
Level_1 accuracy: 0.9125
level_1 hierarchical distanse: 0.0607
##############################
Level_2 accuracy: 0.7813
level_2 hierarchical distande: 0.1871
##############################


[0.9124676072559746,
 0.7813129858911604,
 0.06067283123818697,
 0.1870736183415377,
 0.9864062338514833,
 0.982281043867004]

In [107]:
### species fine-grained visual classification
df_api_net = pd.read_csv('/classification_FGVC/API-Net-master-384/test_image_path_keep_species_all_res.csv')
df_api_net = add_order_family_pred_score(df_api_net)
df_api_net = fgvc_helper(df_api_net)
api_net2 = evaluate(df_api_net, method=2)

Using all data...
Level_1 accuracy: 0.8619
level_1 hierarchical distanse: 0.0957
##############################
Level_2 accuracy: 0.7037
level_2 hierarchical distande: 0.2613
##############################
Level_3 accuracy: 0.7679
level_3 hierarchical distande: 0.2303
##############################
Average accuracy: 0.7778
Average hierarchical distance: 0.1958


In [80]:
## CAL model
df_cal = pd.read_csv('classification_FGVC/CAL/fgvc/test_image_path_keep_species_all_res.csv')
df_cal = add_order_family_pred_score(df_cal)
df_cal = fgvc_helper(df_cal)
cal2 = evaluate(df_cal, method=2)

Using all data...
Level_1 accuracy: 0.8555
level_1 hierarchical distanse: 0.1001
##############################
Level_2 accuracy: 0.7000
level_2 hierarchical distande: 0.2649
##############################
Level_3 accuracy: 0.7620
level_3 hierarchical distande: 0.2386
##############################
Average accuracy: 0.7725
Average hierarchical distance: 0.2012


In [81]:
##  swinT
df_swint = pd.read_csv('classification_hierarchy/vanilla_species_384/output_0512_swint/test_image_path_keep_species_all_res.csv')
df_swint = add_order_family_pred_score(df_swint)
df_swint = fgvc_helper(df_swint)
swint = evaluate(df_swint, method=2)

Using all data...
Level_1 accuracy: 0.8675
level_1 hierarchical distanse: 0.0919
##############################
Level_2 accuracy: 0.7080
level_2 hierarchical distande: 0.2558
##############################
Level_3 accuracy: 0.7738
level_3 hierarchical distande: 0.2234
##############################
Average accuracy: 0.7831
Average hierarchical distance: 0.1903


In [82]:
## FGVC-PIM model
df_pim = pd.read_csv('classification_FGVC/FGVC-PIM-master/test_image_path_keep_species_all_res.csv')
df_pim = add_order_family_pred_score(df_pim)
df_pim = fgvc_helper(df_pim)
pim2 = evaluate(df_pim, method=2)

Using all data...
Level_1 accuracy: 0.8649
level_1 hierarchical distanse: 0.0937
##############################
Level_2 accuracy: 0.7011
level_2 hierarchical distande: 0.2608
##############################
Level_3 accuracy: 0.7633
level_3 hierarchical distande: 0.2326
##############################
Average accuracy: 0.7764
Average hierarchical distance: 0.1957


In [29]:
# df_summary_fgvc = get_res_summary([transfg2, cal2, pim2, api_net2], ['TranFG', 'CAL', 'FGVC-PIM', 'API-net'])
# df_summary_fgvc
# df_summary_fgvc.to_csv('hierarchy_0412/comparison_summary_res_fgvc.csv')

In [101]:
# resnet50 model 
df_resnet50 = pd.read_csv('classification_hierarchy/vanilla_species/output_0512_resnet50_384/test_image_path_keep_species_all_res.csv')
df_resnet50 = add_order_family_pred_score(df_resnet50)
df_resnet50 = fgvc_helper(df_resnet50)
resnet50 = evaluate(df_resnet50, method=2)

Using all data...
Level_1 accuracy: 0.8629
level_1 hierarchical distanse: 0.0950
##############################
Level_2 accuracy: 0.7055
level_2 hierarchical distande: 0.2595
##############################
Level_3 accuracy: 0.7658
level_3 hierarchical distande: 0.2312
##############################
Average accuracy: 0.7781
Average hierarchical distance: 0.1953


In [91]:
# df_resnet101 = pd.read_csv('classification_hierarchy/vanilla_species/output_0512_resnet101/test_image_path_keep_species_all_res.csv')
# df_resnet101 = add_order_family_pred_score(df_resnet101)
# df_resnet101 = fgvc_helper(df_resnet101)
# resnet101 = evaluate(df_resnet101, method=2)

In [84]:
df_efficientnet = pd.read_csv('classification_hierarchy/vanilla_species/output_0512_efficientnet_b3_384/test_image_path_keep_species_all_res.csv')
df_efficientnet = add_order_family_pred_score(df_efficientnet)
df_efficientnet = fgvc_helper(df_efficientnet)
efficientnet = evaluate(df_efficientnet, method=2)

Using all data...
Level_1 accuracy: 0.8673
level_1 hierarchical distanse: 0.0919
##############################
Level_2 accuracy: 0.7037
level_2 hierarchical distande: 0.2593
##############################
Level_3 accuracy: 0.7665
level_3 hierarchical distande: 0.2260
##############################
Average accuracy: 0.7792
Average hierarchical distance: 0.1924


In [85]:
#mobilenet
df_mobilenet = pd.read_csv('classification_hierarchy/vanilla_species/output_0512_mobilenet_v2_384/test_image_path_keep_species_all_res.csv')
df_mobilenet = add_order_family_pred_score(df_mobilenet)
df_mobilenet = fgvc_helper(df_mobilenet)
mobilenet = evaluate(df_mobilenet, method=2)

Using all data...
Level_1 accuracy: 0.8629
level_1 hierarchical distanse: 0.0950
##############################
Level_2 accuracy: 0.7055
level_2 hierarchical distande: 0.2597
##############################
Level_3 accuracy: 0.7666
level_3 hierarchical distande: 0.2341
##############################
Average accuracy: 0.7783
Average hierarchical distance: 0.1963


In [104]:
df_summary = get_res_summary([resnet50, mobilenet, efficientnet, multi2, FGoN2, hrn2, mbm2, cal2, api_net2, pim2, swint, mbm_swint],
                             ['resnet50', 'mobilenet','efficientnet','MHN','FGoN', 'HRN','MBM','CAL', 'API-Net','FGVC-PIM','swint', 'HierSwin'])
df_summary

Unnamed: 0,level_1_acc,level_2_acc,level_3_acc,avg_acc,level_1_hier_dist,level_2_hier_dist,level_3_hier_dist,avg_hier_dist,level_1_auc,level_2_auc,level_3_auc
resnet50,0.8629,0.7055,0.7658,0.7781,0.095,0.2595,0.2312,0.1953,0.9726,0.9674,0.9828
mobilenet,0.8629,0.7055,0.7666,0.7783,0.095,0.2597,0.2341,0.1963,0.9742,0.9693,0.9831
efficientnet,0.8673,0.7037,0.7665,0.7792,0.0919,0.2593,0.226,0.1924,0.9747,0.9682,0.9815
MHN,0.8809,0.6406,0.6389,0.7201,0.0826,0.3028,0.3605,0.2486,0.9692,0.8714,0.8425
FGoN,0.9085,0.7229,0.7115,0.781,0.0635,0.2344,0.2886,0.1955,0.9759,0.9375,0.9252
HRN,0.9014,0.7435,0.7665,0.8038,0.0684,0.2198,0.2253,0.1712,0.9784,0.9352,0.981
MBM,0.9048,0.7577,0.7566,0.8064,0.066,0.2072,0.2396,0.1709,0.9818,0.979,0.9798
CAL,0.8555,0.7,0.762,0.7725,0.1001,0.2649,0.2386,0.2012,0.9739,0.9693,0.9821
API-Net,0.8604,0.6994,0.7635,0.7744,0.0968,0.264,0.2353,0.1987,0.9744,0.9627,0.9799
FGVC-PIM,0.8649,0.7011,0.7633,0.7764,0.0937,0.2608,0.2326,0.1957,0.9729,0.9671,0.9805


In [105]:
df_summary.to_csv('hierarchy_0412/benchmark_summary.csv')

In [91]:
tct_trees = np.array([[0, 0, 0],
                     [0, 1, 1],
                     [0, 2, 2],
                     [0, 3, 3],
                     [0, 4, 4],
                     [0, 5, 5],
                     [0, 6, 6],
                     [0, 7, 7],
                     [1, 8, 8],
                     [1, 9, 9],
                     [1, 10, 10],
                     [1, 11, 11],
                     [1, 12, 12],
                     [2, -1, -1],
                     [2, 13, -1],
                     [2, 14, 13],
                     [2, 15, -1],
                     [2, 13, 14],
                     [2, 13, 15],
                     [2, 15, 16],
                     [2, 15, 17],
                     [3, 16, 18],
                     [3, 17, 19],
                     [3, 18, 20],
                     [3, 19, 21],
                     [3, 20, 22]])

In [36]:
len(np.unique(tct_trees[:,2]))

24

In [None]:
4+21+23

In [None]:
0~3, 4~24, 25~47

In [40]:
tct_trees.shape

(26, 3)

In [118]:
class_id2class_name = dict(zip(df['class_id'],df['class_name']))
class_id2class_name = OrderedDict(sorted(class_id2class_name.items())) 
class_id2class_name

OrderedDict([(0, 'Normal'),
             (1, 'ECC'),
             (2, 'RPC'),
             (3, 'MPC'),
             (4, 'PG'),
             (5, 'Atrophy'),
             (6, 'EMC'),
             (7, 'HCG '),
             (8, 'ASC-US'),
             (9, 'LSIL'),
             (10, 'ASC-H'),
             (11, 'HSIL'),
             (12, 'SCC'),
             (15, 'AGC-FN'),
             (17, 'AGC-ECC-NOS'),
             (18, 'AGC-EMC-NOS'),
             (19, 'ADC-ECC'),
             (20, 'ADC-EMC'),
             (21, 'FUNGI'),
             (22, 'ACTINO'),
             (23, 'TRI'),
             (24, 'HSV'),
             (25, 'CC')])

In [84]:
## # species, order, family
num_species = 23
num_family = 21
num_order = 4
families_interval = [1,]*num_family
families_interval[13] = 2
families_interval[15] = 2
order_interval = [8, 5, 5, 5]
tct_trees0 = np.zeros((num_species,3),dtype=np.int)
tct_trees0[:,0] = np.arange(23)# + num_family + num_order
family_inds = []
for i in range(num_family):
    family_inds.extend([i]*families_interval[i])

order_inds = []
for i in range(num_order):
    order_inds.extend([i]*order_interval[i])
tct_trees0[:,1] = order_inds
tct_trees0[:,2] = family_inds

tct_trees1 = tct_trees0.copy()
tct_trees1[:,0] += num_family + num_order
tct_trees1[:,2] += num_order
tct_trees2 = tct_trees0.copy()
tct_trees2 += 1

In [3]:
# tct_trees1
# tct_trees2