In [None]:
import os
import torch
import numpy as np
import random
import time
from torch.utils.data import Dataset
import numpy as np
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_fscore_support, accuracy_score
import torch
from timm.utils import AverageMeter,dispatch_clip_grad
from timm.models import  model_parameters
from collections import OrderedDict
from torch.nn.functional import one_hot
import random
import csv
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from torch.cuda.amp import GradScaler
from contextlib import suppress
from torch.utils.data import DataLoader, RandomSampler
from timm.utils import AverageMeter,dispatch_clip_grad
import torch.nn as nn


# config

In [None]:
dataset_root = r'F:\lung_dl\data\subtyping\TCGA-BRCA R50' #################################
model_path = r"F:\lung_dl\code\Sub-typing\results"  # type=str, help='Output path'#############################################################################################################################
project = 'topo_mil'  # type=str, help='Project name of exp'#####################################################################################################################
seed = 2021  # type=int, help='random number [2021]'
from datetime import datetime
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")  # 例如：2025-06-19_00-40-12
title = 'TCGA-BRCA R50'+"_"+str(timestamp)+"seed_"+str(seed)  # type=str, help='Title of exp'###################################################################################

cv_fold = 5  
val_ratio = 0.  

fold_start = 0  # type=int, help='Start validation fold [0]'
amp = False  # action='store_true', help='Automatic Mixed Precision Training'

batch_size = 1  # type=int, help='Number of batch size'
num_workers = 0  # type=int, help='Number of workers in the dataloader'
lr = 2e-4  # type=float, help='Initial learning rate [0.0002]'
weight_decay = 1e-5  # type=float, help='Weight decay [5e-3]'
num_epoch = 200  # type=int, help='Number of total training epochs [200]'



if not os.path.exists(os.path.join(model_path,project)):
    os.mkdir(os.path.join(model_path,project))
model_path = os.path.join(model_path,project,title)
if not os.path.exists(model_path):
    os.mkdir(model_path)

# 数据类

In [None]:
class Subtyping_Dataset(Dataset):
    def __init__(self, file_name, file_label):
        super(Dataset, self).__init__()
        self.patient_name = file_name
        self.patient_label = file_label

        self.slide_label = [ 0 if _l == 'IDC' else 1 for _l in self.patient_label] # brca################################################################################
        #self.slide_label = [ 0 if _l == 'LUAD' else 1 for _l in self.patient_label]# nsclc###################################################################################

    def __len__(self):
        return len(self.patient_name)

    def __getitem__(self, idx):
        label = self.slide_label[idx]
        file_path = self.patient_name[idx]
        features = torch.load(r"F:\lung_dl\data\subtyping\TCGA-BRCA R50\pt_files"+"\\"+file_path+".pt")#########################################################
        return features , int(label)


# diagnosis####################################################################################################################################
class C16Dataset(Dataset):
    def __init__(self, file_name=None, file_label=None,max_patch=-1,root=None,persistence=True,keep_same_psize=0,is_train=False,_type='nsclc'):
        super(C16Dataset, self).__init__()
        self.file_name = file_name
        self.slide_label = file_label
        self.slide_label = [int(_l) for _l in self.slide_label]
        self.size = len(self.file_name)
        self.root = root
        self.persistence = persistence
        self.keep_same_psize = keep_same_psize
        self.is_train = is_train

        if persistence:
            self.feats = [ torch.load(os.path.join(root,'pt', _f+'.pt')) for _f in file_name ]

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        """
        Args
        :param idx: the index of item
        :return: image and its label
        """
        if self.persistence:
            features = self.feats[idx]
        else:
            dir_path = os.path.join(self.root,"pt")

            file_path = os.path.join(dir_path, self.file_name[idx]+'.pt')
            features = torch.load(file_path)

        label = int(self.slide_label[idx])


        return features , label


# model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
def initialize_weights(module):
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            # ref from huggingface
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m,nn.Linear):
            # ref from clam
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m,nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, global_max_pool, GlobalAttention

