In [1]:
import torch, os
import torch.nn as nn
import torch_geometric
import torch.nn.functional as F
from torch_geometric.utils.convert import to_networkx

import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
from itertools import combinations

## **Intuition**
    disease - graph 
    symtoms - nodes

In [4]:
node_features_csv = 'data/Symtoms/Symptom-severity.csv'
graph_data_csv = 'data/Symtoms/dataset.csv'

In [5]:
def extract_node_features():
    df_node_features = pd.read_csv(node_features_csv)
    df_node_features.Symptom = df_node_features.Symptom.str.lower()
    df_node_features.Symptom = df_node_features.Symptom.str.strip()

    Symptoms = df_node_features['Symptom'].values 
    Severities = df_node_features['weight'].values
    feature_dict = {symtom : (idx, severity) for idx, (symtom, severity) in enumerate(zip(Symptoms, Severities))}
    return feature_dict

In [6]:
def extract_graph_data():
    df_graph_data = pd.read_csv(graph_data_csv)
    df_graph_data = shuffle(df_graph_data)
    diseases = df_graph_data['Disease'].values

    symtoms = df_graph_data.drop(['Disease'], axis=1)
    symtoms = symtoms.astype(str)
    for symtom in symtoms.columns:
        symtoms[symtom] = symtoms[symtom].str.lower()
        symtoms[symtom] = symtoms[symtom].str.strip()
    symtoms = symtoms.values

    graphs = {}
    for disease, symtom_arr in zip(diseases, symtoms):
        symtoms = [symtom for symtom in symtom_arr if symtom != 'nan']
        if disease not in graphs:
            graphs[disease] = [symtoms]
        else:
            graphs[disease].append(symtoms)

    disease_dict = {disease : idx for idx, disease in enumerate(graphs.keys())}
    return graphs, disease_dict

def create_edges(symptom_arr):
    symptom_arr = list(set(symptom_arr))
    pairs_ij = list(combinations(symptom_arr, 2))
    pairs_ji = [(j,i) for i,j in pairs_ij]
    pairs = list(set(pairs_ij + pairs_ji))
    src_node = [i for i,j in pairs]
    dst_node = [j for i,j in pairs]
    edge_index = torch.tensor([src_node, dst_node])
    return edge_index

def create_graph(graphs, feature_dict, disease_dict):
    graph_obj_arr = []
    for disease, symptoms in graphs.items():
        for symptom_arr in symptoms:
            feature_arr = [feature_dict[symptom][1] for symptom in symptom_arr]
            symptom_arr = [feature_dict[symptom][0] for symptom in symptom_arr]
            edge_index = create_edges(symptom_arr)
            
            x = np.array(feature_arr)
            x = x.reshape(-1, 1)
            x = torch.from_numpy(x)

            y = disease_dict[disease]
            y = torch.tensor([y])
            graph_obj = torch_geometric.data.Data(
                                                x = x,
                                                edge_index = edge_index,
                                                # y = y
                                                )

            graph_obj_arr.append(graph_obj)
    graph_batch = torch_geometric.data.Batch().from_data_list(graph_obj_arr)
    return graph_batch

In [7]:
feature_dict = extract_node_features()
graphs, disease_dict = extract_graph_data()
graph_batch = create_graph(graphs, feature_dict, disease_dict)