## Predicting WNV Prevalence at a County Level in the United States Using AlphaEarth Embedding Data (2017 - 2024)

## Contents:


##### 1. Converting Google Earth Assets to CSV files  

##### 2. Appending WNV case data to each file

##### 3. Obtatining population data from 2017 to 2024 for all counties

##### 4. WNV case data visualized at a national level   

##### 5. Machine learning model evaluation 

In [30]:
from pathlib import Path
from glob import glob
import sys
from dotenv import load_dotenv
import os
from datacommons_client.client import DataCommonsClient

# utils import error: add wnv_embeddings as root
PROJECT_ROOT = Path.cwd().parents[1]  # <-- wnv_embeddings
sys.path.insert(0, str(PROJECT_ROOT))
import ee

from utils.utils import convert_to_df
import pandas as pd
import numpy as np
import geopandas as gpd
from geopandas import clip

import matplotlib.patheffects as pe
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, FancyArrow, Patch
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib as mpl
import contextily as ctx

from shapely.ops import unary_union
from matplotlib.colors import LogNorm, LinearSegmentedColormap


### 1. Converting Google Earth Assets to ~50 CSV Files 

* National average embeddings data per county for all states (2017 to 2024).
* Each asset represents one state (according to the FIPS code).
* State FIPS Codes available here: https://transition.fcc.gov/oet/info/maps/census/fips/fips.txt

#### Saving as CSV to `.\notebooks\national_embeddings\all_embeddings_csvs`

Using the `convert_to_df()` function from `utils.py`: 

In [None]:
# will prompt you to authorize access to GEE
# this is needed to obtain assets from the cloud saved under your account
ee.Authenticate()

# enter your own registered project name here
ee.Initialize(project="wnv-embeddings")

In [None]:
state_fips_codes = [
    "01", "02", "04", "05", "06", "08", "09", "10", "11", "12",
    "13", "15", "16", "17", "18", "19", "20", "21", "22", "23",
    "24", "25", "26", "27", "28", "29", "30", "31", "32", "33",
    "34", "35", "36", "37", "38", "39", "40", "41", "42", "44",
    "45", "46", "47", "48", "49", "50", "51", "53", "54", "55", "56"
  ]

In [None]:
# =============CONVERT GEE ASSETS TO CSVS============= #
# ONLY RUN ONCE TO CONVERT ALL 56 ASSETS AS CSV #

# now obtaining the csvs
# csv_destination = Path("all_embeddings_csvs")
# csv_destination.mkdir(parents=True, exist_ok=True)

# for fips in state_fips_codes:
# 	gee_path = f"users/angel314/{fips}_2017_2024_embeddings"
	
# 	save_to = csv_destination / f"{fips}-avg-embeddings-2017-2024.csv"

# 	convert_to_df(gee_path, True, save_to)

### 2. Appending Yearly WNV Case Data

##### Getting WNV Case Data:
* Source: https://www.cdc.gov/west-nile-virus/data-maps/historic-data.html  
* Section: "Explore county level data for 1999-2024" - "Yearly data"
	* Returns: one CSV with case data at a county level for 1999-2024
* `Location` column represents the FIPS county code for that row.
* WNV Case data is cleaned to only include relevant years and rows with at least one human disease case. 

This is a preview of WNV County Cases from 1999 to 2024.

In [None]:
cases = pd.read_csv("./national_wnv_case_data/wnv_county_cases_1999_2024.csv")
cases.sample(5)

Unnamed: 0,FullGeoName,Year,Location,Activity,Total human disease cases,Neuroinvasive disease cases,**Presumptive viremic blood donors,Notes
16252,"CA, Santa Clara",2007,6085,Human infections and non-human activity,4.0,1.0,0.0,
4456,"CO, Pitkin",2019,8097,Human infections,1.0,0.0,0.0,
26347,"VA, Scott",2002,51169,Non-human activity,0.0,0.0,0.0,
13116,"PA, Indiana",2012,42063,Non-human activity,0.0,0.0,0.0,
7243,"OH, Meigs",2017,39105,Human infections and non-human activity,1.0,1.0,0.0,


