# COVID-19 Deaths Analysis United States


In [14]:
# %pip install pandas plotly scikit-learn statsmodels matplotlib

In [15]:
# Import libraries
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from statsmodels.tsa.api import ExponentialSmoothing
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# 🔹 Data Cleaning & Preprocessing

In [16]:
# Load dataset
df = pd.read_csv("COVID_19_Dataset.csv")

# Convert 'data_as_of' to datetime
df["data_as_of"] = pd.to_datetime(df["data_as_of"], errors="coerce")

# Clean 'year' column (remove commas, convert to int)
df["year"] = df["year"].astype(str).str.replace(",", "", regex=True)
df["year"] = pd.to_numeric(df["year"], errors="coerce")

# Ensure 'month' is numeric
df["month"] = pd.to_numeric(df["month"], errors="coerce")

# Create a 'date' column from year and month
df["date"] = pd.to_datetime(df[["year", "month"]].assign(day=1), errors="coerce")

# Clean 'COVID_deaths' column (remove commas, convert to numeric)
df["COVID_deaths"] = (
    df["COVID_deaths"].astype(str).str.replace(",", "", regex=True)
)
df["COVID_deaths"] = pd.to_numeric(df["COVID_deaths"], errors="coerce")

# Filtering the United States Data
df_us = df[df["jurisdiction_residence"] == "United States"]

# Drop rows with missing death counts
df_us_clean = df_us.dropna(subset=["COVID_deaths"]).sort_values(by="date")


# 🔹 Time Series Line Plot — Total Deaths Over Time

In [17]:
# %pip install nbformat
# Mime type rendering requires nbformat>=4.2.0


In [18]:
# Aggregate monthly COVID-19 deaths at the national level
df_monthly = df_us_clean.groupby("date")["COVID_deaths"].sum().reset_index()

# Plot the monthly COVID-19 deaths as a line chart
fig = px.line(
    df_monthly,
    x="date",
    y="COVID_deaths",
    title="Monthly COVID-19 Deaths in the US",
    labels={"COVID_deaths": "Deaths", "date": "Date"},
    markers=True,
)

# Use a clean white theme for the chart
fig.update_layout(template="plotly_white")

# Show the plot
fig.show()


# 🔹 Choropleth Map — Region-wise Crude Death Rates

In [19]:
# Use the latest month’s data
latest_date = df["date"].max()
df_latest = df[df["date"] == latest_date]
df_latest_state = df_latest[df_latest["jurisdiction_residence"] != "United States"]

# Define a mapping of regions to states
region_to_states = {
    "Region 1": [
        "Connecticut", "Maine", "Massachusetts", "New Hampshire", "Rhode Island", "Vermont",
    ],
    "Region 2": ["New Jersey", "New York"],
    "Region 3": [
        "Delaware", "District of Columbia", "Maryland", "Pennsylvania", "Virginia", "West Virginia",
    ],
    "Region 4": [
        "Alabama", "Florida", "Georgia", "Kentucky", "Mississippi",
        "North Carolina", "South Carolina", "Tennessee",
    ],
    "Region 5": ["Illinois", "Indiana", "Michigan", "Minnesota", "Ohio", "Wisconsin"],
    "Region 6": ["Arkansas", "Louisiana", "New Mexico", "Oklahoma", "Texas"],
    "Region 7": ["Iowa", "Kansas", "Missouri", "Nebraska"],
    "Region 8": ["Colorado", "Montana", "North Dakota", "South Dakota", "Utah", "Wyoming"],
    "Region 9": ["Arizona", "California", "Hawaii", "Nevada"],
    "Region 10": ["Alaska", "Idaho", "Oregon", "Washington"],
}

# Expand regions to states
df_latest_state_expanded = df_latest[
    df_latest["jurisdiction_residence"] != "United States"
].copy()

# Map regions to states and create a new column with lists of states
df_latest_state_expanded["states"] = df_latest_state_expanded[
    "jurisdiction_residence"
].map(region_to_states)

# Explode the 'states' column to create one row per state
df_latest_state_expanded = df_latest_state_expanded.explode("states")

# Rename the 'states' column to 'jurisdiction_residence' for consistency
df_latest_state_expanded["jurisdiction_residence"] = df_latest_state_expanded["states"]
df_latest_state_expanded.drop(columns=["states"], inplace=True)

