In [1]:
# Import libraries
import os
import pandas as pd
import numpy as np
import requests
import json
from pathlib import Path
from neo4j import GraphDatabase
from tqdm import tqdm

In [2]:
# Define CDC endpoint for the new dataset
url = "https://data.cdc.gov/resource/x9gk-5huc.json?$limit=50000"

# Define base URL and parameters for paginated API access
base_url = "https://data.cdc.gov/resource/x9gk-5huc.json"
limit = 50000
offset = 0
all_data = []

print("Fetching data from CDC (paginated)...")

while True:
    url = f"{base_url}?$limit={limit}&$offset={offset}"
    response = requests.get(url)
    response.raise_for_status()
    data = response.json()

    if not data:
        break

    all_data.extend(data)
    offset += limit
    print(f"Fetched {len(all_data)} records so far...")

# Load into a DataFrame
df = pd.DataFrame(all_data)

# Rename columns for clarity
column_renames = {
    "states": "reporting_area",
    "year": "year",
    "week": "week",
    "label": "disease_label",
    "m1": "cases_current_week",
    "m1_flag": "cases_current_week_flag",
    "m2": "cases_52wk_max",
    "m2_flag": "cases_52wk_max_flag",
    "m3": "cumulative_ytd_current_year",
    "m3_flag": "cumulative_ytd_current_year_flag",
    "m4": "cumulative_ytd_prev_year",
    "m4_flag": "cumulative_ytd_prev_year_flag",
    "location1": "location_text",
    "location2": "location_geotext",
    "sort_order": "sort_order",
    "geocode": "geocode"
}

df.rename(columns=column_renames, inplace=True)

Fetching data from CDC (paginated)...
Fetched 50000 records so far...
Fetched 100000 records so far...
Fetched 150000 records so far...
Fetched 200000 records so far...
Fetched 250000 records so far...
Fetched 300000 records so far...
Fetched 350000 records so far...
Fetched 400000 records so far...
Fetched 450000 records so far...
Fetched 500000 records so far...
Fetched 550000 records so far...
Fetched 600000 records so far...
Fetched 650000 records so far...
Fetched 700000 records so far...
Fetched 750000 records so far...
Fetched 800000 records so far...
Fetched 850000 records so far...
Fetched 900000 records so far...
Fetched 950000 records so far...
Fetched 1000000 records so far...
Fetched 1050000 records so far...
Fetched 1100000 records so far...
Fetched 1150000 records so far...
Fetched 1200000 records so far...
Fetched 1250000 records so far...
Fetched 1300000 records so far...
Fetched 1348060 records so far...


In [3]:
# Shape of the DataFrame
print("Shape of the dataset:")
print(df.shape)

# Preview data
print("Preview of dataset:")
display(df.head())

Shape of the dataset:
(1348060, 18)
Preview of dataset:


Unnamed: 0,reporting_area,year,week,disease_label,cases_current_week_flag,cases_52wk_max,cases_52wk_max_flag,cumulative_ytd_current_year_flag,cumulative_ytd_prev_year_flag,location_geotext,sort_order,location_text,geocode,:@computed_region_hjsp_umg2,:@computed_region_skr5_azej,cumulative_ytd_prev_year,cases_current_week,cumulative_ytd_current_year
0,US RESIDENTS,2022,1,Anthrax,-,0,-,-,-,US RESIDENTS,20220100001,,,,,,,
1,NEW ENGLAND,2022,1,Anthrax,-,0,-,-,-,NEW ENGLAND,20220100002,,,,,,,
2,CONNECTICUT,2022,1,Anthrax,-,0,-,-,-,,20220100003,CONNECTICUT,"{'type': 'Point', 'coordinates': [-72.738288, ...",24.0,1043.0,,,
3,MAINE,2022,1,Anthrax,-,0,-,-,-,,20220100004,MAINE,"{'type': 'Point', 'coordinates': [-69.06137, 4...",49.0,1725.0,,,
4,MASSACHUSETTS,2022,1,Anthrax,-,0,-,-,-,,20220100005,MASSACHUSETTS,"{'type': 'Point', 'coordinates': [-71.481104, ...",25.0,1916.0,,,


