# An experimental notebook for testing the CFBCT

In [None]:
import os
import sys
sys.argv = ['run.py']
from utils.options import process_args
from dataset.dataset_survival_tcga import Generic_MIL_Survival_Dataset
from utils.utils import *

## 1. Parameters

In [43]:
args = process_args()

# Dataset organization
ORGANIZE="TCGA" # TCGA; CPTAC
# Dataset root dir
BASE_DIR=f"/mnt/jzy8T/jzy/{ORGANIZE}"
# Dataset: BLCA; BRCA; LUAD; UCEC; GBMLGG; COAD; HNSC; STAD; 
DATASET="STAD"  
# Omic-modal: snn; mlp; mmlp
# Path-modal: deepset; amil; tmil; amisl; clam-sb; clam-mb; mqp
# Muti-modal: mcat; cmta; cfbct; mbct; motcat; porpoise; survpath; ponet; mgct
args.model_type='cfbct' 
# Modality: omic; path; cluster; coattn
args.mode='coattn'
# Omega_k: 0.4,0.6,0.8,1.0
args.W_k=1.0 
# Use tensorboard ? default: False
args.log_data=False
# Use function groups or pathway groups? default: True
args.apply_sig=True
# Save model? default: False 
args.save_pkl=False
# Save model checkpoints? default: False 
args.save_ckp=False
# 
args.data_root_dir=f"{BASE_DIR}/{DATASET}"
args.split_dir=f"{ORGANIZE.lower()}_{DATASET.lower()}"
args.n_classes = 4
args.dataset_path='dataset_csv'

# select fold
fold= 0
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_torch(args.seed)


## 2. Load Dataset

In [46]:
csv_path = './%s/%s_all_clean.csv.zip' % (args.dataset_path, args.split_dir)
# loading dataset 
dataset = Generic_MIL_Survival_Dataset(csv_path=csv_path,
                                        mode=args.mode,
                                        apply_sig=args.apply_sig,
                                        data_dir=args.data_root_dir,
                                        shuffle=False,
                                        seed=args.seed,
                                        print_info=True,
                                        patient_strat=False,
                                        n_bins=4,
                                        label_col='survival_months',
                                        ignore=[])

from utils.generate_utils import *
from utils.utils import get_split_loader

# split dataset 
split_dir = os.path.join('./splits', '5foldcv', args.split_dir)
train_dataset, val_dataset =dataset.return_splits(from_id=False,csv_path='{}/splits_{}.csv'.format(split_dir, fold))
# generate omic_sizes 
args.omic_sizes = train_dataset.omic_sizes
# generate loader 
train_loader = get_split_loader(train_dataset, testing = False, mode=args.mode, batch_size=args.batch_size)
val_loader = get_split_loader(val_dataset,  testing = False, mode=args.mode, batch_size=args.batch_size)

(0, 0) : 0
(0, 1) : 1
(1, 0) : 2
(1, 1) : 3
(2, 0) : 4
(2, 1) : 5
(3, 0) : 6
(3, 1) : 7
label column: survival_months
label dictionary: {(0, 0): 0, (0, 1): 1, (1, 0): 2, (1, 1): 3, (2, 0): 4, (2, 1): 5, (3, 0): 6, (3, 1): 7}
number of classes: 8
slide-level counts:  
 7    112
5     48
4     30
6     30
2     30
3     11
0     30
1     27
Name: label, dtype: int64
Patient-LVL; Number of samples registered in class 0: 30
Slide-LVL; Number of samples registered in class 0: 30
Patient-LVL; Number of samples registered in class 1: 27
Slide-LVL; Number of samples registered in class 1: 27
Patient-LVL; Number of samples registered in class 2: 30
Slide-LVL; Number of samples registered in class 2: 30
Patient-LVL; Number of samples registered in class 3: 11
Slide-LVL; Number of samples registered in class 3: 11
Patient-LVL; Number of samples registered in class 4: 30
Slide-LVL; Number of samples registered in class 4: 30
Patient-LVL; Number of samples registered in class 5: 48
Slide-LVL; Numbe

## 3. Load Model Checkpoint

In [47]:
# Path to the downloaded model checkpoint
ckp_base='<checkpoint path>'
args.path_load_model=os.path.join(ckp_base,f's_{fold}_checkpoint.pt')
# The weights of the three branches when the optimal performance is preserved
cfc=[0.1,0.1,0.1] 
# 
model=generate_model(args=args).to(device)

if device.type=='cpu':
    checkpoint= torch.load(args.path_load_model,map_location=lambda storage, loc: storage)
else:
    checkpoint= torch.load(args.path_load_model)

model.load_state_dict(checkpoint,strict=False)

<All keys matched successfully>

## 4. Runing Experiment on Val Loader

In [48]:
model.eval()
patient_results = {}

slide_ids = val_loader.dataset.slide_data['slide_id']
loader_len=len(slide_ids)

all_risk_scores = np.zeros((len(val_loader)))
all_censorships = np.zeros((len(val_loader)))
all_event_times = np.zeros((len(val_loader)))

for batch_idx, (data_WSI, data_omic, label, event_time, c) in enumerate(val_loader):

        data_WSI = data_WSI.cuda()
        data_omic1 = data_omic[0][0].type(torch.FloatTensor).to(device)
        data_omic2 = data_omic[0][1].type(torch.FloatTensor).to(device)
        data_omic3 = data_omic[0][2].type(torch.FloatTensor).to(device)
        data_omic4 = data_omic[0][3].type(torch.FloatTensor).to(device)
        data_omic5 = data_omic[0][4].type(torch.FloatTensor).to(device)
        data_omic6 = data_omic[0][5].type(torch.FloatTensor).to(device)
        label = label.type(torch.LongTensor).cuda()
        c = c.type(torch.FloatTensor).cuda()

        slide_id = slide_ids.iloc[batch_idx]

        with torch.no_grad():
            output = model(cfc=cfc,x_path=data_WSI, x_omic1=data_omic1, x_omic2=data_omic2, x_omic3=data_omic3, x_omic4=data_omic4, x_omic5=data_omic5, x_omic6=data_omic6)
        risk = -torch.sum(output['S'], dim=1).cpu().numpy()
        patient_results.update({
            slide_id: {'slide_id': np.array(slide_id), 
                    'risk': risk, 
                    'disc_label': label.item(), 
                    'survival': event_time.item(), 
                    'censorship': c.item()}})
        slide_ids = val_loader.dataset.slide_data['slide_id']
      
        risk = -torch.sum(output['cf_S'], dim=1).cpu().numpy()

        all_risk_scores[batch_idx] = risk
        all_censorships[batch_idx] = c.cpu().numpy()
        all_event_times[batch_idx] = event_time
        patient_results.update({slide_id: {'slide_id': np.array(slide_id), 'risk': risk, 'disc_label': label.item(), 'survival': event_time.item(), 'censorship': c.item()}})


## 5. Cindex Result

In [49]:
c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0]
print("val c-index: {:.4f}".format(c_index))

val c-index: 0.7372


: 