In [11]:
import os
import csv
import pickle
import random
import pandas as pd
import numpy as np
import uuid
import json
import os.path as osp

from glob import glob
from tqdm import tqdm
from shutil import copy
from pprint import pprint
# import cv2
# import imageio
# import imagesize
from PIL import Image

import matplotlib.pyplot as plt
from itertools import groupby,chain,combinations
from functools import partial
from collections import defaultdict, Counter, OrderedDict
from scipy.special import softmax

from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, cohen_kappa_score, accuracy_score
from sklearn.metrics import average_precision_score, roc_curve, auc, precision_recall_curve, f1_score, roc_auc_score

In [12]:
from anytree import Node, RenderTree, search
from anytree.importer import JsonImporter
from anytree.exporter import JsonExporter, DotExporter
import lzma
from nltk.tree import Tree

In [5]:
# !pip install anytree

In [21]:
def lca_height(class_name1, class_name2, logarithmic=True):
    """lowest common ancestor height, taking the level into acount np.log(1+height)
    """
    node1 = search.find_by_attr(TCT, class_name1)
    node2 = search.find_by_attr(TCT, class_name2)
    node1_path_names = [x.name for x in node1.path]
    node2_path_names = [x.name for x in node2.path]
    if len(node1_path_names) == len(node2_path_names):
        height = 0
        for name1, name2 in list(zip(node1_path_names, node2_path_names))[::-1]:
            if name1==name2:
                return np.log(1+height) if logarithmic else height
            else:
                height +=1
    #             return name1
    else:
        common_length = len(set(node1_path_names).intersection(set(node2_path_names)))
        longest_length = max(len(node1_path_names), len(node2_path_names))
        height = longest_length - common_length
        return height
def find_level_name_v2(class_name, level=1):
    """fill the finest label using the coarse lable"""
    path_node_classes = classname_paths[class_name.strip()]
    if len(path_node_classes)>level:
        return path_node_classes[level]
    else:
        return path_node_classes[-1]

### Hierarchy class in HiCervix

In [14]:
### hierarchical tree construction for HiCervix
## the full names of the acronyms are listed in dataset/hierarchy_classification/version2023/hierarchy_names.csv
TCT = Node("TCT")
negative = Node("negative",parent=TCT) 
ASC = Node("ASC", parent=TCT) 
AGC = Node("AGC", parent=TCT) 
microbe = Node("microbe",parent=TCT)  ## microbe is the class  Organisms

ASCUS = Node("ASC-US", parent=ASC)
LSIL = Node("LSIL", parent=ASC)
ASCH = Node("ASC-H", parent=ASC)
HSIL = Node("HSIL", parent=ASC)
SCC = Node("SCC", parent=ASC)

AGCNOS = Node("AGC-NOS", parent=AGC)
AGCFN = Node("AGC-FN", parent=AGC)
ADC = Node("ADC", parent=AGC)

AGCNOS1 = Node("AGC-ECC-NOS", parent=AGCNOS)
AGCNOS2 = Node("AGC-EMC-NOS", parent=AGCNOS)

ADC1 = Node("ADC-ECC", parent=ADC)
ADC2 = Node("ADC-EMC", parent=ADC)

normal = Node("Normal", parent=negative)
endocervical = Node("ECC", parent=negative)
xiufu = Node("RPC", parent=negative)
huasheng = Node("MPC", parent=negative)
glucose = Node("PG", parent=negative)
Atrophy = Node("Atrophy", parent=negative)
EMC = Node("EMC", parent=negative)
HCG = Node("HCG", parent=negative)

FUNGI = Node("FUNGI", parent=microbe)
ACTINO = Node("ACTINO", parent=microbe)
TRI = Node("TRI", parent=microbe)
HSV = Node("HSV", parent=microbe)
CC = Node("CC", parent=microbe)

# DotExporter(TCT_en).to_picture("TCT.png")

In [15]:
class_names = ['Normal', 'ECC', 'RPC', 'MPC', 'PG', 'Atrophy', 'EMC', 'HCG', 'ASC-US', 'LSIL',
               'ASC-H', 'HSIL', 'SCC', 'AGC', 'AGC-NOS', 'AGC-FN', 'ADC', 'AGC-ECC-NOS', 'AGC-EMC-NOS', 
               'ADC-ECC', 'ADC-EMC', 'FUNGI', 'ACTINO', 'TRI', 'HSV', 'CC']