In [4]:
# Show columns
print("Column names:")
print(df.columns.tolist())

# Check missing values
print("Missing values per column:")
print(df.isnull().sum())

# Check data types
print("Data types:")
print(df.dtypes)

Column names:
['reporting_area', 'year', 'week', 'disease_label', 'cases_current_week_flag', 'cases_52wk_max', 'cases_52wk_max_flag', 'cumulative_ytd_current_year_flag', 'cumulative_ytd_prev_year_flag', 'location_geotext', 'sort_order', 'location_text', 'geocode', ':@computed_region_hjsp_umg2', ':@computed_region_skr5_azej', 'cumulative_ytd_prev_year', 'cases_current_week', 'cumulative_ytd_current_year']
Missing values per column:
reporting_area                            0
year                                      0
week                                      0
disease_label                             0
cases_current_week_flag               12889
cases_52wk_max                       150274
cases_52wk_max_flag                   99400
cumulative_ytd_current_year_flag      28262
cumulative_ytd_prev_year_flag         36384
location_geotext                    1097706
sort_order                                0
location_text                        250354
geocode                              

In [None]:
#looking at the frequency of occurence for the top 10 diseases in the data
# Count how many rows exist for each disease label
df['disease_label'].value_counts().head(10)

In [None]:
df_anthrax = df[df['disease_label'] == 'Anthrax']
df_anthrax.head()

unique_pairs = df[['location_text', 'week']].drop_duplicates()
len(unique_pairs)

In [5]:
# Show unique reporting areas and diseases
print("Unique reporting areas:")
print(df['reporting_area'].unique())

