In [None]:
from modules import rrt,attmil
import torch
import seaborn as sns
import numpy as np

from vis_utils import *
import h5py
from xml.dom.minidom import parse
import xml.dom.minidom
from openTSNE import TSNE
from shapely.geometry import Polygon
from copy import deepcopy

## Function

In [None]:
def read_annotation(anno_file,return_type=False):
    anno_tumor = []
    anno_normal = []
    anno_type = set()
    # 使用minidom解析器打开 XML 文档
    DOMTree = xml.dom.minidom.parse(anno_file)
    annotations = DOMTree.documentElement.getElementsByTagName('Annotations')[0].getElementsByTagName('Annotation')
    for i in range(len(annotations)):
        anno_type.add(annotations[i].getAttribute('PartOfGroup'))
        if annotations[i].getAttribute('PartOfGroup') == 'Exclusion':
            coordinates = annotations[i].getElementsByTagName('Coordinates')
            _tmp = []
            for node in coordinates[0].childNodes:
                if type(node) == xml.dom.minidom.Element:
                    _tmp.append([int(float(node.getAttribute("X"))),int(float(node.getAttribute("Y")))])

            anno_normal.append(_tmp)
        elif annotations[i].getAttribute('PartOfGroup') != 'None':
            coordinates = annotations[i].getElementsByTagName('Coordinates')
            _tmp = []
            for node in coordinates[0].childNodes:
                if type(node) == xml.dom.minidom.Element:
                    _tmp.append([int(float(node.getAttribute("X"))),int(float(node.getAttribute("Y")))])

            anno_tumor.append(_tmp)
    if return_type:
        return anno_tumor,anno_normal,anno_type
    else:
        return anno_tumor,anno_normal

In [None]:
def get_label(coords,anno_file,_l=None):
    if anno_file is None:
        return None
    label = []
    annos_tumor,annos_normal = read_annotation(anno_file)
    annos_tumor_polygon = [ Polygon(_anno) for _anno in annos_tumor ]
    annos_normal_polygon = [ Polygon(_anno) for _anno in annos_normal ]
    annos_tumor_in_normal_idx = []
    
    # 有一部分tumor是在exclusion里面，需要区别出这些tumor出来
    for idx,_anno in enumerate(annos_tumor_polygon):
        for _anno_1 in annos_normal_polygon:
            if _anno.covered_by(_anno_1):
                annos_tumor_in_normal_idx.append(idx)
    
    for coord in coords:
        _patch = Polygon([coord,[coord[0]+512,coord[1]],[coord[0]+512,coord[1]+512],[coord[0],coord[1]+512]])
        _flag = 0
        _flag_always = 0
        for idx,_anno in enumerate(annos_tumor_polygon):
            if _patch.intersects(_anno):
                _flag = 1
                if idx in annos_tumor_in_normal_idx:
                    _flag_always = 1
        if not _flag_always:
            for _anno_1 in annos_normal_polygon:
                if _patch.intersects(_anno_1):
                    _flag = 0
                
        if _flag:
            #label.append(1)
            if _l is not None:
                label.append(0)
            else:
                label.append(1)
        else:
            label.append(0)
    label = np.array(label)

    if _l is not None:
        #label[np.array(label == 0) * np.array(_l > 0)] = 1
        label[np.array(_l > 0)] = 1
        
    return label

