In [None]:
import os

# The current working directory needs to be in explainable_TGCNN
print(os.getcwd())
os.chdir('..\\')
print(os.getcwd())

import pandas as pd
import numpy as np
import tensorflow as tf
from src import create_fake_patients, whole_model_demographics_gradcam, graph_plot, ga, gc, utils

In [None]:
def replace_visit_number(original_list):
    original_versions = []
    for tpl in original_list:
        for item in tpl:
            version = item.split('_')[1]  # Extract the version number part (e.g., 'v91')
            original_versions.append(version)

    # Create a mapping from original version numbers to new version numbers starting from 1
    unique_versions = sorted(set(original_versions), key=lambda x: int(x[1:]))
    version_mapping = {v: f"v{i+1}" for i, v in enumerate(unique_versions)}

    # Replace the version numbers in the original tuples
    new_list = [
        (
            f"{item1.split('_')[0]}_{version_mapping[item1.split('_')[1]]}",
            f"{item2.split('_')[0]}_{version_mapping[item2.split('_')[1]]}"
        )
        for item1, item2 in original_list
    ]
    return new_list

def add_tuples_to_dict(tuples_list, subgraph_dict, predicted_outcome):
    # Sort each tuple within the list to ensure consistent ordering
    sorted_tuples = tuple(sorted(tuples_list)) 
    
    if sorted_tuples in subgraph_dict:
        # If a subgraph exists, update the appropriate count
        if predicted_outcome > 0.5:
            subgraph_dict[sorted_tuples]['s+'] += 1
        else:
            subgraph_dict[sorted_tuples]['s-'] += 1
    else:
        # If a subgraph doesn't exist, add a new entry with a count of 1 in the appropriate column
        subgraph_dict[sorted_tuples] = {'s+': 1 if predicted_outcome > 0.5 else 0,
                                        's-': 1 if predicted_outcome <= 0.5 else 0}
    
    return subgraph_dict


In [None]:
second_TGCNN_layer = True
demo = True

include_drugs = True
max_timesteps=100

stride = 1
filter_size = 4
num_filters=16

run_name='hip_1999_to_one_year_advance_model'
years_in_advance = "5"

if include_drugs:
    max_event_codes = 518
else:
    max_event_codes = 512
hip_or_knee = 'hip'

# fake mapping dataframe for the ReadCodes and the corresponding descriptions
read_code_map_df = pd.read_csv('fake_read_code_descriptions.csv')

model = whole_model_demographics_gradcam.TGCNN_Model(num_filters=16, num_nodes=max_event_codes, num_time_steps=max_timesteps, 
                            filter_size=filter_size, variable_gamma=True, 
                            exponential_scaling=True, dropout_rate=0.7, lstm_units=64,
                            fcl1_units=128, LSTM_ablation=False, stride=stride, activation_type='LeakyReLU', 
                            no_timestamp=False, second_TGCNN_layer=second_TGCNN_layer, num_labels=1)
model.load_weights('hip_1999_to_one_year_advance_model1_CNN_layer')

# Load in the filters from the model
with open(f'hip_1999_to_one_year_advance_model1_filter.npy', 'rb') as f:
    filters = np.load(f)
filters = tf.cast(filters, dtype=tf.float16)

num_pats = 5
max_act_filt_num = 10
cv_patients = create_fake_patients.create_fake_patient_df(num_pats, 99, max_event_codes)

In [None]:
subgraph_dict = dict()

for pat in range(num_pats):
    
    filt_type = 'median' # 'mean', 'median', 'max'
    input_3d, input_4d, demo_tensor, outcome, outcome_bin = utils.return_pat_from_df(cv_patients, max_event_codes, hip_or_knee, pat, max_timesteps)
    dense_tensor = tf.sparse.to_dense(input_3d)
    dense_tensor= tf.transpose(dense_tensor, perm=[2, 1, 0])
    dense_tensor = np.flip(dense_tensor, axis=0) # change the most recent events to be at the end rather than the start
    dense_tensor = tf.cast(dense_tensor, tf.float16)
    filters_4d = ga.make_filts_4d(filters, filter_size, max_event_codes)
    
    f = ga.get_and_reshape_filt(filters_4d, 30, filt_type=filt_type)
    edge_act_graph = ga.filt_times_pat(f, dense_tensor, filter_size, max_timesteps, stride)
    edges_df = ga.create_edges_df_ga(dense_tensor, edge_act_graph) 
    
    # Get the node positions for the graph
    pos_df = graph_plot.create_position_df_gc(edges_df)
    pos_list = graph_plot.generate_pos_sequence(pos_df['max_codes_per_visit'].max())
    pos_df = graph_plot.map_y_coord_to_node(pos_df, pos_list)
    
    read_code_pos_df = ga.map_read_code_labels(pos_df, read_code_map_df)
    
    edge_pos_df = ga.create_edge_pos_df(edges_df, pos_df)

    # Remove repeat rows where edge_weight_perc == 0
    mask = edge_pos_df['edge_weight_perc'] != edge_pos_df['edge_weight_perc'].shift()
    df_unique_adjacent = edge_pos_df[mask].reset_index(drop=True)
    
    # Find indices where the edge_weight_perc col has a value of 0
    split_indices = df_unique_adjacent.index[df_unique_adjacent['edge_weight_perc'] == 0].tolist()
    
    # Add start and end indices to make it easier to split
    split_indices = [-1] + split_indices + [len(df_unique_adjacent)]
    
    # Split the DataFrame into chunks
    chunks = [df_unique_adjacent.iloc[split_indices[i]+1:split_indices[i+1]] for i in range(len(split_indices)-1)]

    logits = model(input_4d, demo_tensor, training=False)
    proba = tf.sigmoid(logits)
    
    for i, chunk in enumerate(chunks):
        if len(chunk) != 0:
            subgraph_list = list(zip(chunk['start_node'], chunk['end_node']))
            subgraph_list_adj_vis = replace_visit_number(subgraph_list)
            add_tuples_to_dict(subgraph_list_adj_vis, subgraph_dict, proba)

    if (pat % 100) == 0 and (pat !=0):
        print(f"Number of patients complete: {pat}")
        print(f"{(((pat+1)/num_pats)*100):.2f}% Complete")

In [None]:
# Flatten the dictionary and create a list of tuples
data = []
for subgraph, counts in subgraph_dict.items():
    data.append({
        'subgraph': subgraph,
        's+': counts['s+'],
        's-': counts['s-']
    })

# Convert the list of tuples into a DataFrame
df = pd.DataFrame(data)


# Get the ratio of the subgraphs for each class
df['total'] = df['s+'] + df['s-']
df['Rs+'] = df['s+'] / df['total'].replace(0, np.nan)  # Avoid division by zero
df['Rs-'] = df['s-'] / df['total'].replace(0, np.nan) 

df.drop(columns=['total'], inplace=True)
df