# --- FIX: convert numeric columns ---
df_latest_state_expanded["COVID_deaths"] = (
    df_latest_state_expanded["COVID_deaths"].astype(str).str.replace(",", "", regex=True)
)
df_latest_state_expanded["COVID_deaths"] = pd.to_numeric(
    df_latest_state_expanded["COVID_deaths"], errors="coerce"
)

df_latest_state_expanded["crude_COVID_rate"] = (
    df_latest_state_expanded["crude_COVID_rate"].astype(str).str.replace(",", "", regex=True)
)
df_latest_state_expanded["crude_COVID_rate"] = pd.to_numeric(
    df_latest_state_expanded["crude_COVID_rate"], errors="coerce"
)

# Aggregate data by state
df_state_level = df_latest_state_expanded.groupby(
    "jurisdiction_residence", as_index=False
).agg({"crude_COVID_rate": "mean", "COVID_deaths": "sum"})

# State abbreviations
state_abbreviations = {
    "Alabama": "AL", "Alaska": "AK", "Arizona": "AZ", "Arkansas": "AR", "California": "CA",
    "Colorado": "CO", "Connecticut": "CT", "Delaware": "DE", "District of Columbia": "DC",
    "Florida": "FL", "Georgia": "GA", "Hawaii": "HI", "Idaho": "ID", "Illinois": "IL",
    "Indiana": "IN", "Iowa": "IA", "Kansas": "KS", "Kentucky": "KY", "Louisiana": "LA",
    "Maine": "ME", "Maryland": "MD", "Massachusetts": "MA", "Michigan": "MI", "Minnesota": "MN",
    "Mississippi": "MS", "Missouri": "MO", "Montana": "MT", "Nebraska": "NE", "Nevada": "NV",
    "New Hampshire": "NH", "New Jersey": "NJ", "New Mexico": "NM", "New York": "NY",
    "North Carolina": "NC", "North Dakota": "ND", "Ohio": "OH", "Oklahoma": "OK", "Oregon": "OR",
    "Pennsylvania": "PA", "Rhode Island": "RI", "South Carolina": "SC", "South Dakota": "SD",
    "Tennessee": "TN", "Texas": "TX", "Utah": "UT", "Vermont": "VT", "Virginia": "VA",
    "Washington": "WA", "West Virginia": "WV", "Wisconsin": "WI", "Wyoming": "WY",
}

# Map state names to abbreviations
df_state_level["jurisdiction_residence"] = df_state_level["jurisdiction_residence"].map(
    state_abbreviations
)

# Create the choropleth map
fig = px.choropleth(
    df_state_level,
    locations="jurisdiction_residence",
    locationmode="USA-states",
    scope="usa",
    color="crude_COVID_rate",
    hover_name="jurisdiction_residence",
    title="Crude COVID-19 Death Rate by State (Latest Month)",
    color_continuous_scale="blues",
)
fig.show()


# 🔹 Demographic Comparison — Race & Hispanic Origin

In [20]:
# Mapping of Race & Hispanic Origins to manage cluttred lables
race_map = {
    "Hispanic": "Hispanic",
    "Non-Hispanic American Indian or Alaska Native": "AI/AN",
    "Non-Hispanic Asian": "Asian",
    "Non-Hispanic Asian, Native Hawaiian or Other Pacific Islander": "Asian/PI",
    "Non-Hispanic Black": "Black",
    "Non-Hispanic Native Hawaiian or Other Pacific Islander": "Native Hawaiian/PI",
    "Non-Hispanic White": "White",
}

# Filter race-based records
race_df = (
    df_us_clean[df_us_clean["group"] == "Race"]
    .groupby("subgroup1")["COVID_deaths"]
    .sum()
    .reset_index()
)

# Aggregate race-wise deaths
race_df["race"] = race_df["subgroup1"].map(race_map).fillna(race_df["subgroup1"])

# Visualize
fig = px.bar(
    race_df,
    x="race",
    y="COVID_deaths",
    title="COVID-19 Deaths by Race",
    labels={"race": "Race", "COVID_deaths": "Deaths"},
    color="COVID_deaths",
    color_continuous_scale="viridis",
)
fig.update_layout(xaxis_tickangle=45, legend_title="Deaths")
fig.show()


