In [2]:
import os
import sys


def result_file_name(args):
    file_name = f"results_{args.fold_range}_{args.seed}"
    if len(args.modalities) > 0:
        file_name += "_"
        file_name += "_".join(args.modalities)
    return file_name

class Args:
    def __init__(self, modalities=["clinical", "miRNA", "mRNA", "WSI"]):
        self.dataset = 'kidney'
        self.modality_data_path = {'clinical': '../preprocess/preprocessed_data/clinical_kidney.csv',
                                    'mRNA': '../preprocess/preprocessed_data/mrna_kidney.csv',
                                    'miRNA': '../preprocess/preprocessed_data/mirna_kidney.csv',
                                    'WSI': '../preprocess/preprocessed_data/UNI2_features/TCGA-Kidney.pt'}
        self.device = "cuda"
        self.modalities = modalities
        self.input_modality_dim = {'clinical':4, 'mRNA':2746, 'miRNA':743 , 'WSI': 1536}
        self.fold_range = 5
        self.fold = 1
        self.modality_fv_len = 128
        self.batch_size = 128
        self.seed = 24
        self.model_path = f"../logs/dgsurv/{result_file_name(self)}"
        self.num_workers = 4
        self.num_modalities = len(self.modalities)
        self.split_path = "../splits/kidney_splits"
        self.remove_missing = True


args = Args()
sys.path.append(os.path.abspath(".."));

In [3]:
import math
import torch
from torch import nn
import torch.nn.functional as F
from method.GraphLearner import GraphLearner
from method.GraphPooling import GraphPooling
from method.GraphDataUtils import GraphInputProcessor


class GraphModel(nn.Module):
    def __init__(self, n_views, n_in_feats, encoder_model,
                 gnn_arch='gcn',
                 pool_ratio=.1, sparse_threshold=0.1,
                 device='cuda'):
        super(GraphModel, self).__init__()
        self.encoder_model = encoder_model

        n_edge_types = int(n_views**2)
        self.graph_input_processor = GraphInputProcessor(
            n_edge_types=n_edge_types, device=device)

        self.graph_learner = GraphLearner(
            n_in_feats=n_in_feats+n_views,
            n_out_feats=100,
            threshold=sparse_threshold,
            n_heads=1,
            device=device
        )

        n_nodes = max(math.ceil(n_views*pool_ratio), 1)
        self.n_nodes = n_nodes

        n_feats = n_in_feats
        self.pool = GraphPooling(
            n_in_feats, n_nodes,
            n_feats, n_edge_types,
            gnn_arch=gnn_arch, pool=True
        )

        self.output_dim = n_feats
        self.hazard_layer1 = nn.Linear(self.output_dim, 1)

        self.label_layer1 = nn.Linear(self.output_dim, 2)
			
    def forward(self, x_modality, mask):
        representation = self.encoder_model(x_modality)
        missing_rep = []
        for i, rep in enumerate(representation):
            index = torch.ones((rep.shape[0]), dtype=int) * i 
            index = index.to(rep.device)
            modality_mask = mask[:, i].reshape(-1, 1)
            rep =  modality_mask * rep
            missing_rep.append(rep)
        
        representation_dict = {} #used for self supervised loss
        for i, data in enumerate(missing_rep):
            representation_dict[i] = data
                     
        in_graph = self.graph_input_processor(missing_rep)
        learned_graph, _ = self.graph_learner(in_graph)
        latent_graph, assignment_mat = self.pool(learned_graph)
        final_representation = latent_graph[0].squeeze(1)
        hazard = self.hazard_layer1(final_representation)
        score = F.log_softmax(self.label_layer1(final_representation), dim=1)
        return {'hazard':hazard, 'score':score}, representation_dict, learned_graph[1], assignment_mat

In [4]:
import shap
import numpy as np
from utils.encoder import EncoderModel
from torch.utils.data import DataLoader
from utils.dataset import MultiModalDataset


class ClinicalEmbedder(nn.Module):
	def __init__(self, m_length, n_continuous=1, embedding_size=[(33, 2), (2, 1), (6, 3), (145, 2)]):
		super(ClinicalEmbedder, self).__init__()
  		# Embedding Layer
		self.embedding_layers = nn.ModuleList([nn.Embedding(categories, size)
												for categories, size in embedding_size])
		n_emb = sum(e.embedding_dim for e in self.embedding_layers)
		self.n_emb, self.n_continuous = n_emb, n_continuous
		# Linear Layer
		self.hidden1 = nn.Linear(self.n_emb + self.n_continuous, m_length)
		# batch normalization
		self.bn1 = nn.BatchNorm1d(self.n_continuous)
		self.emb_drop = nn.Dropout(0.4)
    
	def forward(self, x_categorical, x_continuous):
		x = [e(x_categorical[:, i]) for i, e in enumerate(self.embedding_layers)]
		x = torch.cat(x, 1)
		x = self.emb_drop(x)
		x2 = self.bn1(x_continuous)
		x = torch.cat([x, x2], 1) # Note no linear layer was used in the end
		return x