Unique reporting areas:
['US RESIDENTS' 'NEW ENGLAND' 'CONNECTICUT' 'MAINE' 'MASSACHUSETTS'
 'NEW HAMPSHIRE' 'RHODE ISLAND' 'VERMONT' 'MIDDLE ATLANTIC' 'NEW JERSEY'
 'NEW YORK' 'NEW YORK CITY' 'PENNSYLVANIA' 'EAST NORTH CENTRAL' 'ILLINOIS'
 'INDIANA' 'MICHIGAN' 'OHIO' 'WISCONSIN' 'WEST NORTH CENTRAL' 'IOWA'
 'KANSAS' 'MINNESOTA' 'MISSOURI' 'NEBRASKA' 'NORTH DAKOTA' 'SOUTH DAKOTA'
 'SOUTH ATLANTIC' 'DELAWARE' 'DISTRICT OF COLUMBIA' 'FLORIDA' 'GEORGIA'
 'MARYLAND' 'NORTH CAROLINA' 'SOUTH CAROLINA' 'VIRGINIA' 'WEST VIRGINIA'
 'EAST SOUTH CENTRAL' 'ALABAMA' 'KENTUCKY' 'MISSISSIPPI' 'TENNESSEE'
 'WEST SOUTH CENTRAL' 'ARKANSAS' 'LOUISIANA' 'OKLAHOMA' 'TEXAS' 'MOUNTAIN'
 'ARIZONA' 'COLORADO' 'IDAHO' 'MONTANA' 'NEVADA' 'NEW MEXICO' 'UTAH'
 'WYOMING' 'PACIFIC' 'ALASKA' 'CALIFORNIA' 'HAWAII' 'OREGON' 'WASHINGTON'
 'US TERRITORIES' 'AMERICAN SAMOA' 'NORTHERN MARIANA ISLANDS' 'GUAM'
 'PUERTO RICO' 'U.S. VIRGIN ISLANDS' 'NON-US RESIDENTS' 'TOTAL'
 'U.S. Residents' 'New England' 'Connecticut' 'Maine

In [6]:
print("Unique year values:")
print(df['year'].unique())

print("Unique week values:")
print(df['week'].unique())

print("Unique disease labels:")
print(df['disease_label'].unique())

Unique year values:
['2022' '2023' '2024' '2025']
Unique week values:
['1' '2' '3' '4' '5' '6' '7' '8' '9' '10' '11' '12' '13' '14' '15' '16'
 '17' '18' '19' '20' '21' '22' '23' '24' '25' '26' '27' '28' '29' '30'
 '31' '32' '33' '34' '35' '36' '37' '38' '39' '40' '41' '42' '43' '44'
 '45' '46' '47' '48' '49' '50' '51' '52']
Unique disease labels:
['Anthrax' 'Arboviral diseases, Chikungunya virus disease'
 'Arboviral diseases, Eastern equine encephalitis virus disease'
 'Arboviral diseases, Jamestown Canyon  virus disease'
 'Arboviral diseases, La Crosse  virus disease'
 'Arboviral diseases, Powassan virus disease'
 'Arboviral diseases, St. Louis encephalitis virus disease'
 'Arboviral diseases, West Nile virus disease'
 'Arboviral diseases, Western equine encephalitis virus disease'
 'Babesiosis' 'Botulism, Foodborne' 'Botulism, Infant'
 'Botulism, Other (wound & unspecified)' 'Brucellosis'
 'Campylobacteriosis' 'Candida auris, clinical'
 'Carbapenemase-producing carbapenem-resistant E

In [7]:
df["reporting_area"] = df["reporting_area"].str.strip().str.upper()
df["disease_label"] = df["disease_label"].str.strip().str.upper()

df = df[df["disease_label"] == "CHLAMYDIA TRACHOMATIS INFECTION"]

critical_cols = ["reporting_area", "disease_label", "year", "week"] # Columns already 0 NAs
df.dropna(subset=critical_cols, inplace=True) 

df = df[~df["reporting_area"].isin({
    "TOTAL", "US RESIDENTS", "U.S. RESIDENTS", "NEW ENGLAND", "MIDDLE ATLANTIC",
    "EAST NORTH CENTRAL", "WEST NORTH CENTRAL", "MOUNTAIN", "PACIFIC",
    "SOUTH ATLANTIC", "WEST SOUTH CENTRAL", "EAST SOUTH CENTRAL", "NEW YORK CITY"
})]

df[["year", "week", "cases_current_week"]] = df[["year", "week", "cases_current_week"]].apply(pd.to_numeric, errors="coerce")

df.dropna(subset=["year", "week", "cases_current_week"], inplace=True)
df = df.astype({"year": int, "week": int, "cases_current_week": int})

# Create weekly difference
df["cases_diff"] = df.groupby("reporting_area")["cases_current_week"].diff().fillna(0).astype(int)

# Create report ID
df["report_id"] = df["reporting_area"] + "|" + df["year"].astype(str) + "|" + df["week"].astype(str)

os.makedirs("data", exist_ok=True)
df.to_csv("data/Chlamydia.csv", index=False)
print("Created Dataframe")

Created Dataframe


In [10]:
df

Int64Index([   1263,    1266,    1267,    1269,    1270,    1272,    1274,
               1276,    1278,    1280,
            ...
            1341039, 1341040, 1341041, 1341042, 1341044, 1341045, 1341047,
            1341048, 1341050, 1341051],
           dtype='int64', length=6594)

In [None]:
# Filter for Chlamydia by the exact disease label

In [11]:
# Create Dataframes for Neo4j
report_cols = ["report_id", 
               "reporting_area", 
               "year", 
               "week", 
               "cases_current_week", 
               "cases_diff",
               "cumulative_ytd_current_year", 
               "cumulative_ytd_prev_year"
              ]

report_df = df[report_cols].drop_duplicates().copy()
report_df.rename(columns={"reporting_area": "location_id"}, inplace=True)

# Location–Report edges
loc_rep_edges = report_df[["location_id", "report_id"]].drop_duplicates().copy()

# NEXT (temporal) edges
report_df_sorted = report_df.sort_values(["location_id", "year", "week"])
temporal_edges = []
prev_row = None
for idx, row in report_df_sorted.iterrows():
    if prev_row is not None and row["location_id"] == prev_row["location_id"]:
        temporal_edges.append({
            "report_id": prev_row["report_id"],
            "next_report_id": row["report_id"]
        })
    prev_row = row
temporal_df = pd.DataFrame(temporal_edges)

# Save
os.makedirs("data", exist_ok=True)
report_df.to_csv("data/nodes_report.csv", index=False)
loc_rep_edges.to_csv("data/edges_location_report.csv", index=False)
temporal_df.to_csv("data/edges_temporal.csv", index=False)
print("Created CSVs for Neo4j graph")

Created CSVs for Neo4j graph


In [12]:
report_df

Unnamed: 0,report_id,location_id,year,week,cases_current_week,cases_diff,cumulative_ytd_current_year,cumulative_ytd_prev_year
1263,MAINE|2022|1,MAINE,2022,1,45,0,45,53
1266,RHODE ISLAND|2022|1,RHODE ISLAND,2022,1,60,0,60,51
1267,VERMONT|2022|1,VERMONT,2022,1,11,0,11,16
1269,NEW JERSEY|2022|1,NEW JERSEY,2022,1,334,0,334,683
1270,NEW YORK|2022|1,NEW YORK,2022,1,293,0,293,118
...,...,...,...,...,...,...,...,...
1341045,WYOMING|2025|14,WYOMING,2025,14,22,1,406.0,423.0
1341047,ALASKA|2025|14,ALASKA,2025,14,54,17,1128.0,1400.0
1341048,CALIFORNIA|2025|14,CALIFORNIA,2025,14,2012,-8,41244.0,46133.0
1341050,OREGON|2025|14,OREGON,2025,14,221,-24,3356.0,3837.0


In [13]:
loc_rep_edges

Unnamed: 0,location_id,report_id
1263,MAINE,MAINE|2022|1
1266,RHODE ISLAND,RHODE ISLAND|2022|1
1267,VERMONT,VERMONT|2022|1
1269,NEW JERSEY,NEW JERSEY|2022|1
1270,NEW YORK,NEW YORK|2022|1
...,...,...
1341045,WYOMING,WYOMING|2025|14
1341047,ALASKA,ALASKA|2025|14
1341048,CALIFORNIA,CALIFORNIA|2025|14
1341050,OREGON,OREGON|2025|14


In [14]:
temporal_df

Unnamed: 0,report_id,next_report_id
0,ALABAMA|2022|1,ALABAMA|2022|2
1,ALABAMA|2022|2,ALABAMA|2022|3
2,ALABAMA|2022|3,ALABAMA|2022|5
3,ALABAMA|2022|5,ALABAMA|2022|6
4,ALABAMA|2022|6,ALABAMA|2022|7
...,...,...
6534,WYOMING|2025|9,WYOMING|2025|10
6535,WYOMING|2025|10,WYOMING|2025|11
6536,WYOMING|2025|11,WYOMING|2025|12
6537,WYOMING|2025|12,WYOMING|2025|13


In [15]:
# Create border edges
state_borders = {
    "Alabama": ["Florida", "Georgia", "Mississippi", "Tennessee"],
    "Alaska": [],
    "Arizona": ["California", "Colorado", "Nevada", "New Mexico", "Utah"],
    "Arkansas": ["Louisiana", "Mississippi", "Missouri", "Oklahoma", "Tennessee", "Texas"],
    "California": ["Arizona", "Nevada", "Oregon"],
    "Colorado": ["Arizona", "Kansas", "Nebraska", "New Mexico", "Oklahoma", "Utah", "Wyoming"],
    "Connecticut": ["Massachusetts", "New York", "Rhode Island"],
    "Delaware": ["Maryland", "New Jersey", "Pennsylvania"],
    "Florida": ["Alabama", "Georgia"],
    "Georgia": ["Alabama", "Florida", "North Carolina", "South Carolina", "Tennessee"],
    "Hawaii": [],
    "Idaho": ["Montana", "Nevada", "Oregon", "Utah", "Washington", "Wyoming"],
    "Illinois": ["Indiana", "Iowa", "Kentucky", "Missouri", "Wisconsin"],
    "Indiana": ["Illinois", "Kentucky", "Michigan", "Ohio"],
    "Iowa": ["Illinois", "Minnesota", "Missouri", "Nebraska", "South Dakota", "Wisconsin"],
    "Kansas": ["Colorado", "Missouri", "Nebraska", "Oklahoma"],
    "Kentucky": ["Illinois", "Indiana", "Missouri", "Ohio", "Tennessee", "Virginia", "West Virginia"],
    "Louisiana": ["Arkansas", "Mississippi", "Texas"],
    "Maine": ["New Hampshire"],
    "Maryland": ["Delaware", "Pennsylvania", "Virginia", "West Virginia"],
    "Massachusetts": ["Connecticut", "New Hampshire", "New York", "Rhode Island", "Vermont"],
    "Michigan": ["Indiana", "Ohio", "Wisconsin"],
    "Minnesota": ["Iowa", "North Dakota", "South Dakota", "Wisconsin"],
    "Mississippi": ["Alabama", "Arkansas", "Louisiana", "Tennessee"],
    "Missouri": ["Arkansas", "Illinois", "Iowa", "Kansas", "Kentucky", "Nebraska", "Oklahoma", "Tennessee"],
    "Montana": ["Idaho", "North Dakota", "South Dakota", "Wyoming"],
    "Nebraska": ["Colorado", "Iowa", "Kansas", "Missouri", "South Dakota", "Wyoming"],
    "Nevada": ["Arizona", "California", "Idaho", "Oregon", "Utah"],
    "New Hampshire": ["Maine", "Massachusetts", "Vermont"],
    "New Jersey": ["Delaware", "New York", "Pennsylvania"],
    "New Mexico": ["Arizona", "Colorado", "Oklahoma", "Texas", "Utah"],
    "New York": ["Connecticut", "Massachusetts", "New Jersey", "Pennsylvania", "Vermont"],
    "North Carolina": ["Georgia", "South Carolina", "Tennessee", "Virginia"],
    "North Dakota": ["Minnesota", "Montana", "South Dakota"],
    "Ohio": ["Indiana", "Kentucky", "Michigan", "Pennsylvania", "West Virginia"],
    "Oklahoma": ["Arkansas", "Colorado", "Kansas", "Missouri", "New Mexico", "Texas"],
    "Oregon": ["California", "Idaho", "Nevada", "Washington"],
    "Pennsylvania": ["Delaware", "Maryland", "New Jersey", "New York", "Ohio", "West Virginia"],
    "Rhode Island": ["Connecticut", "Massachusetts"],
    "South Carolina": ["Georgia", "North Carolina"],
    "South Dakota": ["Iowa", "Minnesota", "Montana", "Nebraska", "North Dakota", "Wyoming"],
    "Tennessee": ["Alabama", "Arkansas", "Georgia", "Kentucky", "Mississippi", "Missouri", "North Carolina", "Virginia"],
    "Texas": ["Arkansas", "Louisiana", "New Mexico", "Oklahoma"],
    "Utah": ["Arizona", "Colorado", "Idaho", "Nevada", "New Mexico", "Wyoming"],
    "Vermont": ["Massachusetts", "New Hampshire", "New York"],
    "Virginia": ["Kentucky", "Maryland", "North Carolina", "Tennessee", "West Virginia"],
    "Washington": ["Idaho", "Oregon"],
    "West Virginia": ["Kentucky", "Maryland", "Ohio", "Pennsylvania", "Virginia"],
    "Wisconsin": ["Illinois", "Iowa", "Michigan", "Minnesota"],
    "Wyoming": ["Colorado", "Idaho", "Montana", "Nebraska", "South Dakota", "Utah"]
}

# Create edge list
edges = []
for state, neighbors in state_borders.items():
    for neighbor in neighbors:
        edges.append({"from": state.upper(), "to": neighbor.upper()})

# Save as CSV
borders_df = pd.DataFrame(edges)
borders_df.to_csv("data/edges_borders.csv", index=False)
borders_df

Unnamed: 0,from,to
0,ALABAMA,FLORIDA
1,ALABAMA,GEORGIA
2,ALABAMA,MISSISSIPPI
3,ALABAMA,TENNESSEE
4,ARIZONA,CALIFORNIA
...,...,...
209,WYOMING,IDAHO
210,WYOMING,MONTANA
211,WYOMING,NEBRASKA
212,WYOMING,SOUTH DAKOTA


In [16]:
# Setup Neo4j connection
driver = GraphDatabase.driver("bolt://neo4j:7687", auth=("neo4j", "ucb_mids_w205"))

def run_query(tx, query, parameters=None):
    tx.run(query, parameters or {})

def batch_execute(session, query, dataframe, parameter_keys):
    for _, row in tqdm(dataframe.iterrows(), total=len(dataframe)):
        params = {key: row[key] for key in parameter_keys}
        session.execute_write(run_query, query, params)

def run_cypher(query):
    with driver.session() as session:
        return list(session.run(query))

def run_write_cypher(query):
    with driver.session() as session:
        session.run(query)

# Load DataFrames
report_df = pd.read_csv("data/nodes_report.csv")
lr_df = pd.read_csv("data/edges_location_report.csv")
temp_df = pd.read_csv("data/edges_temporal.csv")
borders_df = pd.read_csv("data/edges_borders.csv")

In [17]:
# Load data into Neo4j ===
with driver.session() as session:
    print("Clearing database...")
    session.run("MATCH (n) DETACH DELETE n")

    print("Loading Location nodes...")
    unique_locations = report_df["location_id"].drop_duplicates().to_frame()
    unique_locations["location_name"] = unique_locations["location_id"].str.title()

    batch_execute(session, """
        MERGE (:Location {id: $location_id, name: $location_name})
    """, unique_locations, ["location_id", "location_name"])

    print("Loading Report nodes...")
    batch_execute(session, """
        MERGE (:Report {
            id: $report_id,
            location_id: $location_id,
            year: toInteger($year),
            week: toInteger($week),
            cases: toInteger($cases_current_week),
            diff: toInteger($cases_diff)
        })
    """, report_df, ["report_id", "location_id", "year", "week", "cases_current_week", "cases_diff"])

    print("Creating CONTAINS relationships...")
    batch_execute(session, """
        MATCH (l:Location {id: $location_id}), (r:Report {id: $report_id})
        MERGE (l)-[:CONTAINS]->(r)
    """, lr_df, ["location_id", "report_id"])

    print("Creating NEXT relationships...")
    batch_execute(session, """
        MATCH (r1:Report {id: $report_id}), (r2:Report {id: $next_report_id})
        MERGE (r1)-[:NEXT]->(r2)
    """, temp_df, ["report_id", "next_report_id"])

    print("Creating BORDERS relationships...")
    for _, row in tqdm(borders_df.iterrows(), total=len(borders_df)):
        session.execute_write(lambda tx: tx.run("""
            MATCH (s1:Location {id: $from}), (s2:Location {id: $to})
            MERGE (s1)-[:BORDERS]->(s2)
            MERGE (s2)-[:BORDERS]->(s1)
        """, {"from": row["from"], "to": row["to"]}))

driver.close()
print("Completed building the graph in Neo4j!")

Clearing database...
Loading Location nodes...


100%|█████████████████████████████████████████████| 55/55 [00:01<00:00, 40.22it/s]


Loading Report nodes...


100%|████████████████████████████████████████| 6594/6594 [00:52<00:00, 126.09it/s]


Creating CONTAINS relationships...


100%|████████████████████████████████████████| 6594/6594 [01:03<00:00, 104.63it/s]


Creating NEXT relationships...


100%|█████████████████████████████████████████| 6539/6539 [01:14<00:00, 87.95it/s]


Creating BORDERS relationships...


100%|██████████████████████████████████████████| 214/214 [00:01<00:00, 158.66it/s]

Completed building the graph in Neo4j!