In [None]:
def plot(
    x,
    y,
    ax=None,
    title=None,
    draw_legend=True,
    draw_centers=False,
    draw_cluster_labels=False,
    colors=None,
    legend_kwargs=None,
    label_order=None,
    **kwargs
):
    import matplotlib

    if ax is None:
        _, ax = matplotlib.pyplot.subplots(figsize=(8, 8))

    if title is not None:
        ax.set_title(title)

    #plot_params = {"alpha": kwargs.get("alpha", 0.6), "s": kwargs.get("s", 1)}
    plot_params = {"alpha": kwargs.get("alpha", 0.8)}

    # Create main plot
    if label_order is not None:
        assert all(np.isin(np.unique(y), label_order))
        classes = [l for l in label_order if l in np.unique(y)]
    else:
        classes = np.unique(y)
    if colors is None:
        default_colors = matplotlib.rcParams["axes.prop_cycle"]
        colors = {k: v["color"] for k, v in zip(classes, default_colors())}

    point_colors = list(map(colors.get, y))

    size = deepcopy(y)
    point_size = deepcopy(y)
    point_size[size != 1] = 1
    point_size[size == 1] = 50
    point_size[size == 2] = 1
    

    fig=ax.scatter(x[:, 0], x[:, 1], c=point_colors, rasterized=True, s=point_size,**plot_params)
    
    # Plot mediods
    if draw_centers:
        centers = []
        for yi in classes:
            mask = yi == y
            centers.append(np.median(x[mask, :2], axis=0))
        centers = np.array(centers)

        center_colors = list(map(colors.get, classes))
        ax.scatter(
            centers[:, 0], centers[:, 1], c=center_colors, s=48, alpha=1, edgecolor="k"
        )

        # Draw mediod labels
        if draw_cluster_labels:
            for idx, label in enumerate(classes):
                ax.text(
                    centers[idx, 0],
                    centers[idx, 1] + 2.2,
                    label+': '+str(len(x)),
                    fontsize=kwargs.get("fontsize", 6),
                    horizontalalignment="center",
                )

    # Hide ticks and axis
    ax.set_xticks([]), ax.set_yticks([]), ax.axis("off")

    if draw_legend:
        legend_handles = [
            matplotlib.lines.Line2D(
                [],
                [],
                marker="s",
                color="w",
                markerfacecolor=colors[yi],
                ms=10,
                alpha=1,
                linewidth=0,
                label=str(yi)+': '+str(len(y[y ==yi])),
                markeredgecolor="k",
            )
            for yi in classes
        ]
        legend_kwargs_ = dict(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False, )
        if legend_kwargs is not None:
            legend_kwargs_.update(legend_kwargs)
        ax.legend(handles=legend_handles, **legend_kwargs_)
    
    return fig

In [None]:
def tsne(feat,coords=None,anno_file=None,_l=None,**kwargs):
    try:
        label = get_label(coords,anno_file,_l)
    except:
        label = None
    #label = get_label(coords,anno_file,_l)
    embedding = TSNE(n_jobs=8).fit(feat)
    y = label if label is not None else np.array([1 for i in range(len(embedding))])
    return plot(embedding,y,**kwargs)

## Init

In [None]:
_rrt = rrt.RRTMIL(pos='none',attn='rrt',ic=True,n_layers=2,da_act='tanh',trans_conv=True,moe_fl_enable=False,moe_mlp=False,input_dim=1024,moe_mask_diag=False,minmax_weight=True,moe_k=1,all_shortcut=True,l2_n_heads=8,l1_shortcut=True,conv_k=15).eval().requires_grad_(False)

cpt = torch.load('/data/tangwenhao/output/mil/mil_clam_c16_other/rrt_attn_convk15_N1SCsmoe_l2mm/fold_1_model_best_auc.pt')
_rrt.load_state_dict(cpt['model'],strict=False)

In [None]:
_ntrans = rrt.RRTMIL(pos='none',attn='ntrans',n_layers=1,da_act='tanh',ic=True,all_shortcut=False).eval().requires_grad_(False)

cpt = torch.load('/data/tangwenhao/output/mil/mil_clam_c16_other/ntrans_attn/fold_1_model_best_auc.pt')
_ntrans.load_state_dict(cpt['model'],strict=False)

In [None]:
_abmil = attmil.DAttention(2,dropout=True,act='relu',test=False,input_dim=1024).eval().requires_grad_(False)

cpt = torch.load('/home/tangwenhao/code/mil/vis/attn_seed2021_fold_1_model_best_auc.pt')
_abmil.load_state_dict(cpt['model'],strict=False)

_abmil_plip = attmil.DAttention(2,dropout=True,act='relu',test=False,input_dim=512).eval().requires_grad_(False)

cpt = torch.load('/data/tangwenhao/output/mil/mil_clam_c16_other/plip_abmil/fold_1_model_best_auc.pt')
_abmil_plip.load_state_dict(cpt['model'],strict=False)

In [None]:
_f = 'TCGA-33-4532-01Z-00-DX1.32ab8c26-7cdc-4e55-8c70-5a35d83f81a2'
feat = torch.load('/nas/zhangxiaoxian/tcga/zft/pt_files/'+_f+'.pt')
coords = None
print(feat.size())
label = None