In [None]:
###### filtering ######

# remove entries that come before 2017
cases = cases[cases["Year"]>=2017]
# remove any rows with 0 total human disease cases
cases = cases[cases["Total human disease cases"]>0]
# only keep relevant columns
cases = cases.drop(columns=["FullGeoName", "Activity", "Neuroinvasive disease cases", "**Presumptive viremic blood donors", "Notes"]).reset_index(drop=True)
cases

Unnamed: 0,Year,Location,Total human disease cases
0,2024,1001,2.0
1,2024,1003,2.0
2,2024,1021,1.0
3,2024,1043,2.0
4,2024,1047,1.0
...,...,...,...
4006,2017,55141,2.0
4007,2017,56003,1.0
4008,2017,56013,3.0
4009,2017,56015,2.0


In [None]:
cases = cases.groupby(["Year","Location"]).agg("sum").reset_index()
cases

Unnamed: 0,Year,Location,Total human disease cases
0,2017,1001,6.0
1,2017,1003,3.0
2,2017,1007,1.0
3,2017,1011,1.0
4,2017,1015,2.0
...,...,...,...
4006,2024,55133,1.0
4007,2024,55139,1.0
4008,2024,55141,1.0
4009,2024,56015,1.0


In [None]:
# convert from long format to wide format
# each row represents one location
# each location has sum of cases for 2017 - 2024.

# columns="Year" -> each unique year is a column
# values="cases" -> numbers to fill pivot table
# take sum of all entries for the the same location and year

# reset_index to move "Location" column to the right.

cases_wide = (cases.pivot_table(index="Location", columns="Year", values="Total human disease cases", aggfunc="sum", fill_value=0).add_prefix("Cases_").reset_index())
cases_wide

Year,Location,Cases_2017,Cases_2018,Cases_2019,Cases_2020,Cases_2021,Cases_2022,Cases_2023,Cases_2024
0,1001,6.0,0.0,0.0,1.0,1.0,0.0,1.0,2.0
1,1003,3.0,2.0,1.0,0.0,2.0,1.0,0.0,2.0
2,1007,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
3,1011,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,1015,2.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0
...,...,...,...,...,...,...,...,...,...
1607,56025,0.0,0.0,1.0,0.0,0.0,1.0,3.0,1.0
1608,56029,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0
1609,56031,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0
1610,56033,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0


Saving this cleaned dataframe to a csv for future use.

In [None]:
cases_wide.to_csv("./national_wnv_case_data/agg_wnv_county_cases_2017_2024.csv")

##### Iterating Over `all_embeddings_csvs` to add WNV Human cases for each year.

In [None]:
for code in state_fips_codes:
  # get csv for current fips code
	path = f"./all_embeddings_csvs/{code}-avg-embeddings-2017-2024.csv"
  # load in the csv 
	df = pd.read_csv(path)
	df_merged = pd.merge(df, cases_wide, left_on="GEOID", right_on="Location", how="left").fillna(0).drop(columns=["Location"])

	df_merged.to_csv(f"./all_embeddings_with_cases/cleaned-{code}-avg-embeddings-2017-2024.csv")

### 3. Obtaining and Appending County Population Data:

* Using Data Commons API:

	* https://docs.datacommons.org/what_is.html 

	* Basically allows us to query specific statistical questions and get one unified result.

	* There is an option to query for counties as well using FIPS codes: https://datacommons.org/browser/County 

County population data is needed for each year to normalize based on this formula:

$\textnormal{Cases per 100k} = \frac{\textnormal{Number of disease cases}}{\textnormal{County population}} \times 100,000$

Normalized cases (cases per 100k) will be the target variable when measuring machine learning models' performance.

note: api.census.gov does not have consistent and updated data for 2017 - 2024 county populations.

In [27]:
load_dotenv()

client = DataCommonsClient(api_key=os.getenv("COMMONS_API_KEY"))

In [33]:
# the datacommons observations_datafrmae function will fetch: population for the given geoid and year. 
# This way, I can get the mean returned for all values and then convert to an integer. 