level_1_names = ['negative','ASC','AGC','microbe']
level_2_names = ['Normal', 'ECC', 'RPC', 'MPC', 'PG', 'Atrophy', 'EMC', 'HCG', 'ASC-US', 'LSIL', 
                 'ASC-H', 'HSIL', 'SCC', 'AGC-NOS', 'AGC-FN', 'ADC', 'FUNGI', 'ACTINO', 'TRI', 'HSV', 'CC']
level_3_names = ['Normal', 'ECC', 'RPC', 'MPC', 'PG', 'Atrophy', 'EMC', 'HCG', 'ASC-US', 'LSIL', 'ASC-H', 'HSIL', 'SCC', 'AGC-FN', 
                 'AGC-ECC-NOS', 'AGC-EMC-NOS', 'ADC-ECC', 'ADC-EMC', 'FUNGI', 'ACTINO', 'TRI', 'HSV', 'CC']
level_1_names2id = dict(zip(level_1_names, range(len(level_1_names))))
level_2_names2id = dict(zip(level_2_names, range(len(level_2_names))))
level_3_names2id = dict(zip(level_3_names, range(len(level_3_names))))

level_names = [level_1_names, level_2_names,level_3_names,]
classname2newid = dict(zip(class_names, range(len(class_names))))
newid2classname= {v:k for k,v in classname2newid.items()}
print('Number of total annotated classes: {}'.format(len(class_names)))
for i in range(3):
    print('Number of level {} classes: {}'.format(i+1, len(level_names[i])))

Number of total annotated classes: 26
Number of level 1 classes: 4
Number of level 2 classes: 21
Number of level 3 classes: 23


In [16]:
classname_paths = {}
for class_name in class_names:
    # print(class_name)
    class_name_node = search.find_by_attr(TCT, class_name)
    path_node_classes = [x.name for x in class_name_node.path]#[1:] #exclude the root node of 'TCT'
#     print('The nodes in the path (from root node) to reach the nodes of {}'.format(class_name))
#     print(path_node_classes)
    classname_paths[class_name] = path_node_classes

In [17]:
classname_paths

{'Normal': ['TCT', 'negative', 'Normal'],
 'ECC': ['TCT', 'negative', 'ECC'],
 'RPC': ['TCT', 'negative', 'RPC'],
 'MPC': ['TCT', 'negative', 'MPC'],
 'PG': ['TCT', 'negative', 'PG'],
 'Atrophy': ['TCT', 'negative', 'Atrophy'],
 'EMC': ['TCT', 'negative', 'EMC'],
 'HCG': ['TCT', 'negative', 'HCG'],
 'ASC-US': ['TCT', 'ASC', 'ASC-US'],
 'LSIL': ['TCT', 'ASC', 'LSIL'],
 'ASC-H': ['TCT', 'ASC', 'ASC-H'],
 'HSIL': ['TCT', 'ASC', 'HSIL'],
 'SCC': ['TCT', 'ASC', 'SCC'],
 'AGC': ['TCT', 'AGC'],
 'AGC-NOS': ['TCT', 'AGC', 'AGC-NOS'],
 'AGC-FN': ['TCT', 'AGC', 'AGC-FN'],
 'ADC': ['TCT', 'AGC', 'ADC'],
 'AGC-ECC-NOS': ['TCT', 'AGC', 'AGC-NOS', 'AGC-ECC-NOS'],
 'AGC-EMC-NOS': ['TCT', 'AGC', 'AGC-NOS', 'AGC-EMC-NOS'],
 'ADC-ECC': ['TCT', 'AGC', 'ADC', 'ADC-ECC'],
 'ADC-EMC': ['TCT', 'AGC', 'ADC', 'ADC-EMC'],
 'FUNGI': ['TCT', 'microbe', 'FUNGI'],
 'ACTINO': ['TCT', 'microbe', 'ACTINO'],
 'TRI': ['TCT', 'microbe', 'TRI'],
 'HSV': ['TCT', 'microbe', 'HSV'],
 'CC': ['TCT', 'microbe', 'CC']}

### Data preprocessing for different methods