class TopoAggregator(nn.Module):
    def __init__(self, dim_in=512, dim_hidden=512, topk=6):
        super().__init__()
        self.proj_q = nn.Linear(dim_in, dim_hidden)
        self.proj_k = nn.Linear(dim_in, dim_hidden)
        self.topk = topk

    def forward(self, x):
        q = self.proj_q(x)  # Query
        k = self.proj_k(x) 

        S = torch.matmul(q, k.transpose(-2, -1))  
        S_topk, idx_topk = torch.topk(S, k=self.topk, dim=-1)
        idx_topk = idx_topk.to(torch.long)

        idx_topk_exp = idx_topk.expand(k.size(0), -1, -1)
        batch_indices = torch.arange(k.size(0)).view(-1, 1, 1).to(idx_topk.device)
        K_neighbors = k[batch_indices, idx_topk_exp, :]

        P_topk = F.softmax(S_topk, dim=2)
        X_agg = torch.mul(P_topk.unsqueeze(-1), K_neighbors) + torch.matmul((1 - P_topk).unsqueeze(-1), q.unsqueeze(2))

        G = torch.tanh(X_agg)
        W_KA = torch.einsum('ijkl,ijkm->ijk', K_neighbors, G)
        P_KA = F.softmax(W_KA, dim=2).unsqueeze(2)
        X_topo = torch.matmul(P_KA, K_neighbors).squeeze(2)

        return X_topo + q


class DAttention(nn.Module):
    def __init__(self,input_dim,n_classes,TopoAggregator=None):
        super(DAttention, self).__init__()
        self.L = 512 #512
        self.D = 128 #128
        self.K = 1
        self.feature = [nn.Linear(input_dim, 512)]# nn.LayerNorm(input_dim),
        self.feature += [nn.ReLU()]# 

        #if dropout:
        self.feature += [nn.Dropout(0.25)]
        if TopoAggregator is not None:
            self.feature += [TopoAggregator] 
        self.feature = nn.Sequential(*self.feature)

        self.attention = nn.Sequential(
            # nn.LayerNorm(self.L),
            nn.Linear(self.L, self.D),
            # nn.LayerNorm(self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )
        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, n_classes),
        )

        self.apply(initialize_weights)
    def forward(self, x):

        feature = self.feature(x) #1 * N * 512
        feature = feature.squeeze(0) # N * 512 
        A = self.attention(feature) # N * 1

        A = torch.transpose(A, -1, -2)  # 1*N
        A = F.softmax(A, dim=-1)  # 1 * N

        M = torch.mm(A, feature)  # KxL
        Y_prob = self.classifier(M)

        return Y_prob


# utils

In [None]:
def seed_torch(seed=2021):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False   

def readCSV(filename):
    lines = []
    with open(filename, "r") as f:
        csvreader = csv.reader(f)
        for line in csvreader:
            lines.append(line)
    return lines

def get_patient_label(csv_file):
    patients_list=[]
    labels_list=[]
    label_file = readCSV(csv_file)
    for i in range(0, len(label_file)):
        patients_list.append(label_file[i][0])
        labels_list.append(label_file[i][1])
    a=Counter(labels_list)
    print("patient_len:{} label_len:{}".format(len(patients_list), len(labels_list)))
    print("all_counter:{}".format(dict(a)))
    return np.array(patients_list,dtype=object), np.array(labels_list,dtype=object)

def data_split(full_list, ratio, shuffle=True,label=None,label_balance_val=True):
    """
    dataset split: split the full_list randomly into two sublist (val-set and train-set) based on the ratio
    :param full_list: 
    :param ratio:     
    :param shuffle:  
    """
    # select the val-set based on the label ratio
    if label_balance_val and label is not None:
        _label = label[full_list]
        _label_uni = np.unique(_label)
        sublist_1 = []
        sublist_2 = []

        for _l in _label_uni:
            _list = full_list[_label == _l]
            n_total = len(_list)
            offset = int(n_total * ratio)
            if shuffle:
                random.shuffle(_list)
            sublist_1.extend(_list[:offset])
            sublist_2.extend(_list[offset:])
    else:
        n_total = len(full_list)
        offset = int(n_total * ratio)
        if n_total == 0 or offset < 1:
            return [], full_list
        if shuffle:
            random.shuffle(full_list)
        val_set = full_list[:offset]
        train_set = full_list[offset:]

    return val_set, train_set

