## Predicting WNV Prevalence at a County Level in the United States Using AlphaEarth Embedding Data (2017 - 2024)

### Contents:

 1. Converting Google Earth Assets to CSV files  
 2. Appending WNV case data to each file
 3. Obtatining population data from 2017 to 2024 for all counties
 4. WNV case data visualized at a national level   
 5. Machine learning model evaluation 


In [5]:
from pathlib import Path
from glob import glob
import sys
import pickle
from dotenv import load_dotenv
import os
from datacommons_client.client import DataCommonsClient

# utils import error: add wnv_embeddings as root
PROJECT_ROOT = Path.cwd().parents[1]  # <-- wnv_embeddings
sys.path.insert(0, str(PROJECT_ROOT))
import ee

from utils.utils import convert_to_df
import pandas as pd
import geopandas as gpd
import numpy as np

from PIL import Image
import time

### 1. Converting Google Earth Assets to ~50 CSV Files 
Task generation script is contained in `utils/utils.py` and `main.py` is responsible for actually prompting thsoe tasks to begin in Google Earth Engine (in the cloud). \
Below, I am converting the GEE assets that were stored serverside into local CSVs.


* National average embeddings data per county for all states (2017 to 2024).
* Each asset represents one state (according to the FIPS code).
* State FIPS Codes available here: https://transition.fcc.gov/oet/info/maps/census/fips/fips.txt

#### Saving as CSV to `.\notebooks\national_embeddings\all_embeddings_csvs`

Using the `convert_to_df()` function from `utils.py`: 

In [3]:
# will prompt you to authorize access to GEE
# this is needed to obtain assets from the cloud saved under your account
ee.Authenticate()

# enter your own registered project name here
ee.Initialize(project="wnv-embeddings")

In [4]:
state_fips_codes = [
    "01", "02", "04", "05", "06", "08", "09", "10", "11", "12",
    "13", "15", "16", "17", "18", "19", "20", "21", "22", "23",
    "24", "25", "26", "27", "28", "29", "30", "31", "32", "33",
    "34", "35", "36", "37", "38", "39", "40", "41", "42", "44",
    "45", "46", "47", "48", "49", "50", "51", "53", "54", "55", "56"
  ]

In [None]:
# =============CONVERT GEE ASSETS TO CSVS============= #
# ONLY RUN ONCE TO CONVERT ALL 56 ASSETS AS CSV #

# now obtaining the csvs
# csv_destination = Path("all_embeddings_csvs")
# csv_destination.mkdir(parents=True, exist_ok=True)

# for fips in state_fips_codes:
# 	gee_path = f"users/angel314/{fips}_2017_2024_embeddings"
	
# 	save_to = csv_destination / f"{fips}-avg-embeddings-2017-2024.csv"

# 	convert_to_df(gee_path, True, save_to)

In [7]:
# =============CONVERT GEE CT PLANNING REGIONS ASSET TO CSV============= #
# ONLY RUN ONCE TO CONVERT ALL 56 ASSETS AS CSV #

csv_destination = Path("all_embeddings_csvs")
csv_destination.mkdir(parents=True, exist_ok=True)

gee_path = f"users/angel314/09_2017_2024_embeddings_ct_new"

save_to = csv_destination / f"ct-planning-regions-avg-embeddings-2017-2024.csv"

convert_to_df(gee_path, True, save_to)

all_embeddings_csvs\ct-planning-regions-avg-embeddings-2017-2024.csv does not exist, creating all_embeddings_csvs\ct-planning-regions-avg-embeddings-2017-2024.csv

retrieved asset at 'users/angel314/09_2017_2024_embeddings_ct_new'
saved as CSV to: all_embeddings_csvs\ct-planning-regions-avg-embeddings-2017-2024.csv