class InterpEncoderModel(EncoderModel):
	def __init__(self, modalities, modality_fv_len, input_modality_dim, drop_out_p=0.5):
		super(InterpEncoderModel, self).__init__(modalities, modality_fv_len, input_modality_dim, drop_out_p=0.5)
			
	def forward(self, x_modality):
		# Extract representations from different modalities
		representation = []
		for modality in self.data_modalities:
			representation.append(self.modality_pipeline[modality](x_modality[modality]))
		return representation


def embed_clinical(x_modal, clinical_embedder, modalities):
    # Replace the clinical data with the embedding
    clin_embed = clinical_embedder(x_modal['clinical_categorical'], x_modal['clinical_continuous'])
    x_modal['clinical'] = clin_embed
    x_modal.pop('clinical_categorical', None)
    x_modal.pop('clinical_continuous', None)
    new_x_modal = {}
    new_x_modal['clinical'] = clin_embed
    for key in modalities:
        if key != 'clinical':
            new_x_modal[key] = x_modal[key]
    return new_x_modal
    
def calculate_shap_values(args, model, clinical_embedder, device):
    train_dataset = MultiModalDataset(args, 'train', args.modalities, args.modality_data_path)
    train_dataloader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=False, num_workers=args.num_workers)
    for x_modal, _, mask in train_dataloader:
        for modality in x_modal:
            x_modal[modality] = x_modal[modality].to(device)
        mask = mask.to(device)
        x_modal = embed_clinical(x_modal, clinical_embedder, args.modalities)
    e = shap.DeepExplainer(model, x_modal, mask)

    test_dataset = MultiModalDataset(args, 'test', args.modalities, args.modality_data_path)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    all_shape_values = []
    for x_modal, _, mask in test_dataloader:
        for modality in x_modal:
            x_modal[modality] = x_modal[modality].to(device)
        mask = mask.to(device)
        x_modal = embed_clinical(x_modal, clinical_embedder, args.modalities)
        shap_values = e.shap_values(x_modal, mask)
        all_shape_values.append(shap_values)

    cat_shap_values = [None] * len(all_shape_values[0])
    for shap_value_list in all_shape_values:
        for i, shap_value in enumerate(shap_value_list):
            if cat_shap_values[i] is None:
                cat_shap_values[i] = shap_value
            else:
                cat_shap_values[i] = np.concatenate((cat_shap_values[i], shap_value), axis=0)
    return cat_shap_values

In [5]:
import numpy as np


all_modalities = ["clinical", "miRNA", "mRNA", "WSI"]
all_modalities = np.array(all_modalities)
all_shape_values = []
model_base_path = args.model_path

for fold in range(args.fold_range):
    args = Args(all_modalities)
    args.fold = fold
    device = torch.device(args.device)
    encoder_model = InterpEncoderModel(args.modalities, args.modality_fv_len, args.input_modality_dim)
    model = GraphModel(args.num_modalities, args.modality_fv_len, encoder_model, device=args.device)  
         
    model_path = os.path.join(args.model_path, f"model_fold{args.fold}.pt")
    model.load_state_dict(torch.load(model_path))
    model.to(device)

    # Change the clinical_submodel to ClinicalEmbedder and LinearLayer
    model_dict = torch.load(model_path)
    clinical_embedder = ClinicalEmbedder(args.modality_fv_len)
    clinical_embedder_dict = {}
    for key in model_dict:
        prefix = "encoder_model.clinical_submodel"
        if prefix in key:
            clinical_embedder_dict[key[len(prefix)+1:]] = model_dict[key]
    clinical_embedder.load_state_dict(clinical_embedder_dict)
    clinical_embedder.to(device)
    clinical_embedder.eval()
    model.eval()
    model.encoder_model.clinical_submodel = clinical_embedder.hidden1
    model.encoder_model.modality_pipeline['clinical'] = clinical_embedder.hidden1

    shap_values = calculate_shap_values(args, model, clinical_embedder, device)
    all_shape_values.append(shap_values)

cat_shap_values = [None] * len(all_shape_values[0])
for shap_value_list in all_shape_values:
    for i, shap_value in enumerate(shap_value_list):
        if cat_shap_values[i] is None:
            cat_shap_values[i] = shap_value
        else:
            cat_shap_values[i] = np.concatenate((cat_shap_values[i], shap_value), axis=0)
print(args.dataset)
for shap_value in cat_shap_values:
    print(shap_value.shape)    
torch.save( cat_shap_values, f"./output/{args.dataset}_shap_values.pt")



kidney
(939, 9, 1)
(939, 743, 1)
(939, 2746, 1)
(939, 1536, 1)
