In [None]:
import json
import networkx as nx
import matplotlib.pyplot as plt

In [26]:
# Split symptoms into a list
def extract_symptoms(data):

    # Checking if data is a list, because we convert to string first
    if data is None:
        return []
    
    if isinstance(data, list):
        data = ", ".join(str (symptom) for symptom in data)
    else:
        data = str(data)


    # Splitting on commas
    split_comma_strings = [symptom.strip() for symptom in data.split(',')]

    # Want to handle any "and" that separates symptoms.
    # Ex. 'Headache and fever'
    # Ex. 'Headache'
    # Headache should count as a common symptom for both diseases
    split_and_symptoms = []
    for symp in split_comma_strings:
        symp_part = [symptom.strip() for symptom in symp.split(' and ')]
        split_and_symptoms.extend(symp_part)


    # Then just removing any trailing or leading punctuation if there is any.
    # Probably not but just making sure since its a large dataset
    strip_symptoms = []
    for symptom in split_and_symptoms:
        new = symptom.strip('.,')
        strip_symptoms.append(new)

    # Also normalizing the string to just lowercase for consistency
    symptoms = []
    for symptom in strip_symptoms:
        new = symptom.lower()
        symptoms.append(new)

    return symptoms

In [9]:
# Loading the json dataset and removing the treatment field since it isn't needed right now.
# Gotta iterate through each field and make a new dictionary

def remove_treatment_field():
    new_dataset = []

    try:

        with open('disease_database_en.json', 'r', encoding='utf-8') as fd:
            data = json.load(fd)

            for entry in data:
                new_entry = {
                    "disease_id": entry["disease_id"],
                    "disease": entry["disease"],
                    "symptoms": entry["common_symptom"]
                }
                new_dataset.append(new_entry)
                
        with open('new_dataset.json', 'w', encoding='utf-8') as output:
            json.dump(new_dataset, output, indent = 2)

        print("Saved dataset!")
        return new_dataset

    except Exception as error:
        print(f"An error occured: {error}")

remove_treatment_field()

Saved dataset!


[{'disease_id': '1656164150939770881',
  'disease': 'Carotid Sinus Syndrome',
  'symptoms': 'Syncope, Cardiac Arrest, Dizziness, Fatigue, Tinnitus, Cerebral Ischemia, Hypotension'},
 {'disease_id': '1656164157617102850',
  'disease': 'Cranial Osteomyelitis',
  'symptoms': 'Headache and fever, localized scalp redness, swelling, warmth, and pain'},
 {'disease_id': '1656164143901728770',
  'disease': 'Tuberculous Osteomyelitis',
  'symptoms': 'Low-grade fever, fatigue, night sweats, loss of appetite and weight, joint dysfunction, local swelling in limbs and joints due to tuberculosis, local pain, deformity'},
 {'disease_id': '1656164192798924801',
  'disease': 'Lipodystrophy',
  'symptoms': 'Hematuria, pyuria, acute abdominal pain, edema, fever, azotemia, hypertension, slow growth, fever, edema'},
 {'disease_id': '1656164179163242497',
  'disease': 'Vulvar Pemphigus',
  'symptoms': 'Blisters, pain, anorexia, mucosal damage, hypersensitivity, and ulcers'},
 {'disease_id': '1656164134867197

In [None]:
def network(filename, output_filename):
    try:
        # read data from input file
        with open(filename, 'r', encoding = 'utf-8') as f:
            data = json.load(f)
        G = nx.Graph()

        # add nodes and edges
        for entry in data:

            # It's getting the field as a list first in json
            if isinstance(entry.get("disease"), list):
                disease = str(entry.get("disease")[0]).strip()
            
            # It isn't a list
            else:
                disease = str(entry.get("disease", "")).strip()

            symptoms = extract_symptoms(entry.get("symptoms"))
            print(f"Symptoms: {symptoms}")

            if not disease:
                continue

            if not symptoms:
                continue

            G.add_node(disease, type = 'disease')

            for symptom in symptoms:

                if not symptom:
                    continue
                G.add_node(symptom, type = 'symptom')
                G.add_edge(disease, symptom)

        nx.write_graphml(G, output_filename)

        return G
    except Exception as e:
        print(f"An error occured: {e}")
        return None

network('new_dataset.json', 'network_graph.graphml')

Symptoms: ['syncope', 'cardiac arrest', 'dizziness', 'fatigue', 'tinnitus', 'cerebral ischemia', 'hypotension']
Symptoms: ['headache', 'fever', 'localized scalp redness', 'swelling', 'warmth', 'and pain']
Symptoms: ['low-grade fever', 'fatigue', 'night sweats', 'loss of appetite', 'weight', 'joint dysfunction', 'local swelling in limbs', 'joints due to tuberculosis', 'local pain', 'deformity']
Symptoms: ['hematuria', 'pyuria', 'acute abdominal pain', 'edema', 'fever', 'azotemia', 'hypertension', 'slow growth', 'fever', 'edema']
Symptoms: ['blisters', 'pain', 'anorexia', 'mucosal damage', 'hypersensitivity', 'and ulcers']
Symptoms: ['protrusion of bone surfaces', 'bone pain', 'swelling', 'pain in the long bones']
Symptoms: ['headache', 'dizziness', 'restlessness', 'vomiting', 'heart failure', 'renal failure', 'urea clearance disorder', 'nephrotic syndrome', 'increased diastolic pressure', 'fever', 'outpatient hypertension', 'edema']
Symptoms: ['palpitations', 'shortness of breath', 'cya

<networkx.classes.graph.Graph at 0x1d0b6b8cc10>

In [28]:
def visualize_network(G):
    plt.figure(figsize=(10,8))
    pos = nx.spring_layout(G)

    disease_nodes = [node for node, attr in G.nodes(data=True) if attr.get('type') == 'disease']
    symptom_nodes = [node for node, attr in G.nodes(data=True) if attr.get('type') == 'symptom']

    nx.draw_networkx_nodes(G, pos, nodelist=disease_nodes, node_color='blue', node_size=800, label='Diseases')
    nx.draw_networkx_nodes(G, pos, nodelist=symptom_nodes, node_color='red', node_size=600, label='Symptoms')
    nx.draw_networkx_edges(G, pos, alpha=0.5)
    nx.draw_networkx_labels(G, pos, font_size=10)

    plt.scatter([], [], color='blue', label='Diseases', s=100)
    plt.scatter([], [], color='red', label='Symptoms', s=100)
    plt.legend()
 
    plt.title("Disease-Symptom Network")
    plt.show()