def get_kflod(k, patients_array, labels_array,val_ratio=False,label_balance_val=True):
    if k > 1:
        skf = StratifiedKFold(n_splits=k)
    else:
        raise NotImplementedError
    train_patients_list = []
    train_labels_list = []
    test_patients_list = []
    test_labels_list = []
    val_patients_list = []
    val_labels_list = []
    for train_index, test_index in skf.split(patients_array, labels_array):
        if val_ratio != 0.:
            val_index,train_index = data_split(train_index,val_ratio,True,labels_array,label_balance_val)
            x_val, y_val = patients_array[val_index], labels_array[val_index]
        else:
            x_val, y_val = [],[]
        x_train, x_test = patients_array[train_index], patients_array[test_index]
        y_train, y_test = labels_array[train_index], labels_array[test_index]

        train_patients_list.append(x_train)
        train_labels_list.append(y_train)
        test_patients_list.append(x_test)
        test_labels_list.append(y_test)
        val_patients_list.append(x_val)
        val_labels_list.append(y_val)
        
    # print("get_kflod.type:{}".format(type(np.array(train_patients_list))))
    return np.array(train_patients_list,dtype=object), np.array(train_labels_list,dtype=object), np.array(test_patients_list,dtype=object), np.array(test_labels_list,dtype=object),np.array(val_patients_list,dtype=object), np.array(val_labels_list,dtype=object)

def optimal_thresh(fpr, tpr, thresholds, p=0):
    loss = (fpr - tpr) - p * tpr / (fpr + tpr + 1)
    idx = np.argmin(loss, axis=0)
    return fpr[idx], tpr[idx], thresholds[idx]

def five_scores(bag_labels, bag_predictions,sub_typing=False):#multi classification！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！！
    fpr, tpr, threshold = roc_curve(bag_labels, bag_predictions, pos_label=1)
    fpr_optimal, tpr_optimal, threshold_optimal = optimal_thresh(fpr, tpr, threshold)
    # threshold_optimal=0.5
    auc_value = roc_auc_score(bag_labels, bag_predictions)
    this_class_label = np.array(bag_predictions)
    this_class_label[this_class_label>=threshold_optimal] = 1
    this_class_label[this_class_label<threshold_optimal] = 0
    bag_predictions = this_class_label
    avg = 'macro' if sub_typing else 'binary'
    precision, recall, fscore, _ = precision_recall_fscore_support(bag_labels, bag_predictions, average=avg)
    accuracy = accuracy_score(bag_labels, bag_predictions)
    return accuracy, auc_value, precision, recall, fscore

In [None]:
def train_loop(model,loader,optimizer,device,amp_autocast,criterion,scheduler,k,epoch):
    start = time.time()
    train_loss_log = 0.
    model.train()
    for i, data in enumerate(loader):
        # print(i)
        optimizer.zero_grad()
        if isinstance(data[0],(list,tuple)):
            for i in range(len(data[0])):
                data[0][i] = data[0][i].to(device)
            bag=data[0]
            batch_size=data[0][0].size(0)
        else:
            bag=data[0].to(device)  # b*n*1024
            batch_size=bag.size(0)

        label=data[1].to(device)
        
        with amp_autocast():
            train_logits = model(bag)
            logit_loss = criterion(train_logits.view(batch_size,-1),label)

        train_loss = logit_loss 

        train_loss.backward()
        optimizer.step()

        train_loss_log = train_loss_log + train_loss.item()

    end = time.time()
    train_loss_log = train_loss_log/len(loader)
    scheduler.step()
##########################################
    model.eval()
    loss_cls_meter = AverageMeter()
    bag_logit, bag_labels=[], []
    # pred= []
    with torch.no_grad():
        for i, data in enumerate(loader):
            if len(data[1]) > 1:
                bag_labels.extend(data[1].tolist())
            else:
                bag_labels.append(data[1].item())

            if isinstance(data[0],(list,tuple)):
                for i in range(len(data[0])):
                    data[0][i] = data[0][i].to(device)
                bag=data[0]
                batch_size=data[0][0].size(0)
            else:
                bag=data[0].to(device)  # b*n*1024
                batch_size=bag.size(0)

            label=data[1].to(device)

            test_logits = model(bag)

            test_loss = criterion(test_logits.view(batch_size,-1),label)
            bag_logit.append(torch.softmax(test_logits,dim=-1)[:,1].cpu().squeeze().numpy())


            loss_cls_meter.update(test_loss,1)
    
    accuracy, auc_value, precision, recall, fscore = five_scores(bag_labels, bag_logit, False)#multi classification true############################################################################