def get_popln(client, geoid: str, year: int):
    try:
        df = client.observations_dataframe(
            variable_dcids="Count_Person",
            entity_dcids=f"geoId/{geoid}",
            date=str(year)
        )

        if df.empty:
            return geoid, year, None

        pop = int(np.ceil(df["value"].astype(float).mean()))
        return geoid, year, pop

    except Exception as e:
        # API sometimes fails per-entity; do not kill the whole run
        return geoid, year, None

##### Cook County 2017 Population Test #####
print(get_popln(client=client, geoid="17031", year="2017"))

('17031', '2017', 5216669)


In [34]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Iterable

In [None]:
def fetch_populations_multithreaded(df_all: pd.DataFrame, client, years: Iterable[int], max_workers: int = 16) -> pd.DataFrame:
	"""
	Returns a new dataframe with Popln_YYYY columns added.
	"""

	geoids = df_all["GEOID"].unique().tolist()

	# Build all jobs
	jobs = [
			(geoid, year)
			for geoid in geoids
			for year in years
	]

	results = []

	with ThreadPoolExecutor(max_workers=max_workers) as executor:
			future_to_job = {executor.submit(get_popln, client, geoid, year): (geoid, year) for geoid, year in jobs}

			for future in as_completed(future_to_job):
					# result is tuple[str, int, int] 
					results.append(future.result())

	# results -> tidy dataframe
	pop_df = pd.DataFrame(results, columns=["GEOID", "year", "population"])

	pop_wide = (pop_df.pivot(index="GEOID", columns="year", values="population").reset_index())

	pop_wide.columns = [
			"GEOID" if c == "GEOID" else f"Popln_{c}"
			for c in pop_wide.columns
	]

	# merge back
	out = df_all.merge(pop_wide, on="GEOID", how="left")

	return out


In [None]:
df_all = fetch_populations_multithreaded(
    df_all,
    client,
    years=range(2017, 2025),
    max_workers=16
)

In [None]:
# ----- Convert all CSVs in `all_embeddings_with_cases` to one long dataframe ----- #

files = glob("../national_embeddings/all_embeddings_with_cases/*.csv")
dfs = [pd.read_csv(f) for f in files]

df_all = pd.concat(dfs, ignore_index=True)
df_all = df_all.drop(df_all.columns[0], axis=1)

df_all.to_csv("../national_embeddings/national_wnv_case_data/long_format_emb_and_cases.csv", index=False)

### 4. Visualizations - WNV Case Count Per Year (National Level)

In [None]:
# NOTE: CONNECTICUT ADJUSTMENT (MISSING CASES) #
# another case: Connecticut uses new "planning region" FIPS codes, rather than the older FIPS codes used in `df_all`
# this issue stems from the TIGER 2018 Counties dataset used in the GEE scripts (GEE does not have a newer dataset available)
# thus, it may not show up when plotting, this needs to be accounted for:
# I changed the tiger shapefiles to also be 2018 for plotting purposes

# This includes the conversions from the old county codes to the new planning region codes.
# https://developer.ap.org/ap-elections-api/docs/CT_FIPS_Codes_forPlanningRegions.htm

# I mapped legacy county FIPS to all matching updated planning region codes:

# legacy_county_to_planning_regions = {
#     "09001": ["09120", "09140", "09190"],   # Fairfield County
#     "09003": ["09110", "09140", "09160"],   # Hartford County
#     "09005": ["09140", "09160", "09190"],   # Litchfield County
#     "09007": ["09130"],                     # Middlesex County
#     "09009": ["09140", "09170"],             # New Haven County
#     "09011": ["09130", "09150", "09180"],    # New London County
#     "09013": ["09110", "09150"],             # Tolland County
#     "09015": ["09150", "09180"],             # Windham County
# }

# from "/national_wnv_case_data/agg_wnv_county_cases_2017_2024.csv"
# Planning Region | Cases 2023 | Cases 2024
# 09110 | 3 | 2
# 09120 | 0 | 1
# 09130 | 0 | 1
# 09140 | 0 | 0
# 09150 | 0 | 0
# 09160 | 0 | 0
# 09170 | 1 | 3
# 09180 | 0 | 1
# 09190 | 3 | 4