Unnamed: 0,A00_2017,A00_2018,A00_2019,A00_2020,A00_2021,A00_2022,A00_2023,A00_2024,A01_2017,A01_2018,...,A62_2024,A63_2017,A63_2018,A63_2019,A63_2020,A63_2021,A63_2022,A63_2023,A63_2024,GEOID
0,-0.057855,-0.055518,-0.082817,-0.085064,-0.081177,-0.080661,-0.077676,-0.083811,-0.053182,-0.062204,...,-0.116039,-0.000793,0.003175,0.007493,-0.026212,0.004062,-0.004939,-0.013921,-0.002128,9140
1,-0.029856,-0.043176,-0.066744,-0.084748,-0.070108,-0.07072,-0.070404,-0.070954,0.000243,-0.007592,...,-0.08266,0.009964,0.012139,0.016574,-0.027655,0.010644,0.003197,0.013202,0.005719,9150
2,-0.032224,-0.037584,-0.051237,-0.059114,-0.029675,-0.046795,-0.051063,-0.050507,0.004208,-0.007728,...,-0.074208,0.009697,0.015314,0.025111,-0.008877,0.038699,0.001854,0.014391,0.003663,9160
3,-0.022924,-0.021824,-0.036791,-0.033231,-0.031408,-0.036615,-0.020607,-0.032033,-0.102085,-0.102969,...,-0.090937,0.125624,0.113925,0.116631,0.10206,0.119224,0.105919,0.108467,0.125621,9170
4,-0.061269,-0.06053,-0.083934,-0.095966,-0.087213,-0.094672,-0.081715,-0.084212,-0.027524,-0.032526,...,-0.104167,0.030031,0.035055,0.030881,-0.000648,0.026488,0.019322,0.025119,0.030293,9180
5,-0.022919,-0.015612,-0.027782,-0.022987,-0.023637,-0.025878,-0.013778,-0.020955,-0.115328,-0.116547,...,-0.088271,0.117155,0.112799,0.114882,0.101677,0.122768,0.108292,0.104062,0.121962,9120
6,-0.045174,-0.040372,-0.060937,-0.055804,-0.057942,-0.050842,-0.05337,-0.055748,-0.067533,-0.079472,...,-0.091251,0.050927,0.053534,0.05369,0.036795,0.061234,0.046251,0.037806,0.048922,9190
7,-0.041947,-0.052342,-0.077795,-0.088983,-0.078226,-0.084367,-0.07401,-0.080236,-0.06003,-0.074454,...,-0.094936,0.010738,0.012544,0.020791,-0.023066,0.012597,-0.002332,0.003826,0.005892,9110
8,-0.043662,-0.049823,-0.07599,-0.077881,-0.07413,-0.080885,-0.069529,-0.074261,-0.02838,-0.037815,...,-0.098223,0.056882,0.058098,0.051141,0.025071,0.052313,0.032332,0.039943,0.049042,9130


### 2. Appending Yearly WNV Case Data

##### Getting WNV Case Data:
* Source: https://www.cdc.gov/west-nile-virus/data-maps/historic-data.html  
* Section: "Explore county level data for 1999-2024" - "Yearly data"
	* Returns: one CSV with case data at a county level for 1999-2024
* `Location` column represents the FIPS county code for that row.
* WNV Case data is cleaned to only include relevant years and rows with at least one human disease case. 

This is a preview of WNV County Cases from 1999 to 2024.

In [8]:
cases = pd.read_csv("./national_wnv_case_data/wnv_county_cases_1999_2024.csv")
cases.sample(5)

Unnamed: 0,FullGeoName,Year,Location,Activity,Total human disease cases,Neuroinvasive disease cases,**Presumptive viremic blood donors,Notes
18528,"SD, Dewey",2006,46041,Human infections,0.0,0.0,1.0,
15636,"KY, Christian",2008,21047,Non-human activity,0.0,0.0,0.0,
7998,"LA, Calcasieu Parish",2016,22019,Human infections and non-human activity,6.0,5.0,1.0,
3743,"SC, Richland",2021,45079,Human infections and non-human activity,1.0,1.0,0.0,
13107,"PA, Dauphin",2012,42043,Human infections and non-human activity,2.0,2.0,0.0,


In [9]:
###### filtering ######

# remove entries that come before 2017
cases = cases[cases["Year"]>=2017]
# remove any rows with 0 total human disease cases
cases = cases[cases["Total human disease cases"]>0]
# only keep relevant columns
cases = cases.drop(columns=["FullGeoName", "Activity", "Neuroinvasive disease cases", "**Presumptive viremic blood donors", "Notes"]).reset_index(drop=True)
cases

Unnamed: 0,Year,Location,Total human disease cases
0,2024,1001,2.0
1,2024,1003,2.0
2,2024,1021,1.0
3,2024,1043,2.0
4,2024,1047,1.0
...,...,...,...
4006,2017,55141,2.0
4007,2017,56003,1.0
4008,2017,56013,3.0
4009,2017,56015,2.0