# 🔹 COVID Deaths by Age Group

In [21]:
# Filter out the Age group rows
df_age = df_us_clean[df_us_clean["group"] == "Age"]

# Aggregate the filtered data
df_age_grouped = df_age.groupby("subgroup1")["COVID_deaths"].sum().reset_index()
df_age_grouped["age_sort"] = (
    df_age_grouped["subgroup1"].str.extract(r"(\d+)").astype(int)
)
df_age_grouped = df_age_grouped.sort_values("age_sort")

# Visualize
fig = px.bar(
    df_age_grouped,
    x="subgroup1",
    y="COVID_deaths",
    title="COVID Deaths by Age Group",
    labels={"subgroup1": "Age Group", "COVID_deaths": "COVID Deaths"},
    color="COVID_deaths",
    color_continuous_scale="viridis",
)
fig.update_layout(xaxis_tickangle=-45)
fig.show()


# 🔹 Clustering Regions - by COVID Deaths & Crude Rate

In [22]:
# Filter the most recent data, excluding the national aggregate
df_latest = df[df["date"] == df["date"].max()]
df_latest = df_latest[df_latest["jurisdiction_residence"] != "United States"]

# Ensure numeric columns
df_latest["COVID_deaths"] = pd.to_numeric(df_latest["COVID_deaths"], errors="coerce")
df_latest["crude_COVID_rate"] = pd.to_numeric(df_latest["crude_COVID_rate"], errors="coerce")

# Compute average deaths and crude rates by jurisdiction, dropping any rows with missing values
cluster_df = (
    df_latest.groupby("jurisdiction_residence", as_index=False)
    .agg({"COVID_deaths": "mean", "crude_COVID_rate": "mean"})
    .dropna()
)

# Standardize the features for clustering
X = StandardScaler().fit_transform(cluster_df[["COVID_deaths", "crude_COVID_rate"]])

# Apply K-means clustering with 3 clusters
kmeans = KMeans(n_clusters=3, random_state=42)
cluster_df["cluster"] = kmeans.fit_predict(X)

# Create a scatter plot to visualize the clustering results
fig = px.scatter(
    cluster_df,
    x="COVID_deaths",
    y="crude_COVID_rate",
    color="cluster",
    hover_name="jurisdiction_residence",
    title="Clustering Regions by COVID-19 Deaths and Crude COVID Rate",
    labels={"COVID_deaths": "COVID Deaths", "crude_COVID_rate": "Crude COVID Rate"},
)
fig.show()


# 🔹 COVID-19 Deaths Forecast (6 Months)

In [23]:
# Fit a Holt-Winters Exponential Smoothing model to the monthly national COVID-19 deaths
# - 'trend="mul"' and 'seasonal="mul"' specify a multiplicative model for both trend and seasonality
# - 'seasonal_periods=12' assumes a yearly seasonality (12 months)
model = ExponentialSmoothing(
    df_monthly["COVID_deaths"], trend="mul", seasonal="mul", seasonal_periods=12
)

# Fit the model to the historical data
model_fit = model.fit()

# Forecast the next 6 months of COVID-19 deaths
forecast = model_fit.forecast(6)

# Create a Plotly figure to visualize the actual and forecasted deaths
fig = go.Figure()

# Add actual historical death data as a line with markers
fig.add_trace(
    go.Scatter(
        x=df_monthly["date"],
        y=df_monthly["COVID_deaths"],
        mode="lines+markers",
        name="Actual",
    )
)

# Generate the future date range corresponding to the forecast period (6 months ahead)
forecast_dates = pd.date_range(
    df_monthly["date"].iloc[-1] + pd.DateOffset(months=1), periods=6, freq="MS"
)

# Add forecasted data to the plot as another line
fig.add_trace(
    go.Scatter(x=forecast_dates, y=forecast, mode="lines+markers", name="Forecast")
)

# Customize the plot layout with titles and axis labels
fig.update_layout(
    title="COVID-19 Death Forecast (6 months)", xaxis_title="Date", yaxis_title="Deaths"
)

# Display the figure
fig.show()


# 🔹 Dashboard Visualization using Dash

In [24]:
# %pip install dash

In [25]:
# Note: The Dash app will run on http://127.0.0.1:8050/