# I manually went in and updated 2022 data, as it still used old FIPS codes
# however, 2023 and 2024 have remained untouched due to a proper way to equally distribute the cases
# (see "View the 2022 Connecticut Data" at https://www.cdc.gov/west-nile-virus/data-maps/historic-data.html)

# this exemplifies what I mentioned in my earlier comment about connecticut
df_ct = df_all[df_all["GEOID"].str.startswith("09")]
df_ct

In [None]:
# ----- Configurations ----- #
FONT_FAMILY = "DejaVu Sans"
BASE_FONTSIZE = 9
TITLE_FONTSIZE = 14
COUNTRY_LABEL_FONTSIZE = 13
mpl.rcParams.update({
    "font.family": FONT_FAMILY,
    "font.size": BASE_FONTSIZE,
    "axes.titlesize": TITLE_FONTSIZE,
    "axes.titleweight": "bold"
})

LEFT_ANCHOR = 0.01

# ----- State and county shapefile paths ----- #

# shapefiles obtained from: https://www.census.gov/cgi-bin/geo/shapefiles/index.php on Feb. 4, 2026
STATES_PATH = "../national_embeddings/shapefiles/tl_2018_us_state/tl_2018_us_state.shp"
COUNTIES_PATH = "../national_embeddings/shapefiles/tl_2018_us_county/tl_2018_us_county.shp"
WATER_PATH = "../national_embeddings/shapefiles/great_lakes_usnic/GL260205_lam.shp"
COUNTRIES = "../national_embeddings/shapefiles/ne_110m_admin_0_countries"
ALL_DATA = "../national_embeddings/national_wnv_case_data/long_format_emb_and_cases.csv"

# ----- Get state and county geographies (and great lakes), neighboring countries, embeddings + WNV data ----- #
states = gpd.read_file(STATES_PATH)
counties = gpd.read_file(COUNTIES_PATH)
water = gpd.read_file(WATER_PATH)

df_all = pd.read_csv(ALL_DATA)
df_all["GEOID"] = df_all["GEOID"].astype(str).str.zfill(5)

world = gpd.read_file(COUNTRIES)

# ----- Project to EPSG:3857 and adjust boundaries ----- #

# ignore Alaska, Hawaii, Guam, Puerto Rico, Commonwealth of the Northern Mariana Islands, American Samoa, Virgin Islands (no cases)
# including these also unneccessarily enlarge the zoom on the US map 
exclude = ["AK","HI","GU","PR","MP","AS","VI"]
exclude_sfips = ['02', '60', '15', '78', '72', '69', '66']
# historically have had no cases: # https://health.hawaii.gov/docd/disease_listing/west-nile-virus/
# https://www.usgs.gov/faqs/where-united-states-has-west-nile-virus-been-detected-wildlife
# no geoid matched with CNMI in the cases data frame

canada = world[world['NAME_EN'] == "Canada"].to_crs(3857)
mexico = world[world['NAME_EN'] == "Mexico"].to_crs(3857)

states = states.to_crs(3857)
water = water.to_crs(3857)

states = states[~states["STUSPS"].isin(exclude)]
counties = counties[~counties["STATEFP"].isin(exclude_sfips)]
counties = counties.to_crs(3857)

# get united states outline
us_outline = gpd.GeoDataFrame(geometry=[unary_union(states.geometry)], crs=states.crs)

# only keep stuff inside the US OUTLINE (no hanging great lakes region)
water_clipped = gpd.clip(water, us_outline)

# ----- Merge previous long dataframe with geographies ----- #

# both use "GEOID" as unique identifier
counties_geom = counties[["GEOID","geometry"]]
df_merged = pd.merge(df_all, counties_geom, on="GEOID", how="inner")
# convert df_merged to a GeoDataFrame (I need to inspect it visually)
df_merged = gpd.GeoDataFrame(df_merged, geometry=df_merged.geometry, crs=states.crs)

# only keep cases columns (embedding data is irrelevant for these visualizations)
df_merged = df_merged[['GEOID', 'Cases_2017', 'Cases_2018', 'Cases_2019', 'Cases_2020',
       'Cases_2021', 'Cases_2022', 'Cases_2023', 'Cases_2024', 'geometry']]