In [None]:
_f = 'test_040'
plip_feat = torch.load('/data/tangwenhao/c16_clam_bio_seg/plip/pt/'+_f+'.pt')
patch = h5py.File('/home/tangwenhao/dataset/mil/c16_clam_bio_seg/h5/'+_f+'.h5',"r")
feat = torch.Tensor(patch['features'])
coords = patch['coords']
print(feat.size())
label = get_label(coords,"/home/tangwenhao/dataset/mil/c16_clam_bio_seg/vis/c_16_lesion_annotation/"+_f+".xml")
print(len(label[label==0]),len(label[label==1]))

## Attention Distributions

### Ntrans

In [None]:
with torch.no_grad():
    x,attns = _ntrans(feat.unsqueeze(0),return_attn=True)
print(torch.nn.functional.softmax(x,dim=-1))

_a = attns[0]
_a,_ = torch.sort(_a,descending=True)
sns.scatterplot(x=np.array(list(range(_a.size(0)))),y=_a,s=100,alpha=0.6)

### ABMIL w/ PLIP

In [None]:
with torch.no_grad():
    x,attns = _abmil_plip.forward(plip_feat.unsqueeze(0),return_attn=True)
print(torch.nn.functional.softmax(x,dim=-1))

_a = attns[0]
_a,_ = torch.sort(_a,descending=True)
sns.scatterplot(x=np.array(list(range(_a.size(0)))),y=_a,s=100,alpha=0.6)

### ABMIL w/ R50

In [None]:
with torch.no_grad():
    x,attns = _abmil.forward(feat.unsqueeze(0),return_attn=True)
print(torch.nn.functional.softmax(x,dim=-1))

_a = attns[0]
_a,_ = torch.sort(_a,descending=True)
sns.scatterplot(x=np.array(list(range(_a.size(0)))),y=_a,s=100,alpha=0.6)

### RRT-MIL

In [None]:
with torch.no_grad():
    x,attns,trans_attns = _rrt.forward(feat.unsqueeze(0),return_attn=True,return_trans_attn=True)
print(torch.nn.functional.softmax(x,dim=-1))

#_a = torch.nn.functional.softmax(attns[0])
_a = attns[0]
_a,_ = torch.sort(_a,descending=True)
sns.scatterplot(x=np.array(list(range(_a.size(0)))),y=_a,s=100,alpha=0.6)

## Feature Vis

### PLIP Features

In [None]:
feat_vis = tsne(torch.tensor(plip_feat).squeeze(0),coords,"/home/tangwenhao/dataset/mil/c16_clam_bio_seg/vis/c_16_lesion_annotation/"+_f+".xml",draw_legend=False)
feat_vis.get_figure().savefig('./vis_figure/ori_plip.png',dpi=300,bbox_inches='tight')

### R50 Features

In [None]:
feat_vis = tsne(torch.tensor(feat).squeeze(0),coords,"/home/tangwenhao/dataset/mil/c16_clam_bio_seg/vis/c_16_lesion_annotation/"+_f+".xml",draw_legend=False)
feat_vis.get_figure().savefig('./vis_figure/ori.png',dpi=300,bbox_inches='tight')

### Features After FC in ABMIL

In [None]:
feat_fc = tsne(_abmil.feature(torch.tensor(feat).squeeze(0)),coords,"/home/tangwenhao/dataset/mil/c16_clam_bio_seg/vis/c_16_lesion_annotation/"+_f+".xml",draw_legend=False)
feat_fc.get_figure().savefig('./vis_figure/fc.png',dpi=300,bbox_inches='tight')

### Features After Ntrans

In [None]:
feat_ntrans = _ntrans.online_encoder(_ntrans.dp(_ntrans.patch_to_emb(torch.tensor(feat).squeeze(0))),no_pool=True)
feat_ntrans = tsne(feat_ntrans,coords,"/home/tangwenhao/dataset/mil/c16_clam_bio_seg/vis/c_16_lesion_annotation/"+_f+".xml",draw_legend=False)
feat_ntrans.get_figure().savefig('./vis_figure/ntrans.png',dpi=300,bbox_inches='tight')

### Features After RRT-MIL

In [None]:
feat_rrt = _rrt.online_encoder(_rrt.dp(_rrt.patch_to_emb(torch.tensor(feat).squeeze(0))),no_pool=True)
feat_rrt = tsne(feat_rrt,coords,"/home/tangwenhao/dataset/mil/c16_clam_bio_seg/vis/c_16_lesion_annotation/"+_f+".xml",draw_legend=False)
feat_rrt.get_figure().savefig('./vis_figure/rrt_as.png',dpi=300,bbox_inches='tight')