###################################################

    return train_loss_log,start,end,[accuracy, auc_value, precision, recall, fscore]

def val_loop(model,loader,device,criterion,epoch):
    model.eval()
    loss_cls_meter = AverageMeter()
    bag_logit, bag_labels=[], []
    # pred= []
    with torch.no_grad():
        for i, data in enumerate(loader):
            if len(data[1]) > 1:
                bag_labels.extend(data[1].tolist())
            else:
                bag_labels.append(data[1].item())

            if isinstance(data[0],(list,tuple)):
                for i in range(len(data[0])):
                    data[0][i] = data[0][i].to(device)
                bag=data[0]
                batch_size=data[0][0].size(0)
            else:
                bag=data[0].to(device)  # b*n*1024
                batch_size=bag.size(0)

            label=data[1].to(device)

            test_logits = model(bag)

            test_loss = criterion(test_logits.view(batch_size,-1),label)
            bag_logit.append(torch.softmax(test_logits,dim=-1)[:,1].cpu().squeeze().numpy())


            loss_cls_meter.update(test_loss,1)
    
    accuracy, auc_value, precision, recall, fscore = five_scores(bag_labels, bag_logit, False)#multi classification true############################################################################
    

    return accuracy, auc_value, precision, recall, fscore,loss_cls_meter.avg


In [None]:
def one_fold(k,ckc_metric,train_p, train_l, test_p, test_l,val_p,val_l):
    # ---> Initialization
    seed_torch(seed)
    
    amp_autocast = torch.cuda.amp.autocast if amp else suppress
    
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    acs,pre,rec,fs,auc,te_auc,te_fs = ckc_metric

    train_set = Subtyping_Dataset(train_p[k],train_l[k])
    test_set = Subtyping_Dataset(test_p[k],test_l[k])
    if val_ratio != 0.:
        val_set = Subtyping_Dataset(val_p[k],val_l[k])
    else:
        val_set = test_set

    train_loader = DataLoader(train_set, batch_size=batch_size, sampler=RandomSampler(train_set), num_workers=num_workers)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    # bulid networks
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    age = TopoAggregator(dim_in=512, dim_hidden=512, topk=6).to(device)###############################################################################################################
    model = DAttention(input_dim=1024, n_classes=2, TopoAggregator=age).to(device)#######################################################################################PLIP：512，r50为1,

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epoch, 0)

    optimal_ac, opt_pre, opt_re, opt_fs, opt_auc,opt_epoch = 0, 0, 0, 0,0,0
    epoch_start = 0

    import pandas as pd
    metrics_df = pd.DataFrame(columns=[
        'epoch', 
        'train_loss', 
        'val_loss',
        'train_accuracy', 
        'val_accuracy',
        'train_auc_value', 
        'val_auc_value',
        'train_precision', 
        'val_precision', 
        'train_recall', 
        'val_recall', 
        'train_fscore',
        'val_fscore'
    ])
    train_time_meter = AverageMeter()
    for epoch in range(epoch_start, num_epoch):
        train_loss,start,end,train_metric = train_loop(model,train_loader,optimizer,device,amp_autocast,criterion,scheduler,k,epoch)
        train_time_meter.update(end-start)
        accuracy, auc_value, precision, recall, fscore, test_loss = val_loop(model,val_loader,device,criterion,epoch)
        val_metric=[accuracy, auc_value, precision, recall, fscore]
        print('\r Epoch [%d/%d] train loss: %.1E, test loss: %.1E, accuracy: %.3f, auc_value:%.3f, precision: %.3f, recall: %.3f, fscore: %.3f , time: %.3f(%.3f)' % 
        (epoch+1, num_epoch, train_loss, test_loss, accuracy, auc_value, precision, recall, fscore, train_time_meter.val,train_time_meter.avg))

        new_row = {
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': test_loss,
            'train_accuracy':train_metric[0], 
            'val_accuracy':val_metric[0],
            'train_auc_value':train_metric[1], 
            'val_auc_value':val_metric[1],
            'train_precision':train_metric[2], 
            'val_precision':val_metric[2], 
            'train_recall':train_metric[3], 
            'val_recall':val_metric[3], 
            'train_fscore':train_metric[4],
            'val_fscore':val_metric[4]
        }
        metrics_df = pd.concat([metrics_df, pd.DataFrame([new_row])], ignore_index=True)
        if auc_value > opt_auc:
            optimal_ac = accuracy
            opt_pre = precision
            opt_re = recall
            opt_fs = fscore
            opt_auc = auc_value
            opt_epoch = epoch

            if not os.path.exists(model_path):
                os.mkdir(model_path)

            best_pt = {
                'model': model.state_dict(),
            }
            
            torch.save(best_pt, os.path.join(model_path, 'fold_{fold}_epoch_{e}_best_auc_{oa}.pt'.format(fold=k,e=epoch,oa=opt_auc)))


    acs.append(optimal_ac)
    pre.append(opt_pre)
    rec.append(opt_re)
    fs.append(opt_fs)
    auc.append(opt_auc)
    csv_path = os.path.join(model_path, f'metrics_fold_{k}.csv')
    metrics_df.to_csv(csv_path, index=False)

    return [acs,pre,rec,fs,auc,te_auc,te_fs]