In [18]:
# hierarchy for HierSwin and making better mistakes
tct = Tree.fromstring("(TCT (negative (Normal Normal) (ECC ECC) (RPC RPC) (MPC MPC) (PG PG) (Atrophy Atrophy) (EMC EMC) (HCG HCG))  \
                     (ASC (ASC-US ASC-US) (LSIL LSIL) (ASC-H ASC-H) (HSIL HSIL) (SCC SCC))   \
                     (AGC (AGC-FN AGC-FN) (AGC-NOS AGC-ECC-NOS AGC-EMC-NOS) (ADC ADC-ECC ADC-EMC)) \
                     (microbe (FUNGI FUNGI) (ACTINO ACTINO) (TRI TRI) (HSV HSV) (CC CC)) \
                     )")
tct_2level = Tree.fromstring("(TCT (negative Normal  ECC RPC MPC  PG  Atrophy  EMC  HCG)  \
                     (ASC  ASC-US  LSIL  ASC-H  HSIL  SCC)   \
                     (AGC AGC-FN  AGC-NOS  ADC ) \
                     (microbe  FUNGI  ACTINO  TRI  HSV  CC) \
                     )")

tct_distances = {}
all_names = set(level_1_names + level_2_names + level_3_names)
for name1 in all_names:
    for name2 in all_names:
        tct_distances[(name1,name2)] = lca_height(name1, name2)
with open('tct_tree.pkl','wb') as fi:
    pickle.dump(tct, fi)

with open('tct_distances.pkl','wb') as fi:
    pickle.dump(tct_distances, fi)

level_names_dict = dict(zip(['order','family','species'], level_names))
with open('level_names_dict.pkl','wb') as fi:
    pickle.dump(level_names_dict, fi)
# with open('classification_hierarchy/making-better-mistakes-2level/data/tct_tree.pkl','wb') as fi:
#     pickle.dump(tct_2level, fi)

# with open('classification_hierarchy/making-better-mistakes/data/tct_tree.pkl','rb') as fi:
#     tree_data = pickle.load(fi)
# tree_data
# with lzma.open('/classification_hierarchy/making-better-mistakes/data/imagenet_ilsvrc_distances.pkl.xz', "rb") as f:
#     tmp = pickle.load(f)
# #     distance_data = DistanceDict(pickle.load(f))
#     distance_data = DistanceDict(tmp)

In [19]:
# level_names_dict

In [22]:
# for multi-head and FGoN
trees =[]
trees_names = []
trees_dict = {}
def find_class_id(class_name, lst):
    try:
        class_id = lst.index(class_name)
    except:
        class_id = -1
    return class_id
for i, class_name in enumerate(class_names):
    class_name_node = search.find_by_attr(TCT, class_name)
    path_node_classes = [x.name for x in class_name_node.path]#[1:] #exclude the root node of 'TCT'
    extended_node_names = [find_level_name_v2(class_name, i) for i in range(1,4)]
    extended_node_ids = [find_class_id(x, level_names[i]) for i, x in enumerate(extended_node_names)] #level_names[i].index(x)
    trees.append(extended_node_ids)
    trees_names.append(extended_node_names)
    trees_dict[class_name]=extended_node_ids 

In [None]:
# HierSwin methods and making better mistakes
datasets = ['train.csv','val.csv','test.csv']
# datasets = ['train_image_path.csv','val_image_path.csv','test_image_path.csv']
csv_dir = 'dataset/hierarchy_classification/version2023/'
for dataset in datasets:
    csv_file = osp.join(csv_dir, dataset)
    df_tmp = pd.read_csv(csv_file)
    df_tmp = df_tmp[df_tmp['class_name'].isin(level_3_names)]
    df_tmp['level_3_id'] = df_tmp['class_name'].map(dict(zip(level_3_names,range(len(level_3_names)))))
    df_tmp.to_csv(csv_file.replace('.csv','_mbm.csv'),index=False)

In [None]:
## HRN
datasets = ['train_image_path.csv']#,'val_image_path.csv','test_image_path.csv']
csv_dir = 'dataset/hierarchy_classification/version2023/'
for dataset in datasets:
    csv_file = osp.join(csv_dir, dataset)
    df = pd.read_csv(csv_file)
    df['level_1_id'] = df['level_1'].map(level_1_names2id)
    df['level_2_name'] = df['class_name'].apply(lambda x: find_level_name_v2(x, level=2))
    df['level_3_name'] = df['class_name'].apply(lambda x: find_level_name_v2(x, level=3))
    df['level_2_id'] = df['level_2_name'].apply(lambda x: level_2_names2id.get(x, -1))
    df['level_3_id'] = df['level_3_name'].apply(lambda x: level_3_names2id.get(x, -1))
    #df[['image_path', 'image_name', 'class_id', 'level_1_id', 'level_2_id', 'level_3_id']].to_csv(csv_file.replace('.csv','_hrn.csv')) #encoding='utf-8-sig', ,  index=False
    df[['image_path', 'image_name', 'class_id', 'level_1_id', 'level_2_id', 'level_3_id']].to_csv('/mnt/group-ai-medical-abp/private/daviddecai/tmp.csv') #encoding='utf-8-sig', ,  index=False