# ----- Helper functions ----- #

def add_scalebar_miles_left_endlabel(
    ax,
    anchor=(0.03, 0.055),
    bar_h=26_000,
    tick_len_frac=0.5,
    label_fontsize=BASE_FONTSIZE,
    width_frac=0.42
):
    """Scale bar with non-uniform segments (0–250–500–1000 miles for national map)."""
    M_PER_MILE = 1609.344
    tick_values = [0, 250, 500, 1000]
    seg_lengths = np.diff(tick_values)
    
    x0, x1 = ax.get_xlim()
    y0, y1 = ax.get_ylim()
    width_m = x1 - x0
    height_m = y1 - y0
    
    total_m_draw = width_m * width_frac
    total_miles = tick_values[-1] - tick_values[0]
    unit_scale = total_m_draw / total_miles
    tick_len = bar_h * tick_len_frac
    
    axfx, axfy = anchor
    x_left = x0 + axfx * width_m
    y_base = y0 + axfy * height_m
    
    edge = "#1E2933"
    dark = "#2F3B46"
    light = "#FFFFFF"
    
    # Outer frame
    ax.add_patch(Rectangle(
        (x_left, y_base), total_m_draw, bar_h,
        facecolor="none", edgecolor=edge, linewidth=1.6, zorder=60
    ))
    
    # Draw segments
    x_curr = x_left
    for i, seg_len_miles in enumerate(seg_lengths):
        seg_m = seg_len_miles * unit_scale
        face = dark if i % 2 == 0 else light
        
        ax.add_patch(Rectangle(
            (x_curr, y_base), seg_m, bar_h,
            facecolor=face, edgecolor=edge, linewidth=1.2, zorder=61
        ))
        
        # Tick at start
        ax.plot([x_curr, x_curr], [y_base, y_base - tick_len],
                color=edge, lw=2.0, solid_capstyle="round", zorder=62)
        ax.text(x_curr, y_base - tick_len - 11_000, f"{tick_values[i]}",
                ha="center", va="top", fontsize=label_fontsize, 
                color=edge, zorder=63)
        x_curr += seg_m
    
    # Final tick
    ax.plot([x_curr, x_curr], [y_base, y_base - tick_len],
            color=edge, lw=2.0, solid_capstyle="round", zorder=62)
    ax.text(x_curr, y_base - tick_len - 11_000, f"{tick_values[-1]}",
            ha="center", va="top", fontsize=label_fontsize, 
            color=edge, zorder=63)
    
    # "Miles" label (move text further to the right)
    ax.text(x_curr + unit_scale * 90, y_base - bar_h * 0.5, "Miles",
            ha="left", va="center", fontsize=label_fontsize, 
            color=edge, zorder=63)

def add_compass(ax, center_frac=(0.12, 0.18), size=180_000, color="#2F3B46"):
    """Compass rose."""
    x0, x1 = ax.get_xlim()
    y0, y1 = ax.get_ylim()
    cx = x0 + center_frac[0] * (x1 - x0)
    cy = y0 + center_frac[1] * (y1 - y0)
    
    # Cross lines
    ax.plot([cx - size*0.8, cx + size*0.8], [cy, cy], 
            color=color, lw=1.1, zorder=70)
    ax.plot([cx, cx], [cy - size*0.8, cy + size*0.8], 
            color=color, lw=1.1, zorder=70)
    
    # North arrow
    ax.add_patch(FancyArrow(
        cx, cy, 0, size*0.95, 
        width=size*0.12, head_width=size*0.35, head_length=size*0.35,
        color=color, length_includes_head=True, zorder=71
    ))
    
    fs = BASE_FONTSIZE
    ax.text(cx, cy + size*1.05, "N", ha="center", va="bottom", 
            fontsize=fs, color=color, zorder=71)
    ax.text(cx, cy - size*0.95, "S", ha="center", va="top", 
            fontsize=fs, color=color, zorder=71)
    ax.text(cx + size*0.95, cy, "E", ha="left", va="center", 
            fontsize=fs, color=color, zorder=71)
    ax.text(cx - size*0.95, cy, "W", ha="right", va="center", 
            fontsize=fs, color=color, zorder=71)