In [8]:
# set seed
seed_torch(seed)

p, l = get_patient_label(r"F:\lung_dl\data\subtyping\TCGA-BRCA R50\label.csv")#########################################################################################
index = [i for i in range(len(p))]
random.shuffle(index)
p = p[index]
l = l[index]

if cv_fold > 1:
    train_p, train_l, test_p, test_l,val_p,val_l = get_kflod(cv_fold, p, l,val_ratio)

acs, pre, rec,fs,auc,te_auc,te_fs=[],[],[],[],[],[],[]
ckc_metric = [acs, pre, rec,fs,auc,te_auc,te_fs]


for k in range(fold_start, cv_fold):
    print('Start %d-fold cross validation: fold %d ' % (cv_fold, k))
    ckc_metric = one_fold(k,ckc_metric,train_p, train_l, test_p, test_l,val_p,val_l)

print('Cross validation accuracy mean: %.3f, std %.3f ' % (np.mean(np.array(acs)), np.std(np.array(acs))))
print('Cross validation auc mean: %.3f, std %.3f ' % (np.mean(np.array(auc)), np.std(np.array(auc))))
print('Cross validation precision mean: %.3f, std %.3f ' % (np.mean(np.array(pre)), np.std(np.array(pre))))
print('Cross validation recall mean: %.3f, std %.3f ' % (np.mean(np.array(rec)), np.std(np.array(rec))))
print('Cross validation fscore mean: %.3f, std %.3f ' % (np.mean(np.array(fs)), np.std(np.array(fs))))

 Epoch [101/200] train loss: 1.2E-01, test loss: 4.8E-01, accuracy: 0.800, auc_value:0.861, precision: 0.508, recall: 0.775, fscore: 0.614 , time: 337.294(347.998)
 Epoch [102/200] train loss: 9.0E-02, test loss: 5.5E-01, accuracy: 0.795, auc_value:0.848, precision: 0.500, recall: 0.825, fscore: 0.623 , time: 353.360(348.051)
 Epoch [103/200] train loss: 1.0E-01, test loss: 5.5E-01, accuracy: 0.795, auc_value:0.845, precision: 0.500, recall: 0.775, fscore: 0.608 , time: 343.011(348.002)
 Epoch [104/200] train loss: 1.1E-01, test loss: 6.5E-01, accuracy: 0.754, auc_value:0.835, precision: 0.446, recall: 0.825, fscore: 0.579 , time: 345.725(347.980)
 Epoch [105/200] train loss: 9.7E-02, test loss: 5.5E-01, accuracy: 0.815, auc_value:0.852, precision: 0.534, recall: 0.775, fscore: 0.633 , time: 340.820(347.912)
 Epoch [106/200] train loss: 1.0E-01, test loss: 5.5E-01, accuracy: 0.774, auc_value:0.836, precision: 0.471, recall: 0.800, fscore: 0.593 , time: 342.346(347.859)
 Epoch [107/200]