In [1]:
# importing relevant libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from gensim.models import Word2Vec
from stellargraph import StellarGraph
from stellargraph.data import BiasedRandomWalk
from sklearn.manifold import TSNE
from time import time
import warnings

warnings.filterwarnings('ignore')

#### Constants

In [2]:
# constants
metadata_columns = ['Diagnosis Code','Description','CMS-HCC Model Category V24']
inpatient_data_columns = ['empi','visit_id','visit_start_date','primary_diagnosis']
outpatient_data_columns = ['empi','visit_id','last_date_of_service','primary_diagnosis']

#### Read Data

In [3]:
# reading the risk adjustment metadata file
metadata = pd.read_csv("../data/metadata/2022 Midyear_Final ICD-10-CM Mappings.csv")

# reading the inpatient and outpatient data 
inpatient_data = pd.read_csv("../data/patient-data/df_preprocessed.csv")
outpatient_data = pd.read_csv("../data/patient-data/df_outpatient.csv")

#### Helper Functions

In [230]:
def prepare_metadata(metadata):
    """
    Cleans and prepares the HCC metadata file
    """
    # cleaning the risk adjustment metadata file
    metadata = metadata.iloc[2:-7,:]
    metadata.replace(r'\n',' ', regex=True, inplace=True)
    metadata.reset_index(drop=True, inplace=True)

    # creating the cleaned risk adjustment metadata dataframe
    new_metadata = pd.DataFrame(metadata.iloc[1:,:])
    new_metadata.columns = metadata.iloc[0,:].tolist()

    # filtering only the required columns from patient data and metadata
    new_metadata = new_metadata.loc[:,metadata_columns]
    new_metadata.columns = ['pd','dscr','hcc']
    new_metadata.loc[:,"hcc"] = new_metadata.hcc.fillna(0).astype('int')

    return new_metadata


def prepare_patient_data(inpatient_data, outpatient_data):
    """
    filters only the required columns from inpatient and outpatient data
    """
    inpatient_data = inpatient_data.loc[:,inpatient_data_columns]
    outpatient_data = outpatient_data.loc[:,outpatient_data_columns]

    inpatient_data.columns = ['empi','vid','vdt','pd']
    outpatient_data.columns = ['empi','vid','vdt','pd']

    return inpatient_data, outpatient_data


def create_patient_hcc_mapping(patient_df, hcc_df):
    """Maps ICD-10 codes to HCCs and prepares the data for 
    processing into adjacency matrices

    Args:
        patient_df (_type_): IP hospitalisation data
        hcc_df (_type_): HCC mapping for various (sub)models

    Returns:
        data: preprocessed data
    """
    join_params = {
    'left':patient_df,
    'right':metadata,
    'on':'pd',
    'how':'left'
    }
    data_merged = pd.merge(**join_params)
    data_hcc_dummies = pd.get_dummies(data_merged.hcc).iloc[:,1:]
    data = pd.concat([data_merged,data_hcc_dummies], axis=1)
    agg_dict = dict(zip(data_hcc_dummies.columns,np.repeat('sum', len(data_hcc_dummies.columns))))
    data = data.groupby('empi', as_index=False).aggregate(agg_dict)
    print("HCC Mapping done...")
    data.iloc[:,1:] = data.iloc[:,1:].ne(0)*1
    return data


def filter_edge_data_by_node(node, edge_data=edge_data):
    n_i = edge_data.columns[0]
    n_j = edge_data.columns[1]
    weight = edge_data.columns[2]
    filter1 = edge_data[n_i]==node
    filter2 = edge_data[n_j]==node
    return edge_data[filter1 | filter2].sort_values(weight)


def filter_edge_data_by_source_node(node, edge_data=edge_data):
    n_i = edge_data.columns[0]
    weight = edge_data.columns[2]
    filter1 = edge_data[n_i]==node
    return edge_data[filter1].sort_values(weight)


def filter_edge_data_by_target_node(node, edge_data=edge_data):
    n_j = edge_data.columns[1]
    weight = edge_data.columns[2]
    filter1 = edge_data[n_j]==node
    return edge_data[filter1].sort_values(weight)


def plot_disease_graph(node_data, edge_data):

    # Storing data in a networkx graph object (for graph visualisation)
    G = nx.Graph()

    for node in node_data.index:
        G.add_node(node, hcc_count=node_data[node])

    for _, edges in edge_data.iterrows():
        G.add_edge(edges[0],edges[1], e_ij_sim=edges[4])

    # plotting IP disease graph
    plt.figure(figsize=(50,40))
    pos = nx.spring_layout(G, weight='e_ij_sim')
    nx.draw(G, pos)
    nx.draw_networkx_labels(G, pos, font_size=40)
    nx.draw_networkx_edge_labels(G, pos)
    plt.show()