# ----------------------------- COLORBAR CLASSES ----------------------------
class _ColorbarProxy:
    def __init__(self, cmap, norm, ticks, ticklabels=None, nsteps=64):
        self.cmap = cmap
        self.norm = norm
        self.ticks = ticks
        self.ticklabels = ticklabels if ticklabels else [f"{t:g}" for t in ticks]
        self.nsteps = nsteps

class _VerticalColorbarHandler(mpl.legend_handler.HandlerBase):
    def create_artists(self, legend, orig_handle, xdescent, ydescent, 
                      width, height, fontsize, trans):
        cmap = orig_handle.cmap
        norm = orig_handle.norm
        ticks = orig_handle.ticks
        n = orig_handle.nsteps
        artists = []
        
        # Geometry
        pad_x = 0.68 * width
        # adjusted to make it less wide
        bar_w = 2.0 * width
        bar_h = height * 7.5
        x0 = xdescent + pad_x
        top_pad_frac = -0.4
        y0 = ydescent + (height - bar_h) / 2 + top_pad_frac * bar_h
        
        # Gradient rectangles
        for i in range(n):
            y = y0 + (i / n) * bar_h
            y2 = y0 + ((i + 1) / n) * bar_h
            if isinstance(norm, LogNorm):
                frac = (i + 0.5) / n
                val = np.exp(np.log(norm.vmin) * (1 - frac) + 
                           np.log(norm.vmax) * frac)
            else:
                val = norm.vmin + ((i + 0.5) / n) * (norm.vmax - norm.vmin)
            rect = Rectangle(
                (x0, y), bar_w, y2 - y, 
                facecolor=cmap(norm(val)), edgecolor="none", lw=0
            )
            rect.set_transform(trans)
            artists.append(rect)
            
				# ADDED BACK: Vertical "WNV Cases" label on left side
        title = mpl.text.Text(
            x=x0 - 0.5 * width, y=y0 + bar_h / 2,
            text="WNV Cases", rotation=90,
            va="center", ha="center",
            fontsize=fontsize - 3.2, color="#000000"
        )
        title.set_transform(trans)
        artists.append(title)
        
        # Tick labels
        label_x = x0 + bar_w + 0.14 * width
        PAD_TOP_FRAC = 0.09
        PAD_LOW_FRAC = 0.06
        
        for i, (lab, t) in enumerate(zip(orig_handle.ticklabels, ticks)):
            frac = norm(t)
            if i == 0:
                frac = min(1.0, frac + PAD_LOW_FRAC)
            elif i == len(ticks) - 1:
                frac = max(0.0, frac - PAD_TOP_FRAC)
            ytick = y0 + frac * bar_h
            txt = mpl.text.Text(
                x=label_x, y=ytick, text=lab,
                va="center", ha="left",
                fontsize=fontsize-1, color="#111"
            )
            txt.set_transform(trans)
            artists.append(txt)
        
        return artists

In [52]:
# ----------------------------- CREATE COLORMAP ----------------------------
cmap = LinearSegmentedColormap.from_list(
    "wnv_cases", ["#FFF7EF", "#8C1A3C"]
)

