In [1]:
from thop import profile
from modules import attmil,clam,dsmil,transmil,rrt
import torch
import sys
import torch.nn as nn
import numpy as np
sys.path.append('./DTFD/')
from Model.Attention import Attention_with_Classifier
from Model.Attention import Attention_Gated as Attention
from Model.network import Classifier_1fc, DimReduction



Setting tau to 1.0
Setting tau to 1.0


In [2]:
model = attmil.DAttention(2,dropout=0.25,act='relu',test=False)

In [2]:
model = attmil.AttentionGated(dropout=0.25)

In [4]:
model = clam.CLAM_SB(n_classes=2,dropout=0.25)

In [17]:
model = clam.CLAM_MB(n_classes=2,dropout=0.25)

In [2]:
model = dsmil.MILNet(2,0.25,'relu')

In [2]:
model = transmil.TransMIL(n_classes=2,dropout=0.25,act='relu')

In [23]:
model = rrt.RRT(pos='none',peg_k=7,attn='rrt',pool='attn',n_layers=2,epeg=True,conv_k=15,ffn=False)

In [11]:
# DTFD
def get_cam_1d(classifier, features):
    tweight = list(classifier.parameters())[-2]
    cam_maps = torch.einsum('bgf,cf->bcg', [features, tweight])
    return cam_maps
class DTFD(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.classifier = Classifier_1fc(512, 2, 0.25)
        self.attention = Attention(512)
        self.dimReduction = DimReduction(1024, 512, dropout=0.25)
        self.UClassifier = Attention_with_Classifier(L=512, num_cls=2, droprate=0.25)
        self.group = 5
        self.distill = 'MaxMinS'
    def forward(self, x):
        print(x.size())
        #x = x[0]
        feat_index = list(range(x.shape[0]))
        index_chunk_list = np.array_split(np.array(feat_index), self.group)
        index_chunk_list = [sst.tolist() for sst in index_chunk_list]
        slide_pseudo_feat = []
        slide_sub_preds = []
        for tindex in index_chunk_list:
            
            subFeat_tensor = torch.index_select(x, dim=0, index=torch.LongTensor(tindex))
            tmidFeat = self.dimReduction(subFeat_tensor)
            tAA = self.attention(tmidFeat).squeeze(0)

            tattFeats = torch.einsum('ns,n->ns', tmidFeat, tAA)  ### n x fs
            tattFeat_tensor = torch.sum(tattFeats, dim=0).unsqueeze(0)  ## 1 x fs
            tPredict = self.classifier(tattFeat_tensor)  ### 1 x 2
            slide_sub_preds.append(tPredict)

            patch_pred_logits = get_cam_1d(self.classifier, tattFeats.unsqueeze(0)).squeeze(0)  ###  cls x n
            patch_pred_logits = torch.transpose(patch_pred_logits, 0, 1)  ## n x cls
            patch_pred_softmax = torch.softmax(patch_pred_logits, dim=1)  ## n x cls

            _, sort_idx = torch.sort(patch_pred_softmax[:,-1], descending=True)
            topk_idx_max = sort_idx[:1].long()
            topk_idx_min = sort_idx[-1:].long()
            topk_idx = torch.cat([topk_idx_max, topk_idx_min], dim=0)

            MaxMin_inst_feat = tmidFeat.index_select(dim=0, index=topk_idx)   ##########################
            max_inst_feat = tmidFeat.index_select(dim=0, index=topk_idx_max)
            af_inst_feat = tattFeat_tensor

            if self.distill == 'MaxMinS':
                slide_pseudo_feat.append(MaxMin_inst_feat)
            elif self.distill == 'MaxS':
                slide_pseudo_feat.append(max_inst_feat)
            elif self.distill == 'AFS':
                slide_pseudo_feat.append(af_inst_feat)

        slide_pseudo_feat = torch.cat(slide_pseudo_feat, dim=0)  ### 

        gSlidePred = self.UClassifier(slide_pseudo_feat)
        return gSlidePred
model = DTFD()

In [14]:
flops, params = profile(model, inputs=torch.rand(1, 9000, 1024))
print('FLOPs = ' + str(flops/1000**3) + 'G')
print('Params = ' + str(params/1000**2) + 'M')

[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
506
FLOPs = 10.061387008G
Params = 1.181954M