# Import the dash framework
from dash import Dash, html, dcc, Input, Output

# Load data
df = pd.read_csv("COVID_19_Dataset.csv")
df["data_as_of"] = pd.to_datetime(df["data_as_of"], errors="coerce")
# Fix 'year' and 'month' columns first
df["year"] = (
    df["year"].astype(str).str.replace(",", "", regex=True).str.strip()
)
df["year"] = pd.to_numeric(df["year"], errors="coerce").astype("Int64")

df["month"] = pd.to_numeric(df["month"], errors="coerce").astype("Int64")

# Convert 'data_as_of' to datetime
df["data_as_of"] = pd.to_datetime(df["data_as_of"], errors="coerce")

# Create a proper 'date' column
df["date"] = pd.to_datetime(
    df[["year", "month"]].assign(day=1), errors="coerce"
)

# --- Clean numeric columns ---
for col in ["COVID_deaths", "crude_COVID_rate", "aa_COVID_rate"]:
    df[col] = (
        df[col]
        .astype(str)
        .str.replace(",", "", regex=True)  # remove commas like "1,234"
        .str.strip()
    )
    df[col] = pd.to_numeric(df[col], errors="coerce")


df = df.dropna(subset=["COVID_deaths"])

# Create dropdown options
jurisdictions = sorted(df["jurisdiction_residence"].unique())
years = sorted(df["year"].dropna().unique().astype(int))

# Clustering on the latest data
# Clustering on the latest data
df_latest = df[df["date"] == df["date"].max()]
df_latest = df_latest[df_latest["jurisdiction_residence"] != "United States"]

cluster_df = (
    df_latest.groupby("jurisdiction_residence", as_index=False)
    .agg({"COVID_deaths": "mean", "crude_COVID_rate": "mean"})
    .dropna()
)

X = StandardScaler().fit_transform(cluster_df[["COVID_deaths", "crude_COVID_rate"]])
kmeans = KMeans(n_clusters=3, random_state=42)
cluster_df["cluster"] = kmeans.fit_predict(X)


# Forecast model (always based on national data)
df_monthly = (
    df[df["jurisdiction_residence"] == "United States"]
    .groupby("date")["COVID_deaths"]
    .sum()
    .reset_index()
)
try:
    model = ExponentialSmoothing(
        df_monthly["COVID_deaths"], trend="mul", seasonal="mul", seasonal_periods=12
    )
    model_fit = model.fit()
    forecast = model_fit.forecast(6)
    forecast_dates = pd.date_range(
        df_monthly["date"].iloc[-1] + pd.DateOffset(months=1), periods=6, freq="MS"
    )
except:
    forecast = []
    forecast_dates = []

# Choropleth prep
region_to_states = {
    "Region 1": [
        "Connecticut",
        "Maine",
        "Massachusetts",
        "New Hampshire",
        "Rhode Island",
        "Vermont",
    ],
    "Region 2": ["New Jersey", "New York"],
    "Region 3": [
        "Delaware",
        "District of Columbia",
        "Maryland",
        "Pennsylvania",
        "Virginia",
        "West Virginia",
    ],
    "Region 4": [
        "Alabama",
        "Florida",
        "Georgia",
        "Kentucky",
        "Mississippi",
        "North Carolina",
        "South Carolina",
        "Tennessee",
    ],
    "Region 5": ["Illinois", "Indiana", "Michigan", "Minnesota", "Ohio", "Wisconsin"],
    "Region 6": ["Arkansas", "Louisiana", "New Mexico", "Oklahoma", "Texas"],
    "Region 7": ["Iowa", "Kansas", "Missouri", "Nebraska"],
    "Region 8": [
        "Colorado",
        "Montana",
        "North Dakota",
        "South Dakota",
        "Utah",
        "Wyoming",
    ],
    "Region 9": ["Arizona", "California", "Hawaii", "Nevada"],
    "Region 10": ["Alaska", "Idaho", "Oregon", "Washington"],
}
latest_date = df["date"].max()
df_latest_state = df[df["date"] == latest_date].copy()
df_latest_state["states"] = df_latest_state["jurisdiction_residence"].map(
    region_to_states
)
df_latest_state = df_latest_state.explode("states")
df_latest_state["jurisdiction_residence"] = df_latest_state["states"].combine_first(
    df_latest_state["jurisdiction_residence"]
)
df_state_level = df_latest_state.groupby("jurisdiction_residence", as_index=False).agg(
    {"crude_COVID_rate": "mean", "COVID_deaths": "sum"}
)