In [10]:
cases = cases.groupby(["Year","Location"]).agg("sum").reset_index()
cases

Unnamed: 0,Year,Location,Total human disease cases
0,2017,1001,6.0
1,2017,1003,3.0
2,2017,1007,1.0
3,2017,1011,1.0
4,2017,1015,2.0
...,...,...,...
4006,2024,55133,1.0
4007,2024,55139,1.0
4008,2024,55141,1.0
4009,2024,56015,1.0


In [11]:
# convert from long format to wide format
# each row represents one location
# each location has sum of cases for 2017 - 2024.

# columns="Year" -> each unique year is a column
# values="cases" -> numbers to fill pivot table
# take sum of all entries for the the same location and year

# reset_index to move "Location" column to the right.

cases_wide = (cases.pivot_table(index="Location", columns="Year", values="Total human disease cases", aggfunc="sum", fill_value=0).add_prefix("Cases_").reset_index())
cases_wide

Year,Location,Cases_2017,Cases_2018,Cases_2019,Cases_2020,Cases_2021,Cases_2022,Cases_2023,Cases_2024
0,1001,6.0,0.0,0.0,1.0,1.0,0.0,1.0,2.0
1,1003,3.0,2.0,1.0,0.0,2.0,1.0,0.0,2.0
2,1007,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
3,1011,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,1015,2.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0
...,...,...,...,...,...,...,...,...,...
1607,56025,0.0,0.0,1.0,0.0,0.0,1.0,3.0,1.0
1608,56029,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0
1609,56031,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0
1610,56033,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0


In [24]:
cases_wide[(cases_wide["Location"]>=9000) & (cases_wide["Location"]<=10000)]

Year,Location,Cases_2017,Cases_2018,Cases_2019,Cases_2020,Cases_2021,Cases_2022,Cases_2023,Cases_2024
171,9001,2.0,11.0,1.0,5.0,3.0,0.0,0.0,0.0
172,9003,0.0,5.0,0.0,1.0,2.0,0.0,0.0,0.0
173,9007,0.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0
174,9009,1.0,3.0,0.0,2.0,2.0,0.0,0.0,0.0
175,9015,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
176,9110,0.0,0.0,0.0,0.0,0.0,0.0,3.0,2.0
177,9120,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
178,9130,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
179,9170,0.0,0.0,0.0,0.0,0.0,0.0,1.0,3.0
180,9180,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


Saving this cleaned dataframe to a csv for future use.

In [13]:
cases_wide.to_csv("./national_wnv_case_data/agg_wnv_county_cases_2017_2024.csv")

##### Iterating Over `all_embeddings_csvs` to add WNV Human cases for each year.

In [26]:
files = glob("../national_embeddings/all_embeddings_csvs/*.csv")
dfs = [pd.read_csv(f) for f in files]

df_all = pd.concat(dfs, ignore_index=True)
df_merged = pd.merge(df_all, cases_wide, left_on="GEOID", right_on="Location", how="left").fillna(0).drop(columns=["Location"])

df_merged.to_csv(f"./all_embeddings_with_cases/cleaned-avg-embeddings-2017-2024.csv")

In [27]:
df_merged

Unnamed: 0,A00_2017,A00_2018,A00_2019,A00_2020,A00_2021,A00_2022,A00_2023,A00_2024,A01_2017,A01_2018,...,A63_2024,GEOID,Cases_2017,Cases_2018,Cases_2019,Cases_2020,Cases_2021,Cases_2022,Cases_2023,Cases_2024
0,0.020481,0.011058,-0.000501,0.011396,0.000753,-0.004479,-0.022577,-0.004919,-0.077097,-0.053880,...,0.105248,1053,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,-0.013062,-0.030307,-0.034806,-0.026667,-0.040167,-0.049618,-0.055469,-0.042281,-0.011958,0.010793,...,0.052994,1123,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
2,-0.105284,-0.107447,-0.129861,-0.119771,-0.122789,-0.130565,-0.146850,-0.115087,0.003709,0.017785,...,0.000531,1009,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,-0.072484,-0.081692,-0.097903,-0.090730,-0.091891,-0.100606,-0.111368,-0.091791,-0.012406,-0.000298,...,0.017798,1115,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,-0.054914,-0.063882,-0.079282,-0.075227,-0.075802,-0.082101,-0.091138,-0.075997,-0.037458,-0.030478,...,0.029815,1117,4.0,1.0,0.0,0.0,0.0,0.0,2.0,2.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3146,-0.061269,-0.060530,-0.083934,-0.095966,-0.087213,-0.094672,-0.081715,-0.084212,-0.027524,-0.032526,...,0.030293,9180,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
3147,-0.022919,-0.015612,-0.027782,-0.022987,-0.023637,-0.025878,-0.013778,-0.020955,-0.115328,-0.116547,...,0.121962,9120,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
3148,-0.045174,-0.040372,-0.060937,-0.055804,-0.057942,-0.050842,-0.053370,-0.055748,-0.067533,-0.079472,...,0.048922,9190,0.0,0.0,0.0,0.0,0.0,0.0,3.0,4.0
3149,-0.041947,-0.052342,-0.077795,-0.088983,-0.078226,-0.084367,-0.074010,-0.080236,-0.060030,-0.074454,...,0.005892,9110,0.0,0.0,0.0,0.0,0.0,0.0,3.0,2.0