# ----------------------------- PLOT LOOP ----------------------------
for year in range(2017, 2018):
    col = f"Cases_{year}"
    
    # Prepare data
    plot_df = df_merged.copy()
    plot_df["plot_col"] = plot_df[col].copy()
    
    # Handle zeros/missing
    has_pos = (plot_df["plot_col"] > 0).any()
    if has_pos:
        vmax = int(plot_df["plot_col"].max())
        eps = 0.8
        plot_df["plot_col_adj"] = np.where(
            plot_df["plot_col"] <= 0, eps, plot_df["plot_col"].astype(float)
        )
        norm = LogNorm(vmin=eps, vmax=vmax)
        data_min = int(plot_df.loc[plot_df["plot_col"] > 0, "plot_col"].min())
        data_max = vmax
    else:
        plot_df["plot_col_adj"] = 1.0
        norm = mpl.colors.Normalize(vmin=0, vmax=1)
        data_min = 0
        data_max = 0
    
    # Create figure
    fig, ax = plt.subplots(figsize=(15, 8), dpi=300)
    
    # Basemap
    ctx.add_basemap(
        ax, crs=df_merged.crs, 
        source=ctx.providers.CartoDB.Positron, 
        zoom=4, alpha=0.9, attribution=False
    )
    
    # Canada and Mexico backdrops
    if canada is not None and not canada.empty:
        canada.plot(ax=ax, facecolor="#E6E8EB", edgecolor="#D1D5DB", 
                   lw=0.6, zorder=1)
    if mexico is not None and not mexico.empty:
        mexico.plot(ax=ax, facecolor="#E6E8EB", edgecolor="#D1D5DB", 
                   lw=0.6, zorder=1)
    
    # US outline backdrop
    us_outline.plot(ax=ax, facecolor="#E6E8EB", edgecolor="#D1D5DB", 
                   lw=0.6, zorder=1)
    
    # Choropleth - census tracts
    plot_df.plot(
        ax=ax, column="plot_col_adj", cmap=cmap, norm=norm,
        linewidth=0.05, edgecolor="#9AA3AD", alpha=0.9, zorder=2.1
    )
    
    # Water bodies
    if water_clipped is not None and not water_clipped.empty:
        water_clipped.plot(ax=ax, color="lightblue", linewidth=0, zorder=2.5)
    
    # Connecticut hatching for 2023-2024
    if year in [2023, 2024]:
        ct_mask = plot_df["GEOID"].str.startswith("09")
        plot_df[ct_mask].plot(
            ax=ax, facecolor="#D1D5DB", edgecolor="#AAAAAA",
            linewidth=0.3, zorder=3, hatch="///", alpha=0.6
        )
    
    # State boundaries
    states.boundary.plot(ax=ax, edgecolor="#333333", linewidth=0.8, zorder=4)
    
    # County boundaries (lighter)
    counties.boundary.plot(ax=ax, edgecolor="#9AA3AD", linewidth=0.2, zorder=3.5)
    
    # Country labels
    xmin, ymin, xmax, ymax = us_outline.total_bounds
    dx, dy = xmax - xmin, ymax - ymin
    xmid, ymid = (xmin + xmax) / 2, (ymin + ymax) / 2
    
    txt_can = ax.text(
        xmid, ymax + 0.025*dy, "CANADA",
        fontsize=COUNTRY_LABEL_FONTSIZE, fontweight="bold",
        color="#6B7280", ha="center", va="center", zorder=5
    )
    # fix -> shift to the left more
    txt_mex = ax.text(
        xmid - 0.12*dx, ymin - 0.01*dy, "MEXICO",
        fontsize=COUNTRY_LABEL_FONTSIZE, fontweight="bold",
        color="#6B7280", ha="center", va="center", zorder=5
    )
    
    for t in [txt_can, txt_mex]:
        t.set_path_effects([
            pe.withStroke(linewidth=2.2, foreground="white", alpha=0.9)
        ])
    
    # Set extent
    pad_left = 0.05 * dx
    pad_right = 0.05 * dx
    pad_bottom = 0.08 * dy
    pad_top = 0.05 * dy
    
    ax.set_xlim(xmin - pad_left, xmax + pad_right)
    ax.set_ylim(ymin - pad_bottom, ymax + pad_top)
    ax.set_aspect("equal", adjustable="box")
    ax.set_axis_off()
    
    # ----------------------------- LEGEND ----------------------------
    # Build colorbar proxy
    if has_pos:
        low_pos = float(norm.vmin)
        high_pos = float(norm.vmax)
        ticks = [low_pos, high_pos]
        ticklabels = [f"Low: {data_min}", f"High: {data_max}"]
    else:
        ticks = [0.0, 1.0]
        ticklabels = [f"Low: {data_min}", f"High: {data_max}"]
    
    colorbar_proxy = _ColorbarProxy(
        cmap=cmap, norm=norm, ticks=ticks, ticklabels=ticklabels
    )
		
		# adding to legend - county box handle 
    county_handle = Rectangle(
        (0, 0), width=1.0, height=0.6,
        facecolor='none', edgecolor='#9AA3AD', linewidth=0.8
    )
    
    # Legend items
    num_colorbar_rows = 2
    invisible_handles = [Patch(alpha=0)] * num_colorbar_rows
    
    handles = [county_handle] + [colorbar_proxy] + invisible_handles
    labels = ["County"] + [""] * num_colorbar_rows + [""] 
    
    # Remove old legend if present
    for child in ax.get_children():
        if isinstance(child, mpl.legend.Legend):
            child.remove()
    
    leg = ax.legend(
        handles=handles, labels=labels,
        handler_map={_ColorbarProxy: _VerticalColorbarHandler()},
        title="Legend",
        loc="lower left",
        bbox_to_anchor=(LEFT_ANCHOR, 0.081, 0.12, 0.6),
        frameon=True, framealpha=1.0,
        edgecolor="#B8BEC5", facecolor="#FFFFFF",
        fontsize=BASE_FONTSIZE, title_fontsize=10,
        alignment="left", mode="expand",
        borderpad=1.0, labelspacing=1.0,
        handlelength=1.6, handletextpad=0.5
    )
    
    # Bold legend title and header
    if leg.get_title() is not None:
        leg.get_title().set_fontweight("bold")
        leg.get_title().set_ha("left")
        
		# MISSING CONNECTICUT DATA WARNING
    if year in [2023, 2024]:
        # Get legend position
        fig.canvas.draw()
        renderer = fig.canvas.get_renderer()
        bbox_px = leg.get_window_extent(renderer=renderer)
        (x0_ax, y0_ax) = ax.transAxes.inverted().transform((bbox_px.x0, bbox_px.y0))
        
        # Place note below legend
        ax.text(
            x0_ax, y0_ax - 0.02,
            "Note: Connecticut county data\nunavailable for 2023–2024 due to planning region change.",
            transform=ax.transAxes,
            fontsize=7.5, color="#444444",
            ha="left", va="top",
            zorder=200,
            style='italic'
        )
    
    # Title above legend
    fig.canvas.draw()
    renderer = fig.canvas.get_renderer()
    bbox_px = leg.get_window_extent(renderer=renderer)
    (x0_ax, y0_ax) = ax.transAxes.inverted().transform((bbox_px.x0, bbox_px.y0))
    (x1_ax, y1_ax) = ax.transAxes.inverted().transform((bbox_px.x1, bbox_px.y1))
    
    title_x = x0_ax
    title_y = y1_ax - 0.02
    title_text = f"West Nile Virus\nCases by County\n({year})"
    
    t = ax.text(
        title_x, title_y, title_text,
        transform=ax.transAxes,
        ha="left", va="bottom",
        fontsize=10.5, fontweight="bold",
        color="#111", zorder=200
    )
    t.set_path_effects([
        pe.withStroke(linewidth=3, foreground="white", alpha=0.9)
    ])
    
    # Connecticut note
    if year in [2023, 2024]:
        ax.text(
            0.02, 0.02,
            "Note: Connecticut county data unavailable for 2023–2024\n"
            "(state switched from counties to planning regions).",
            transform=ax.transAxes,
            fontsize=7.5, color="#444444",
            ha="left", va="bottom", zorder=200
        )
    
    # Scale bar and compass
    add_scalebar_miles_left_endlabel(
        ax, anchor=(0.02, 0.04), 
        bar_h=40_000, width_frac=0.15
    )
    add_compass(
        ax, center_frac=(0.035, 0.45), 
        size=150_000
    )
    
    # Save
    plt.subplots_adjust(left=0.01, right=0.99, top=0.96, bottom=0.06)
    plt.savefig(
        f"../national_embeddings/wnv_case_maps/wnv_cases_map_{year}.png",
        dpi=300, bbox_inches="tight", facecolor="white"
    )
    plt.close(fig)
    print(f"Saved map for {year}")

Saved map for 2017


### 5. Model Evaluation