state_abbr = {
    "Alabama": "AL",
    "Alaska": "AK",
    "Arizona": "AZ",
    "Arkansas": "AR",
    "California": "CA",
    "Colorado": "CO",
    "Connecticut": "CT",
    "Delaware": "DE",
    "District of Columbia": "DC",
    "Florida": "FL",
    "Georgia": "GA",
    "Hawaii": "HI",
    "Idaho": "ID",
    "Illinois": "IL",
    "Indiana": "IN",
    "Iowa": "IA",
    "Kansas": "KS",
    "Kentucky": "KY",
    "Louisiana": "LA",
    "Maine": "ME",
    "Maryland": "MD",
    "Massachusetts": "MA",
    "Michigan": "MI",
    "Minnesota": "MN",
    "Mississippi": "MS",
    "Missouri": "MO",
    "Montana": "MT",
    "Nebraska": "NE",
    "Nevada": "NV",
    "New Hampshire": "NH",
    "New Jersey": "NJ",
    "New Mexico": "NM",
    "New York": "NY",
    "North Carolina": "NC",
    "North Dakota": "ND",
    "Ohio": "OH",
    "Oklahoma": "OK",
    "Oregon": "OR",
    "Pennsylvania": "PA",
    "Rhode Island": "RI",
    "South Carolina": "SC",
    "South Dakota": "SD",
    "Tennessee": "TN",
    "Texas": "TX",
    "Utah": "UT",
    "Vermont": "VT",
    "Virginia": "VA",
    "Washington": "WA",
    "West Virginia": "WV",
    "Wisconsin": "WI",
    "Wyoming": "WY",
}
df_state_level["jurisdiction_residence"] = df_state_level["jurisdiction_residence"].map(
    state_abbr
)

# App initialization
app = Dash(__name__)
app.title = "COVID-19 Deaths - United States"