### 3. Obtaining and Appending County Population Data:

* Using Data Commons API:

	* https://docs.datacommons.org/what_is.html 

	* Basically allows us to query specific statistical questions and get one unified result.

	* There is an option to query for counties as well using FIPS codes: https://datacommons.org/browser/County 

County population data is needed for each year to normalize based on this formula:

$\textnormal{Cases per 100k} = \frac{\textnormal{Number of disease cases}}{\textnormal{County population}} \times 100,000$

Normalized cases (cases per 100k) will be the target variable when measuring machine learning models' performance.

note: api.census.gov does not have consistent and updated data for 2017 - 2024 county populations.

In [None]:
# ----- Get the long dataframe we created in the previous section ----- #
# This long dataframe contains: all geoid + planning region embedding data + case data per county

df_all = pd.read_csv("./all_embeddings_with_cases/cleaned-avg-embeddings-2017-2024.csv")

load_dotenv()

client = DataCommonsClient(api_key=os.getenv("COMMONS_API_KEY"))

In [29]:
# ----- Population fetching with PKL cache ----- #

CACHE_PKL = "../national_embeddings/population_cache.pkl"

def get_populations_batch_single_year(client, geoids: list[str], year: int):
    """
    Fetch population for multiple GEOIDs for a SINGLE year in one API call.
    """
    try:
        entity_dcids = [f"geoId/{geoid}" for geoid in geoids]
        
        df = client.observations_dataframe(
            variable_dcids="Count_Person",
            entity_dcids=entity_dcids,
            date=str(year)
        )
        
        if df.empty:
            return []
        
        results = []
        for _, row in df.iterrows():
            entity = row.get("entity", "")
            geoid = entity.replace("geoId/", "")
            value = row.get("value", None)
            
            if value is not None and not pd.isna(value):
                pop = int(np.ceil(float(value)))
                results.append((geoid, year, pop))
        
        return results
    
    except Exception as e:
        print(f"      Batch API error for year {year}: {e}")
        return []


def fetch_populations_for_geoids(client, geoids: list[str], years: list[int], batch_size: int = 500):
    """
    Fetch population data for given GEOIDs and years.
    Returns a DataFrame with columns: GEOID, year, population
    """
    batches = [geoids[i:i + batch_size] for i in range(0, len(geoids), batch_size)]
    
    all_results = []
    total_calls = len(batches) * len(years)
    call_count = 0
    
    for year in years:
        print(f"\n  Fetching data for year {year}...")
        for i, batch in enumerate(batches):
            call_count += 1
            print(f"    Batch {i+1}/{len(batches)} ({len(batch)} GEOIDs) - Call {call_count}/{total_calls}")
            results = get_populations_batch_single_year(client, batch, year)
            all_results.extend(results)
            
            time.sleep(0.3)
    
    # Convert to DataFrame
    pop_df = pd.DataFrame(all_results, columns=["GEOID", "year", "population"])
    
    return pop_df


def load_cache(cache_path: str):
    """Load pickle cache if it exists."""
    if Path(cache_path).exists():
        print(f"Loading cache from {cache_path}...")
        with open(cache_path, 'rb') as f:
            cache = pickle.load(f)
        print(f"  Loaded {len(cache)} GEOID-year pairs from cache")
        return cache
    else:
        print("No cache found, starting fresh.")
        return {}


