In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import torch
if not torch.cuda.is_available():
    raise Exception("GPU not availalbe. CPU training will be too slow.")
print("device name", torch.cuda.get_device_name(0))

In [None]:
import seaborn as sns
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy import stats
from utils import padding, load_model, get_from_patient
from glob import glob

# Test

In [None]:
from models_hierarchy import ResNetMTL_InfoMin_CLUB
from sklearn.metrics import classification_report

has_types=False

max_length = 50
T21 = np.array([ [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
               [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0],
               [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]])
T21 = (T21.T / T21.sum(axis=1)).T
T21 = torch.from_numpy(T21)

val_data_dir = './dataset/test/'
npys = os.listdir(val_data_dir)
patients = []
patient_labels = {}
for npy in npys:
    p_label = np.load(val_data_dir + npy, allow_pickle=True)[()]['p_label']
    patients.append(npy.split('_')[0])
    patient_labels[npy.split('_')[0]] = p_label
patients = list(set(patients))

disease_test=[[] for i in range(12)]
for patient in patients:
    disease_test[patient_labels[patient]].append(patient)

model = ResNetMTL_InfoMin_CLUB([3,7], freeze=False, pretrained=True).cuda()
model_path = './checkpoints/model.best'
load_model(model, model_path)

total_p = 0
correct_p_3 = 0
correct_p_7 = 0
p_trues_3 = []
p_preds_3 = []
p_trues_7 = []
p_preds_7 = []
error_m_3 = np.zeros((3,3))
error_m_7 = np.zeros((7,7))

correct_c = 0
total_c = 0
c_trues = []
c_preds = []
error_mc = np.zeros((2,2))

for index, patients in enumerate(disease_test):
    for patient in patients:
        patient_label_p = index
        if patient_label_p == 0:
            patient_label_p = [0,0,0]
        elif patient_label_p == 1:
            patient_label_p = [0,0,1]
        elif patient_label_p == 2:
            patient_label_p = [0,1,2]
        elif patient_label_p == 3:
            patient_label_p = [0,1,3]
        elif patient_label_p == 4:
            continue
        elif patient_label_p == 5:
            patient_label_p = [0,2,4]
        elif patient_label_p == 6:
            patient_label_p = [0,2,5]
        elif patient_label_p == 7:
            patient_label_p = [0,3,6]
        elif patient_label_p == 8:
            patient_label_p = [1,4,7]
        elif patient_label_p == 9:
            patient_label_p = [1,4,8]
        elif patient_label_p == 10:
            patient_label_p = [1,5,9]
        elif patient_label_p == 11:
            patient_label_p = [2,6,10]
        else:
            print('error patient label!')

        total_p += 1

        p_trues_3.append(patient_label_p[0])
        p_trues_7.append(patient_label_p[1])
        data = glob(val_data_dir + patient + '*.npy')
        p_pre_3 = []
        p_pre_7 = []
        logits_out_3_all = torch.DoubleTensor(1, 3).cuda()
        logits_out_7_all = torch.DoubleTensor(1, 7).cuda()
        for npy in data:
            cell_imgs, patient_label, cell_rates, cell_types, original_imgs = get_from_patient(npy, max_length)
            cell_imgs = torch.stack(cell_imgs).cuda()
            cell_types = cell_types.cuda()
            patient_label_3 = patient_label[0].cuda()
            patient_label_7 = patient_label[1].cuda()
            patient_one_hot_3 = torch.nn.functional.one_hot(patient_label_3, 3).cuda()
            patient_one_hot_7 = torch.nn.functional.one_hot(patient_label_7, 7).cuda()

            T21_matrix = T21.float().cuda()

            T_matrix = [T21_matrix]

            with torch.no_grad():
                out_instance, A_raw, logits = model.test(cell_imgs, T_matrix)

            logits_3 = logits[0].permute(1,0)
            logits_7 = logits[1].permute(1,0)
            logits_3 = logits_3.unsqueeze(0)
            logits_7 = logits_7.unsqueeze(0)


            out_instance_s = torch.nn.functional.gumbel_softmax(out_instance, hard=True)

            predicted_cells = torch.argmax(out_instance, 1)

            if has_types:
                total_c += predicted_cells.size(0)
                correct_c += predicted_cells.eq(cell_types).sum().item()
                for i in range(predicted_cells.size(0)):
                    c_p = predicted_cells.cpu().numpy()[i]
                    c_t = cell_types.cpu().numpy()[i]
                    c_preds.append(c_p)
                    c_trues.append(c_t)
                    error_mc[c_t][c_p]+=1

            logits_3 = torch.softmax(logits_3, dim =1)
            logits_7 = torch.softmax(logits_7, dim =1)

            logits_T_3 = logits_3.transpose(1,0)
            logits_out_3 = logits_T_3[1]
            logits_out_3_all += logits_out_3
            pre3 = torch.argmax(logits_out_3, dim =1).cpu().numpy()[0]
            p_pre_3.append(pre3)

            logits_T_7 = logits_7.transpose(1,0)
            logits_out_7 = logits_T_7[1]
            logits_out_7_all += logits_out_7
            
            if pre3==0:
                logits_out_7 = logits_out_7[0][0:4]
                pre7 = torch.argmax(logits_out_7, dim =0).cpu().numpy()
            elif pre3==1:
                logits_out_7 = logits_out_7[0][4:6]
                pre7 = torch.argmax(logits_out_7, dim =0).cpu().numpy() + 4
            else:
                logits_out_7 = logits_out_7[0][6]
                pre7 = torch.argmax(logits_out_7, dim =0).cpu().numpy() + 6
            p_pre_7.append(pre7)

        counts_3 = np.bincount(np.array(p_pre_3))
        p_pre_3 = np.argmax(counts_3)   
        p_preds_3.append(p_pre_3)

        counts_7 = np.bincount(np.array(p_pre_7))
        p_pre_7 = np.argmax(counts_7)
        p_preds_7.append(p_pre_7)

        correct_p_3 += p_pre_3==patient_label_p[0]
        error_m_3[patient_label_p[0]][p_pre_3]+=1
        correct_p_7 += p_pre_7==patient_label_p[1]
        error_m_7[patient_label_p[1]][p_pre_7]+=1

# Results

In [None]:
from sklearn.metrics import classification_report
p_names_3 = ['AML', 'ALL', 'Normal']
p_names_7 = ['M2', 'M3', 'M5', 'M7', 'L1L2', 'L3','Normal']

print(classification_report(p_trues_3, p_preds_3, target_names=p_names_3, zero_division=0, digits=4))
print(classification_report(p_trues_7, p_preds_7, target_names=p_names_7, zero_division=0, digits=4))
print('======')
    
print('class3\n',error_m_3)
print('======')
print('class7\n',error_m_7)

if has_types:
    c_names = ['Normal', 'Tumor']
    print(classification_report(c_trues, c_preds, target_names=c_names, digits=4))
    print('======')
    print(error_mc)