def create_adjacency_matrices(edge_data):
    """
    Creates adjacency matrix of order 1 and proximity matrix
    of order 2"""
    hcc_set = set(edge_data.n_i) | set(edge_data.n_j)
    adjacency_matrix_1 = pd.DataFrame(0,columns=hcc_set, index=hcc_set)
    adjacency_matrix_2 = pd.DataFrame(0,columns=hcc_set, index=hcc_set)

    for _, edge in edge_data.iterrows():
        adjacency_matrix_1.loc[edge[0],edge[1]] = 1
        adjacency_matrix_1.loc[edge[1],edge[0]] = 1
    
    for i in range(adjacency_matrix_1.shape[0]):
        for j in range(i+1, adjacency_matrix_1.shape[0]):
            adjacency_matrix_2.iloc[i,j] = np.sum(adjacency_matrix_1.iloc[i,:]*adjacency_matrix_1.iloc[j,:])
            adjacency_matrix_2.iloc[j,i] = np.sum(adjacency_matrix_1.iloc[i,:]*adjacency_matrix_1.iloc[j,:])
    
    return adjacency_matrix_1, adjacency_matrix_2


def create_stellargraph(node_data, edge_data):
    node_data_df = pd.DataFrame({'x':node_data.values}, index=node_data.index).astype(float)
    edge_data_df = edge_data[['n_i','n_j','e_ij_sim']].copy().astype(float)
    edge_data_df.columns = ['source','target','weight']

    G = StellarGraph(node_data_df, edges=edge_data_df)
    return G


def biased_random_walk(G, weighted=False, n=100):
    """
    Performs biased random walk, return walks paths"""
    rw = BiasedRandomWalk(G)

    walks = rw.run(
        nodes=list(G.nodes()),  # root nodes
        length=20,  # maximum length of a random walk
        n=n,  # number of random walks per root node
        p=0.5,  # Defines (unormalised) probability, 1/p, of returning to source node
        q=2,  # Defines (unormalised) probability, 1/q, for moving away from source node (intuitively, ratio of BFS:DFS)
        weighted=weighted,
        seed=42
    )
    print("Number of random walks: {}".format(len(walks)))
    return walks


def create_node_embeddings(walks, nodes_list):
    
    model = Word2Vec(walks, window=5, min_count=0, sg=1, workers=1)
    
    embeddings = pd.DataFrame()
    for node in nodes_list:
        embeddings[str(node)] = model.wv[int(node)]
    
    return embeddings


def TSNE_plot_node_embeddings(embeddings):

    transform = TSNE  # PCA
    trans = transform(n_components=2)
    tsne_coordinates = trans.fit_transform(embeddings.T)

    alpha = 0.7

    plt.figure(figsize=(20, 18))
    plt.axes().set(aspect="equal")
    plt.scatter(
        tsne_coordinates[:, 0],
        tsne_coordinates[:, 1],
        cmap="jet",
        alpha=alpha,
    )
    plt.title("{} visualization of node embeddings".format(transform.__name__))
    for i in range(tsne_coordinates.shape[0]):
        plt.annotate(embeddings.columns[i], (tsne_coordinates[i,0], tsne_coordinates[i,1]))
    plt.show()


def plot_embeddings_correlation(embeddings):
    plt.figure(figsize=(20,16))
    mask = np.triu(np.ones_like(embeddings.corr(), dtype=bool))
    sns.heatmap(embeddings.corr(), mask=mask)



#### Preparing Data

In [5]:
# preparing cleaned risk adjustment metadata
metadata = prepare_metadata(metadata)

# filtering only the required columns from inpatient and outpatient data
inpatient_data, outpatient_data = prepare_patient_data(inpatient_data, outpatient_data)

In [6]:
print(f"Patient Data: {inpatient_data.shape}\n{inpatient_data.head()}")
print(f"\nOutpatient Data: {outpatient_data.shape}\n{outpatient_data.head()}")
print(f"\nMetadata: {metadata.shape}\n{metadata.head()}")

Patient Data: (20541, 4)
          empi                               vid         vdt     pd
0  M0000040556  nM0000040556:1088927097671487508  2018-01-20  K5641
1  M0000040556  nM0000040556:1801256381439324181  2018-02-04   G458
2  M0000040556  nM0000040556:1339948222969081413  2018-02-18   K254
3  M0000040556  nM0000040556:1014145580172622435  2018-04-13  R5381
4  M0000040556   nM0000040556:130095445752129940  2018-06-27   I160

Outpatient Data: (371942, 4)
          empi                               vid         vdt      pd
0  M0000040556  nM0000040556:1002776852509021534  2018-06-03    N183
1  M0000040556  nM0000040556:1003353323421991148  2018-11-07  I87312
2  M0000040556  nM0000040556:1005893621657975338  2019-05-06   I2510
3  M0000040556  nM0000040556:1006150904029681462  2018-10-26  I69354
4  M0000040556  nM0000040556:1007502285099269954  2018-12-03   K8590