def save_cache(cache: dict, cache_path: str):
    """Save cache to pickle file."""
    print(f"Saving cache to {cache_path}...")
    with open(cache_path, 'wb') as f:
        pickle.dump(cache, f)
    print(f"  Saved {len(cache)} GEOID-year pairs to cache")


def fetch_populations_with_pkl_cache(df_all: pd.DataFrame, client, years: list[int], 
                                       cache_path: str = CACHE_PKL, 
                                       batch_size: int = 500,
                                       max_retries: int = 2) -> pd.DataFrame:
    """
    Fetch population data with pickle caching and retry logic for missing GEOIDs.
    
    Cache structure: {(geoid, year): population}
    
    Args:
        df_all: Main dataframe with GEOID column
        client: DataCommons client
        years: List of years to fetch
        cache_path: Path to pickle cache file
        batch_size: Number of GEOIDs per API call
        max_retries: Number of retry attempts for missing GEOIDs
    
    Returns:
        DataFrame with population columns added
    """
    # Load cache
    cache = load_cache(cache_path)
    
    # Get all unique GEOIDs we need
    all_geoids = df_all["GEOID"].unique().tolist()
    all_geoids = [str(g).zfill(5) for g in all_geoids]
    
    print(f"\nTotal unique GEOIDs needed: {len(all_geoids)}")
    print(f"Years needed: {list(years)}")
    
    # Determine what's missing from cache
    needed_pairs = set((geoid, year) for geoid in all_geoids for year in years)
    cached_pairs = set(cache.keys())
    missing_pairs = needed_pairs - cached_pairs
    
    print(f"\nCache status:")
    print(f"  Cached pairs: {len(cached_pairs)}")
    print(f"  Needed pairs: {len(needed_pairs)}")
    print(f"  Missing pairs: {len(missing_pairs)}")
    
    # Fetch missing data with retries
    retry_count = 0
    while missing_pairs and retry_count < max_retries:
        retry_count += 1
        
        # Group missing pairs by year for efficient fetching
        missing_by_year = {}
        for geoid, year in missing_pairs:
            missing_by_year.setdefault(year, set()).add(geoid)
        
        print(f"\n{'='*60}")
        print(f"Retry attempt {retry_count}/{max_retries}")
        print(f"{'='*60}")
        
        for year in sorted(missing_by_year.keys()):
            geoids_to_fetch = list(missing_by_year[year])
            print(f"\nYear {year}: Fetching {len(geoids_to_fetch)} missing GEOIDs")
            
            # Fetch data
            pop_df = fetch_populations_for_geoids(client, geoids_to_fetch, [year], batch_size)
            
            # Update cache with results
            for _, row in pop_df.iterrows():
                geoid = str(row['GEOID']).zfill(5)
                year_val = int(row['year'])
                pop = int(row['population'])
                cache[(geoid, year_val)] = pop
            
            print(f"  Retrieved {len(pop_df)} values for year {year}")
        
        # Save cache after each retry
        save_cache(cache, cache_path)
        
        # Recalculate what's still missing
        cached_pairs = set(cache.keys())
        missing_pairs = needed_pairs - cached_pairs
        
        print(f"\nAfter retry {retry_count}:")
        print(f"  Still missing: {len(missing_pairs)} pairs")
        
        if missing_pairs:
            print(f"  Sample missing pairs: {list(missing_pairs)[:10]}")
            
            # If we have retries left, wait a bit before retrying
            if retry_count < max_retries:
                print(f"\nWaiting 2 seconds before retry {retry_count + 1}...")
                time.sleep(2)
    
    # Convert cache to wide-format DataFrame
    print(f"\n{'='*60}")
    print("Converting cache to DataFrame...")
    print(f"{'='*60}")
    
    # Build DataFrame from cache
    rows = []
    for geoid in all_geoids:
        row = {"GEOID": geoid}
        for year in years:
            pop = cache.get((geoid, year), None)
            row[f"Popln_{year}"] = pop
        rows.append(row)
    
    pop_wide = pd.DataFrame(rows)
    
    # Check coverage
    total_cells = len(all_geoids) * len(years)
    filled_cells = pop_wide[[f"Popln_{y}" for y in years]].notna().sum().sum()
    coverage = (filled_cells / total_cells) * 100
    
    print(f"\nFinal coverage: {filled_cells}/{total_cells} ({coverage:.1f}%)")
    
    if filled_cells < total_cells:
        missing_count = total_cells - filled_cells
        print(f"Warning: {missing_count} GEOID-year pairs still missing (will be NaN)")
    
    # Merge with original dataframe
    df_all["GEOID"] = df_all["GEOID"].astype(str).str.zfill(5)
    pop_wide["GEOID"] = pop_wide["GEOID"].astype(str).str.zfill(5)
    
    out = df_all.merge(pop_wide, on="GEOID", how="left")
    
    print(f"\nMerge complete:")
    print(f"  Original df_all shape: {df_all.shape}")
    print(f"  Population data shape: {pop_wide.shape}")
    print(f"  Merged df shape: {out.shape}")
    
    return out