# ===== GLOBAL STYLE =====
app.layout = html.Div(
    style={
        "fontFamily": "Arial, sans-serif",
        "backgroundColor": "#f8f9fa",
        "color": "#212529",
        "padding": "20px",
    },
    children=[
        # Header
        html.Div(
            className="header",
            style={
                "textAlign": "center",
                "marginBottom": "20px",
                "padding": "10px",
                "backgroundColor": "#2c3e50",
                "color": "white",
                "borderRadius": "10px",
                "boxShadow": "0 4px 6px rgba(0,0,0,0.2)",
            },
            children=[html.H1("COVID-19 Deaths - United States")],
        ),

        # Filters
        html.Div(
            className="filters-row",
            style={"display": "flex", "gap": "20px", "marginBottom": "20px"},
            children=[
                html.Div(
                    [
                        html.Label("Select Jurisdiction:", style={"fontWeight": "bold"}),
                        dcc.Dropdown(
                            id="jurisdiction-dropdown",
                            options=[{"label": j, "value": j} for j in jurisdictions],
                            value="United States",
                            style={"backgroundColor": "white"},
                        ),
                    ],
                    style={"flex": 1},
                ),
                html.Div(
                    [
                        html.Label("Select Year:", style={"fontWeight": "bold"}),
                        dcc.Dropdown(
                            id="year-dropdown",
                            options=[{"label": "Overall", "value": "Overall"}]
                            + [{"label": y, "value": y} for y in years],
                            value="Overall",
                            style={"backgroundColor": "white"},
                        ),
                    ],
                    style={"flex": 1},
                ),
            ],
        ),

        # KPI Row
        html.Div(
            className="kpi-row",
            style={
                "display": "flex",
                "justifyContent": "space-between",
                "flexWrap": "wrap",
                "gap": "15px",
                "marginBottom": "30px",
            },
            children=[
                html.Div(
                    className="kpi-card",
                    style={
                        "flex": "1",
                        "padding": "20px",
                        "backgroundColor": "white",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                        "textAlign": "center",
                    },
                    children=[html.H3("Total Deaths"), html.H1(id="total-deaths")],
                ),
                html.Div(
                    className="kpi-card",
                    style={
                        "flex": "1",
                        "padding": "20px",
                        "backgroundColor": "white",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                        "textAlign": "center",
                    },
                    children=[html.H3("Crude Rate"), html.H1(id="crude-rate")],
                ),
                html.Div(
                    className="kpi-card",
                    style={
                        "flex": "1",
                        "padding": "20px",
                        "backgroundColor": "white",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                        "textAlign": "center",
                    },
                    children=[
                        html.H3("Age-Adjusted Rate"),
                        html.H1(id="age-adjusted-rate"),
                    ],
                ),
                html.Div(
                    className="kpi-card",
                    style={
                        "flex": "1",
                        "padding": "20px",
                        "backgroundColor": "white",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                        "textAlign": "center",
                    },
                    children=[
                        html.H3("Crude Rate (Annualized)"),
                        html.H1(id="crude-rate-ann"),
                    ],
                ),
                html.Div(
                    className="kpi-card",
                    style={
                        "flex": "1",
                        "padding": "20px",
                        "backgroundColor": "white",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                        "textAlign": "center",
                    },
                    children=[
                        html.H3("Age-Adjusted Rate (Annualized)"),
                        html.H1(id="age-adjusted-rate-ann"),
                    ],
                ),
            ],
        ),

        # Charts Grid
        html.Div(
            className="chart-grid",
            style={
                "display": "grid",
                "gridTemplateColumns": "repeat(2, 1fr)",
                "gap": "20px",
            },
            children=[
                html.Div(
                    className="card",
                    style={
                        "backgroundColor": "white",
                        "padding": "15px",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                    },
                    children=[dcc.Graph(id="deaths-over-time")],
                ),
                html.Div(
                    className="card",
                    style={
                        "backgroundColor": "white",
                        "padding": "15px",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                    },
                    children=[dcc.Graph(id="deaths-by-race")],
                ),
                html.Div(
                    className="card",
                    style={
                        "backgroundColor": "white",
                        "padding": "15px",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                    },
                    children=[dcc.Graph(id="deaths-by-age")],
                ),
                html.Div(
                    className="card",
                    style={
                        "backgroundColor": "white",
                        "padding": "15px",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                    },
                    children=[dcc.Graph(id="cluster-chart")],
                ),
                html.Div(
                    className="card",
                    style={
                        "backgroundColor": "white",
                        "padding": "15px",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                        "gridColumn": "span 2",
                    },
                    children=[dcc.Graph(id="choropleth-map")],
                ),
                html.Div(
                    className="card",
                    style={
                        "backgroundColor": "white",
                        "padding": "15px",
                        "borderRadius": "12px",
                        "boxShadow": "0 4px 8px rgba(0,0,0,0.1)",
                        "gridColumn": "span 2",
                    },
                    children=[dcc.Graph(id="forecast-chart")],
                ),
            ],
        ),
    ]
)


# KPI Callback
@app.callback(
    Output("total-deaths", "children"),
    Output("crude-rate", "children"),
    Output("age-adjusted-rate", "children"),
    Output("crude-rate-ann", "children"),
    Output("age-adjusted-rate-ann", "children"),
    [Input("jurisdiction-dropdown", "value"), Input("year-dropdown", "value")],
)
def update_kpis(jurisdiction, year):
    # Filter data based on jurisdiction and year
    if year == "Overall":
        dff = df[df["jurisdiction_residence"] == jurisdiction]
    else:
        dff = df[(df["jurisdiction_residence"] == jurisdiction) & (df["year"] == year)]

    # Calculate KPIs
    total = int(dff["COVID_deaths"].sum())
    crude = dff["crude_COVID_rate"].mean()
    adj = dff["aa_COVID_rate"].mean()
    crude_ann = crude * 12 if pd.notna(crude) else None
    adj_ann = adj * 12 if pd.notna(adj) else None

    return (
        f"{total:,}",
        f"{crude:.2f}",
        f"{adj:.2f}",
        f"{crude_ann:.2f}",
        f"{adj_ann:.2f}",
    )