Metadata: (10981, 3)
      pd                   dscr  hcc
1  A0103      Typhoid pneumonia  115
2  A0104      Typhoid arthr

#### Combining IP & OP Data

In [7]:
data_combined = pd.concat([inpatient_data, outpatient_data], axis=0)

### Weighted Pagerank

In [8]:
join_params = {
'left':data_combined,
'right':metadata,
'on':'pd',
'how':'left'
}

data_combined_pagerank = pd.merge(**join_params)

In [9]:
data_combined_pagerank['hcc'] = data_combined_pagerank.hcc.replace(0,np.nan)
data_combined_pagerank = data_combined_pagerank.dropna()
data_combined_pagerank = data_combined_pagerank.sort_values(['empi','vdt']).reset_index(drop=True)
data_combined_pagerank["hcc_nxt"] = data_combined_pagerank.hcc.shift(-1)
index_drop = data_combined_pagerank.groupby('empi').tail(1).index
data_combined_pagerank = data_combined_pagerank.drop(index_drop)
data_combined_pagerank = data_combined_pagerank[data_combined_pagerank.hcc!=data_combined_pagerank.hcc_nxt].reset_index(drop=True)

In [10]:
data_combined_pagerank['sno'] = (data_combined_pagerank.empi+"-src-hcc"+data_combined_pagerank.hcc.astype(str)+"-tgt-hcc"+data_combined_pagerank.hcc_nxt.astype(str))
data_combined_pagerank = data_combined_pagerank.groupby('sno').head(1).reset_index(drop=True).drop(['sno'], axis=1)

In [11]:
data_combined_pagerank = data_combined_pagerank[['hcc','hcc_nxt']].astype(str)
data_combined_pagerank['edge'] = data_combined_pagerank.hcc + ", " + data_combined_pagerank.hcc_nxt
data_combined_pagerank['cnt'] = 1

In [185]:
edge_data = data_combined_pagerank.groupby(['hcc','hcc_nxt'], as_index=False).agg({'cnt':'count'}).astype(float)
edge_data.columns = ['source','target','weight']

In [186]:
edge_data.shape

(2752, 3)

In [217]:
edge_data = edge_data[edge_data.weight>=10] # remove edges having less than 10 count (weight)
edge_data.shape

(561, 3)

In [218]:
G = nx.from_pandas_edgelist(edge_data, 'source', 'target', create_using=nx.DiGraph)
G_weighted = nx.from_pandas_edgelist(edge_data, 'source', 'target', create_using=nx.DiGraph, edge_attr='weight')


In [219]:
weighted_pagerank = nx.pagerank(G_weighted, alpha=0.85)
pagerank_importances = pd.Series(weighted_pagerank.values(), index=weighted_pagerank.keys()).sort_values()
pagerank_importances.index.name='hcc'

#### Relational Scorer

In [221]:
def compute_in_degree(node, edge_data=edge_data):
    """Computes the in-degree of node

    Args:
        node (_type_): _description_
        edge_data (_type_): _description_
    """
    return filter_edge_data_by_target_node(node, edge_data).weight.sum()


def identify_outgoing_neighbours(base_nodes, edge_data=edge_data):
    """returns the set of outgoing neighbours of base nodes
    Args:
        base_nodes (list): list of base disease
        edge_data (_type_, optional): _description_. Defaults to edge_data.

    Returns:
        _type_: set of neighbour disease nodes
    """
    neighbours = set()
    for node in base_nodes:
        neighbours = neighbours | set(filter_edge_data_by_source_node(node).target)
    return neighbours


def compute_relational_score(base_nodes, edge_data=edge_data):
    score_dict = {}
    tgt_nodes = identify_outgoing_neighbours(base_nodes)
    for node in tgt_nodes:
        edge_data_by_tgt_node = filter_edge_data_by_target_node(node)
        
        weights = edge_data_by_tgt_node[edge_data_by_tgt_node.source.isin(base_nodes)].rename(index=edge_data.source).weight.sort_index()
        importances = pagerank_importances[pagerank_importances.index.isin(base_nodes)].sort_index()
        in_degree = compute_in_degree(node)

        score = (weights*importances).sum()
        score_dict[node] = score
    return pd.Series(score_dict).sort_values().replace(0,np.nan).dropna()
    

In [258]:
base_nodes = [103]
scores = compute_relational_score(base_nodes)
scores.tail(10)

79.0     0.150335
111.0    0.150335
136.0    0.180402
2.0      0.195435
19.0     0.195435
99.0     0.210469
18.0     0.210469
85.0     0.285636
96.0     0.345770
100.0    0.481071
dtype: float64