def diagnose_missing_populations(df_all: pd.DataFrame, years: list[int]):
    """
    Show diagnostic info about missing population data.
    """
    pop_cols = [f"Popln_{y}" for y in years]
    
    # Find rows with ANY missing population data
    missing_mask = df_all[pop_cols].isna().any(axis=1)
    missing_df = df_all[missing_mask].copy()
    
    if len(missing_df) == 0:
        print("\n✓ No missing population data!")
        return None
    
    print(f"\n{'='*60}")
    print(f"Found {len(missing_df)} rows with missing population data")
    print(f"{'='*60}")
    
    # Get unique GEOIDs with missing data
    missing_geoids = missing_df["GEOID"].unique()
    print(f"\nUnique GEOIDs with missing data: {len(missing_geoids)}")
    print(f"Sample GEOIDs: {missing_geoids[:10].tolist()}")
    
    # Show some examples
    print(f"\nSample rows with missing data:")
    display_cols = ["GEOID"] + pop_cols
    print(missing_df[display_cols].head(10))
    
    return missing_geoids


# ----- MAIN EXECUTION ----- #

print("="*60)
print("FETCHING POPULATION DATA WITH CACHING")
print("="*60)

# Fetch populations with caching and retry
df_all = fetch_populations_with_pkl_cache(
    df_all, 
    client, 
    years=range(2017, 2025),
    cache_path=CACHE_PKL,
    batch_size=500,
    max_retries=2  # Will retry missing GEOIDs twice
)

# Diagnose any remaining missing data
missing_geoids = diagnose_missing_populations(df_all, range(2017, 2025))

# Save final result
df_all.to_csv("../national_embeddings/national_wnv_case_data/long_cases_popln_embs.csv", index=False)

print("\n" + "="*60)
print("COMPLETE!")
print("="*60)
print(f"Final dataframe shape: {df_all.shape}")
print(f"Saved to: ../national_embeddings/national_wnv_case_data/long_cases_popln_embs.csv")
print(f"Cache saved to: {CACHE_PKL}")
print("="*60)

FETCHING POPULATION DATA WITH CACHING
Loading cache from ../national_embeddings/population_cache.pkl...
  Loaded 25108 GEOID-year pairs from cache

Total unique GEOIDs needed: 3151
Years needed: [2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024]

Cache status:
  Cached pairs: 25108
  Needed pairs: 25208
  Missing pairs: 100

Retry attempt 1/2

Year 2017: Fetching 9 missing GEOIDs

  Fetching data for year 2017...
    Batch 1/1 (9 GEOIDs) - Call 1/1
  Retrieved 0 values for year 2017

Year 2018: Fetching 9 missing GEOIDs

  Fetching data for year 2018...
    Batch 1/1 (9 GEOIDs) - Call 1/1
  Retrieved 0 values for year 2018

Year 2019: Fetching 9 missing GEOIDs

  Fetching data for year 2019...
    Batch 1/1 (9 GEOIDs) - Call 1/1
  Retrieved 0 values for year 2019

Year 2020: Fetching 9 missing GEOIDs

  Fetching data for year 2020...
    Batch 1/1 (9 GEOIDs) - Call 1/1
  Retrieved 18 values for year 2020

Year 2021: Fetching 10 missing GEOIDs

  Fetching data for year 2021...
    Batch 1

In [39]:
print(f"As seen below, the only regions with missing populations are: \n{list(df_all[df_all.isnull().any(axis=1)]["GEOID"])}")

df_all[df_all.isnull().any(axis=1)]

As seen below, the only regions with missing populations are: 
['02261', '09001', '09009', '09005', '09015', '09007', '09013', '09003', '09011', '09140', '09150', '09160', '09170', '09180', '09120', '09190', '09110', '09130']