In [None]:
### fine-grained visual classification

# level_3 finegrained，intermediate classes such as AGC,'AGC-NOS','ADC' are removed 
csv_dir = 'dataset/hierarchy_classification/version2023'
datasets = ['train','val','test']
for dataset in datasets:
    df_tmp = pd.read_csv(osp.join(csv_dir, dataset + '.csv'))
    #df_tmp = df_tmp[~df_tmp['class_name'].isin(['AGC','AGC-NOS','ADC'])] 
    df_tmp = df_v2(df_tmp)
    #df_tmp['level_3_id']= df_tmp['level_3'].map(dict(zip(level_3_names, range(len(level_3_names)))))
    df_tmp['level_3_id']= df_tmp['level_3'].apply(lambda x: level_3_names2id.get(x, -1))
    df_tmp.to_csv(osp.join(csv_dir, dataset + '_keep_species_all.csv'))

In [91]:
tct_trees = np.array([[0, 0, 0],
                     [0, 1, 1],
                     [0, 2, 2],
                     [0, 3, 3],
                     [0, 4, 4],
                     [0, 5, 5],
                     [0, 6, 6],
                     [0, 7, 7],
                     [1, 8, 8],
                     [1, 9, 9],
                     [1, 10, 10],
                     [1, 11, 11],
                     [1, 12, 12],
                     [2, -1, -1],
                     [2, 13, -1],
                     [2, 14, 13],
                     [2, 15, -1],
                     [2, 13, 14],
                     [2, 13, 15],
                     [2, 15, 16],
                     [2, 15, 17],
                     [3, 16, 18],
                     [3, 17, 19],
                     [3, 18, 20],
                     [3, 19, 21],
                     [3, 20, 22]])

In [36]:
len(np.unique(tct_trees[:,2]))

24

In [None]:
4+21+23

In [None]:
0~3, 4~24, 25~47

In [40]:
tct_trees.shape

(26, 3)

In [118]:
class_id2class_name = dict(zip(df['class_id'],df['class_name']))
class_id2class_name = OrderedDict(sorted(class_id2class_name.items())) 
class_id2class_name

OrderedDict([(0, 'Normal'),
             (1, 'ECC'),
             (2, 'RPC'),
             (3, 'MPC'),
             (4, 'PG'),
             (5, 'Atrophy'),
             (6, 'EMC'),
             (7, 'HCG '),
             (8, 'ASC-US'),
             (9, 'LSIL'),
             (10, 'ASC-H'),
             (11, 'HSIL'),
             (12, 'SCC'),
             (15, 'AGC-FN'),
             (17, 'AGC-ECC-NOS'),
             (18, 'AGC-EMC-NOS'),
             (19, 'ADC-ECC'),
             (20, 'ADC-EMC'),
             (21, 'FUNGI'),
             (22, 'ACTINO'),
             (23, 'TRI'),
             (24, 'HSV'),
             (25, 'CC')])

In [84]:
## # species, order, family
num_species = 23
num_family = 21
num_order = 4
families_interval = [1,]*num_family
families_interval[13] = 2
families_interval[15] = 2
order_interval = [8, 5, 5, 5]
tct_trees0 = np.zeros((num_species,3),dtype=np.int)
tct_trees0[:,0] = np.arange(23)# + num_family + num_order
family_inds = []
for i in range(num_family):
    family_inds.extend([i]*families_interval[i])

order_inds = []
for i in range(num_order):
    order_inds.extend([i]*order_interval[i])
tct_trees0[:,1] = order_inds
tct_trees0[:,2] = family_inds

tct_trees1 = tct_trees0.copy()
tct_trees1[:,0] += num_family + num_order
tct_trees1[:,2] += num_order
tct_trees2 = tct_trees0.copy()
tct_trees2 += 1

In [3]:
# tct_trees1
# tct_trees2