# Charts Callback
@app.callback(
    Output("deaths-over-time", "figure"),
    Output("deaths-by-race", "figure"),
    Output("deaths-by-age", "figure"),
    Output("cluster-chart", "figure"),
    Output("choropleth-map", "figure"),
    Output("forecast-chart", "figure"),
    [Input("jurisdiction-dropdown", "value"), Input("year-dropdown", "value")],
)
def update_dashboard(jurisdiction, year):
    # Filter data based on jurisdiction and year
    if year == "Overall":
        dff = df[df["jurisdiction_residence"] == jurisdiction]
    else:
        dff = df[(df["jurisdiction_residence"] == jurisdiction) & (df["year"] == year)]

    # Time Series
    ts_data = dff.groupby("date")["COVID_deaths"].sum().reset_index()
    ts_fig = px.line(
        ts_data,
        x="date",
        y="COVID_deaths",
        title="COVID-19 Deaths Over Time",
        labels={"date": "Date", "COVID_deaths": "Deaths"},
        markers=True,
    )

    # Race
    race_map = {
        "Hispanic": "Hispanic",
        "Non-Hispanic American Indian or Alaska Native": "AI/AN",
        "Non-Hispanic Asian": "Asian",
        "Non-Hispanic Asian, Native Hawaiian or Other Pacific Islander": "Asian/PI",
        "Non-Hispanic Black": "Black",
        "Non-Hispanic Native Hawaiian or Other Pacific Islander": "Native Hawaiian/PI",
        "Non-Hispanic White": "White",
    }
    race_df = (
        dff[dff["group"] == "Race"]
        .groupby("subgroup1")["COVID_deaths"]
        .sum()
        .reset_index()
    )
    race_df["race"] = race_df["subgroup1"].map(race_map).fillna(race_df["subgroup1"])
    race_fig = px.bar(
        race_df,
        x="race",
        y="COVID_deaths",
        title="COVID-19 Deaths by Race",
        labels={"race": "Race", "COVID_deaths": "Deaths"},
        color="COVID_deaths",
        color_continuous_scale="viridis",
    )
    race_fig.update_layout(xaxis_tickangle=45, legend_title="Deaths")

    # Age
    df_age = dff[dff["group"] == "Age"]
    df_age_grouped = df_age.groupby("subgroup1")["COVID_deaths"].sum().reset_index()

    df_age_grouped["age_sort"] = (
        df_age_grouped["subgroup1"].str.extract(r"(\d+)").astype(int)
    )
    df_age_grouped = df_age_grouped.sort_values("age_sort")

    age_fig = px.bar(
        df_age_grouped,
        x="subgroup1",
        y="COVID_deaths",
        title="COVID-19 Deaths by Age",
        labels={"subgroup1": "Age Group", "COVID_deaths": "Deaths"},
        color="COVID_deaths",
        color_continuous_scale="viridis",
    )
    age_fig.update_layout(xaxis_tickangle=45, legend_title="Deaths")

    # Cluster
    cluster_fig = px.scatter(
        cluster_df,
        x="COVID_deaths",
        y="crude_COVID_rate",
        color="cluster",
        hover_name="jurisdiction_residence",
        title="Clustering of Regions by COVID Death Rates",
        # color_continuous_scale='viridis',
        labels={"COVID_deaths": "COVID Deaths", "crude_COVID_rate": "Crude COVID Rate"},
    )

    # Choropleth
    choropleth_fig = px.choropleth(
        df_state_level,
        locations="jurisdiction_residence",
        locationmode="USA-states",
        color="crude_COVID_rate",
        hover_name="jurisdiction_residence",
        title="State-wise COVID Death Rates",
        labels={"crude_COVID_rate": "Crude Rate"},
        color_continuous_scale="blues",
        scope="usa",
    )

    # Forecast
    forecast_fig = go.Figure()
    forecast_fig.add_trace(
        go.Scatter(
            x=df_monthly["date"],
            y=df_monthly["COVID_deaths"],
            mode="lines+markers",
            name="Actual",
        )
    )
    if len(forecast) > 0:
        forecast_fig.add_trace(
            go.Scatter(
                x=forecast_dates, y=forecast, mode="lines+markers", name="Forecast"
            )
        )
    forecast_fig.update_layout(
        title="COVID-19 Death Forecast", xaxis_title="Date", yaxis_title="Deaths"
    )

    return ts_fig, race_fig, age_fig, cluster_fig, choropleth_fig, forecast_fig


if __name__ == "__main__":
    app.run(debug=True)


# Thank You