Unnamed: 0.1,Unnamed: 0,A00_2017,A00_2018,A00_2019,A00_2020,A00_2021,A00_2022,A00_2023,A00_2024,A01_2017,...,Cases_2023,Cases_2024,Popln_2017,Popln_2018,Popln_2019,Popln_2020,Popln_2021,Popln_2022,Popln_2023,Popln_2024
85,85,-0.067889,-0.053618,-0.057702,-0.065675,-0.067103,-0.07059,-0.068902,-0.063097,-0.232306,...,0.0,0.0,9224.0,9301.0,9243.0,0,,,,
308,308,-0.03731,-0.031442,-0.049655,-0.043741,-0.046134,-0.042171,-0.039296,-0.043806,-0.088164,...,0.0,0.0,949921.0,943823.0,943332.0,944306,956446.0,,,
309,309,-0.033067,-0.031561,-0.049956,-0.047735,-0.045903,-0.048856,-0.036343,-0.046602,-0.0902,...,0.0,0.0,860435.0,857620.0,854757.0,855733,864751.0,,,
310,310,-0.040196,-0.043661,-0.059519,-0.067283,-0.04189,-0.054963,-0.059617,-0.059496,0.000758,...,0.0,0.0,182177.0,181111.0,180333.0,181143,185175.0,,,
311,311,-0.035494,-0.04828,-0.072797,-0.090926,-0.077159,-0.076595,-0.075972,-0.078254,-0.005003,...,0.0,0.0,116359.0,117027.0,116782.0,116657,116503.0,,,
312,312,-0.043372,-0.051041,-0.077507,-0.079033,-0.075957,-0.081894,-0.072071,-0.076353,-0.027307,...,0.0,0.0,163410.0,162682.0,162436.0,162742,164568.0,,,
313,313,-0.040297,-0.055253,-0.078879,-0.09799,-0.077941,-0.082462,-0.082011,-0.082412,-0.013475,...,0.0,0.0,151461.0,150921.0,150721.0,150947,150120.0,,,
314,314,-0.038071,-0.045666,-0.071302,-0.077414,-0.070555,-0.077963,-0.064387,-0.072232,-0.079555,...,0.0,0.0,895388.0,892697.0,891720.0,892153,898636.0,,,
315,315,-0.056698,-0.056489,-0.079609,-0.090634,-0.081786,-0.090594,-0.076504,-0.078844,-0.024375,...,0.0,0.0,269033.0,266784.0,265206.0,266868,269131.0,,,
3142,3142,-0.057855,-0.055518,-0.082817,-0.085064,-0.081177,-0.080661,-0.077676,-0.083811,-0.053182,...,0.0,0.0,,,,449055,452095.0,451887.0,452303.0,462220.0


In [51]:
poplns = ["Popln_2017","Popln_2018","Popln_2019","Popln_2020","Popln_2021","Popln_2022","Popln_2023","Popln_2024"]
mask = (df_all[poplns] == 0).any(axis=1)
rows_with_zeros = df_all[mask]

print(f"popln of zero identifed: {rows_with_zeros["GEOID"]}")
rows_with_zeros

popln of zero identifed: 85    02261
Name: GEOID, dtype: object


Unnamed: 0.1,Unnamed: 0,A00_2017,A00_2018,A00_2019,A00_2020,A00_2021,A00_2022,A00_2023,A00_2024,A01_2017,...,Cases_2023,Cases_2024,Popln_2017,Popln_2018,Popln_2019,Popln_2020,Popln_2021,Popln_2022,Popln_2023,Popln_2024
85,85,-0.067889,-0.053618,-0.057702,-0.065675,-0.067103,-0.07059,-0.068902,-0.063097,-0.232306,...,0.0,0.0,9224.0,9301.0,9243.0,0,,,,


In [59]:
df_all["Popln_2020"] = df_all["Popln_2020"].replace(0, np.nan)
df_all

Unnamed: 0.1,Unnamed: 0,A00_2017,A00_2018,A00_2019,A00_2020,A00_2021,A00_2022,A00_2023,A00_2024,A01_2017,...,Cases_2023,Cases_2024,Popln_2017,Popln_2018,Popln_2019,Popln_2020,Popln_2021,Popln_2022,Popln_2023,Popln_2024
0,0,0.020481,0.011058,-0.000501,0.011396,0.000753,-0.004479,-0.022577,-0.004919,-0.077097,...,0.0,0.0,36993.0,36524.0,36633.0,36281.0,36879.0,36755.0,36695.0,36630.0
1,1,-0.013062,-0.030307,-0.034806,-0.026667,-0.040167,-0.049618,-0.055469,-0.042281,-0.011958,...,1.0,0.0,40613.0,40535.0,40367.0,40133.0,41284.0,41251.0,41070.0,40699.0
2,2,-0.105284,-0.107447,-0.129861,-0.119771,-0.122789,-0.130565,-0.146850,-0.115087,0.003709,...,0.0,0.0,57787.0,57771.0,57826.0,57879.0,58884.0,59077.0,59292.0,60163.0
3,3,-0.072484,-0.081692,-0.097903,-0.090730,-0.091891,-0.100606,-0.111368,-0.091791,-0.012406,...,0.0,0.0,88199.0,88690.0,89512.0,90739.0,90412.0,91719.0,92903.0,96927.0
4,4,-0.054914,-0.063882,-0.079282,-0.075227,-0.075802,-0.082101,-0.091138,-0.075997,-0.037458,...,2.0,2.0,213605.0,215707.0,217702.0,221428.0,220780.0,223916.0,226955.0,235969.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3146,3146,-0.061269,-0.060530,-0.083934,-0.095966,-0.087213,-0.094672,-0.081715,-0.084212,-0.027524,...,0.0,1.0,,,,278379.0,278855.0,280293.0,279025.0,282602.0
3147,3147,-0.022919,-0.015612,-0.027782,-0.022987,-0.023637,-0.025878,-0.013778,-0.020955,-0.115328,...,0.0,1.0,,,,324397.0,326799.0,326381.0,326296.0,335666.0
3148,3148,-0.045174,-0.040372,-0.060937,-0.055804,-0.057942,-0.050842,-0.053370,-0.055748,-0.067533,...,3.0,4.0,,,,618762.0,623927.0,620666.0,621232.0,637013.0
3149,3149,-0.041947,-0.052342,-0.077795,-0.088983,-0.078226,-0.084367,-0.074010,-0.080236,-0.060030,...,3.0,2.0,,,,964088.0,971938.0,977165.0,969029.0,991508.0


In [60]:
poplns = ["Popln_2017","Popln_2018","Popln_2019","Popln_2020","Popln_2021","Popln_2022","Popln_2023","Popln_2024"]
mask = (df_all[poplns] == 0).any(axis=1)
rows_with_zeros = df_all[mask]

print(f"popln of zero identifed: {rows_with_zeros["GEOID"]}")
rows_with_zeros

popln of zero identifed: Series([], Name: GEOID, dtype: object)


Unnamed: 0.1,Unnamed: 0,A00_2017,A00_2018,A00_2019,A00_2020,A00_2021,A00_2022,A00_2023,A00_2024,A01_2017,...,Cases_2023,Cases_2024,Popln_2017,Popln_2018,Popln_2019,Popln_2020,Popln_2021,Popln_2022,Popln_2023,Popln_2024


In [61]:
df_all = pd.read_csv("../national_embeddings/national_wnv_case_data/long_cases_popln_embs.csv")
df_all.isna().sum()

Unnamed: 0    0
A00_2017      0
A00_2018      0
A00_2019      0
A00_2020      0
             ..
Popln_2020    0
Popln_2021    1
Popln_2022    9
Popln_2023    9
Popln_2024    9
Length: 530, dtype: int64

### 4. Visualizations - WNV Case Count Per Year (National Level)

See `maps.py` for the map generation script.

Below, I am creating a gif of the maps created in `maps.py` and contained in `/notebooks/national_embeddings/wnv_case_maps`

In [13]:
# creating a gif with the generated maps
# https://propolis.io/articles/make-animated-gif-using-python.html

images = []

for img in sorted(glob('../national_embeddings/wnv_case_maps/*.png')):
	im = Image.open(img)
	images.append(im)

last_frame = (len(images)) 

# 5 extra frames for the last map - 2024
for x in range(0, 5):
    im = images[last_frame-1]
    images.append(im)

# save as a gif   
images[0].save('../national_embeddings/wnv_case_maps/cases_2017_to_2024.gif',
               save_all=True, append_images=images[1:], optimize=False, duration=750, loop=0)

### 5. Model Evaluation