<a href="https://colab.research.google.com/github/DelMashiry-dev/DelMashiry-dev/blob/main/RTMRP%20PROTOTYPE%20ver.2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
from google.colab import userdata

# Load NewsAPI key
try:
    newsapi_key = userdata.get('NEWSAPI_KEY')  # Set in Colab Secrets
    print("NewsAPI key loaded successfully")
except Exception as e:
    print(f"Failed to load NewsAPI key: {e}")
    newsapi_key = "YOUR_API_KEY_HERE"  # Replace with your actual key if not using Colab Secrets

# Load OWID data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)
df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None)  # Ensure timezone-naive
df = df[['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
         'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']]

# Simulate healthcare capacity data
capacity_data = {
    'location': ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany', 'France', 'Italy', 'Spain',
                'China', 'Japan', 'South Korea', 'Australia', 'Canada', 'Mexico'],
    'hospital_beds': [10000, 5000, 3000, 8000, 9000, 7500, 7000, 6000, 12000, 9500, 8500, 6500, 7800, 4000],
    'icu_beds': [2000, 1000, 500, 1600, 1800, 1500, 1400, 1200, 2400, 1900, 1700, 1300, 1560, 800],
    'ventilators': [500, 300, 200, 400, 450, 375, 350, 300, 600, 475, 425, 325, 390, 200]
}
capacity_df = pd.DataFrame(capacity_data)

# Fetch NewsAPI data
def fetch_promed_data(api_key, date_range_start=None, date_range_end=None):
    try:
        url = "https://newsapi.org/v2/everything"
        params = {
            "q": "disease outbreak OR epidemic OR health emergency OR infectious disease OR COVID-19",
            "language": "en",
            "sortBy": "publishedAt",
            "apiKey": api_key,
            "pageSize": 30
        }
        if date_range_start and date_range_end:
            params["from"] = date_range_start.strftime('%Y-%m-%d')
            params["to"] = date_range_end.strftime('%Y-%m-%d')

        response = requests.get(url, params=params, timeout=10)

        if response.status_code == 200:
            articles = response.json().get("articles", [])
            if not articles:
                raise Exception("No articles found")

            data = []
            for article in articles:
                date = pd.to_datetime(article["publishedAt"]).tz_localize(None).normalize()  # Date-only
                text = (article["title"] + " " + (article.get("content", "") or article["description"])).lower()

                signal = (
                    100 if any(word in text for word in ["outbreak", "epidemic", "pandemic"]) else
                    50 if any(word in text for word in ["health", "disease", "emergency"]) else
                    10
                )

                location = (
                    "USA" if any(word in text for word in ["usa", "united states", "america"]) else
                    "India" if any(word in text for word in ["india", "indian"]) else
                    "Brazil" if any(word in text for word in ["brazil", "brazilian"]) else
                    "Unknown"
                )

                data.append({"date": date, "location": location, "disease_signal": signal})

            df = pd.DataFrame(data)
            if df.empty:
                raise Exception("No relevant data parsed")

            df["date"] = pd.to_datetime(df["date"]).dt.tz_localize(None).dt.normalize()
            print("NewsAPI data fetched:", df.head())
            return df

        else:
            raise Exception(f"API request failed with status {response.status_code}")

    except Exception as e:
        print(f"Error fetching NewsAPI data: {e}")
        owid_dates = pd.date_range(start=date_range_start, end=date_range_end, freq='D')[-30:] if date_range_start and date_range_end else \
                     [pd.to_datetime(datetime.now().date() - timedelta(days=x)).tz_localize(None).normalize() for x in range(30)]
        fallback = pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": ["USA"] * len(owid_dates),
            "disease_signal": np.random.randint(10, 100, len(owid_dates))
        })
        print("Using fallback data")
        return fallback

# Prepare features
def prepare_features(df, location, target='total_cases', lookback=7):
    df_loc = df[df['location'] == location].sort_values('date')

    df_loc = df_loc.dropna(subset=[target])

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data for {location} with non-null {target}")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)

    date_start = df_loc['date'].min()
    date_end = df_loc['date'].max()
    news_data = fetch_promed_data(newsapi_key, date_start, date_end)
    if 'disease_signal' not in news_data.columns:
        news_data['disease_signal'] = 0

    df_loc['date'] = df_loc['date'].dt.normalize()
    news_data['date'] = news_data['date'].dt.normalize()

    df_loc = df_loc.merge(news_data[['date', 'disease_signal']], on='date', how='left')

    if 'disease_signal' not in df_loc.columns:
        df_loc['disease_signal'] = 0
    df_loc['disease_signal'] = df_loc['disease_signal'].fillna(0)

    feature_cols = [target] + [f'lag_{i}' for i in range(1, lookback + 1)]
    df_loc = df_loc.dropna(subset=feature_cols)

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data after processing for {location}, returning fallback")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    return df_loc

# Random Forest
def train_random_forest(df, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features]
    y = df[target]

    if X.empty or X.shape[0] < 1:
        raise ValueError("No samples available for training RandomForestRegressor")

    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X, y)

    importance = pd.DataFrame({'feature': features, 'importance': model.feature_importances_})
    return model, importance

# Prophet
def train_prophet(df, target='total_cases'):
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})
    if prophet_df.shape[0] < 2:
        raise ValueError("Dataframe has less than 2 non-NaN rows for training Prophet")
    model = Prophet(yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=True)
    model.fit(prophet_df)
    return model

# Ensemble
def ensemble_predict(rf_model, prophet_model, df, forecast_days, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features].tail(1)
    rf_pred = rf_model.predict(X)

    future = prophet_model.make_future_dataframe(periods=forecast_days)
    prophet_pred = prophet_model.predict(future)['yhat'].tail(forecast_days).values

    ensemble_pred = (rf_pred[0] + prophet_pred[0]) / 2
    return ensemble_pred

# Forecasting
def forecast_resource(df, location, forecast_days, model_type, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)

        if model_type == 'Random Forest':
            model, importance = train_random_forest(df_loc, target)
            features = [col for col in df_loc.columns if 'lag_' in col or col == 'disease_signal']
            X = df_loc[features].tail(1)
            pred = model.predict(X)
            return pred, importance, None

        elif model_type == 'Prophet':
            model = train_prophet(df_loc, target)
            future = model.make_future_dataframe(periods=forecast_days)
            forecast = model.predict(future)
            return forecast['yhat'].tail(forecast_days).values, None, forecast[['ds', 'yhat_lower', 'yhat_upper']]

        else:  # Ensemble
            rf_model, importance = train_random_forest(df_loc, target)
            prophet_model = train_prophet(df_loc, target)
            pred = ensemble_predict(rf_model, prophet_model, df_loc, forecast_days, target)
            return pred, importance, None
    except Exception as e:
        print(f"Error in forecast_resource: {e}")
        return np.array([0]), None, None

# Model evaluation
def evaluate_model(df, location, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        train_size = int(0.8 * len(df_loc))
        train_df, test_df = df_loc[:train_size], df_loc[train_size:]

        if test_df.empty:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        rf_model, _ = train_random_forest(train_df, target)
        features = [col for col in test_df.columns if 'lag_' in col or col == 'disease_signal']
        rf_pred = rf_model.predict(test_df[features])
        rf_rmse = np.sqrt(mean_squared_error(test_df[target], rf_pred))
        rf_r2 = r2_score(test_df[target], rf_pred)

        prophet_model = train_prophet(train_df, target)
        future = prophet_model.make_future_dataframe(periods=len(test_df))
        forecast = prophet_model.predict(future)
        prophet_pred = forecast['yhat'].tail(len(test_df)).values
        prophet_rmse = np.sqrt(mean_squared_error(test_df[target], prophet_pred))
        prophet_r2 = r2_score(test_df[target], prophet_pred)

        return {
            'Random Forest': {'RMSE': rf_rmse, 'R2': rf_r2},
            'Prophet': {'RMSE': prophet_rmse, 'R2': prophet_r2}
        }
    except Exception as e:
        print(f"Error in evaluate_model: {e}")
        return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

# Calculate resource allocation
def calculate_resources(predicted_cases, country):
    # Define resource ratios based on healthcare capacity and case severity
    resource_ratios = {
        'hospital_beds': 0.15,  # 15% of cases require hospitalization
        'icu_beds': 0.05,       # 5% of cases require ICU
        'ventilators': 0.02,    # 2% of cases require ventilators
        'medical_staff': 0.3,   # Staff needed per hospitalized patient
        'ppe_kits': 10,         # PPE kits per hospitalized patient per day
        'test_kits': 2,         # Test kits per case
    }

    # Get country capacity
    country_capacity = capacity_df[capacity_df['location'] == country]
    if country_capacity.empty:
        # Use default values if country not found
        capacity = {
            'hospital_beds': 5000,
            'icu_beds': 1000,
            'ventilators': 300
        }
    else:
        capacity = {
            'hospital_beds': country_capacity['hospital_beds'].iloc[0],
            'icu_beds': country_capacity['icu_beds'].iloc[0],
            'ventilators': country_capacity['ventilators'].iloc[0]
        }

    # Calculate needed resources
    predicted_value = predicted_cases[0] if isinstance(predicted_cases, np.ndarray) else predicted_cases
    resources_needed = {
        'hospital_beds': int(predicted_value * resource_ratios['hospital_beds']),
        'icu_beds': int(predicted_value * resource_ratios['icu_beds']),
        'ventilators': int(predicted_value * resource_ratios['ventilators']),
        'medical_staff': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['medical_staff']),
        'ppe_kits': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['ppe_kits']),
        'test_kits': int(predicted_value * resource_ratios['test_kits'])
    }

    # Calculate shortages
    shortages = {
        'hospital_beds': max(0, resources_needed['hospital_beds'] - capacity['hospital_beds']),
        'icu_beds': max(0, resources_needed['icu_beds'] - capacity['icu_beds']),
        'ventilators': max(0, resources_needed['ventilators'] - capacity['ventilators'])
    }

    return resources_needed, capacity, shortages

# Run retraining
for _ in range(2):  # Reduced to 2 for faster startup
    rf_model, prophet_model, df = retrain_model(df, 'USA', newsapi_key)
    if rf_model is None or prophet_model is None:
        print("Retraining skipped due to invalid data")
    else:
        print("Retraining complete")

# Create filtered data for map
map_data = df.groupby('location').last().reset_index()
map_data = map_data.dropna(subset=['total_cases'])

# Dashboard
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Add CSS for custom styling
app.index_string = '''
<!DOCTYPE html>
<html>
    <head>
        {%metas%}
        <title>{%title%}</title>
        {%favicon%}
        {%css%}
        <style>
            .card {
                box-shadow: 0 4px 8px rgba(0,0,0,0.1);
                transition: 0.3s;
                border-radius: 10px;
                margin-bottom: 20px;
            }
            .card:hover {
                box-shadow: 0 8px 16px rgba(0,0,0,0.2);
            }
            .card-header {
                background-color: #3498db;
                color: white;
                border-radius: 10px 10px 0 0 !important;
                font-weight: bold;
            }
            .nav-tabs .nav-link.active {
                background-color: #3498db;
                color: white;
            }
            .tab-content {
                border: 1px solid #dee2e6;
                border-top: 0;
                padding: 15px;
            }
            .resource-alert {
                padding: 10px;
                border-radius: 5px;
                margin-bottom: 10px;
            }
            .dashboard-title {
                color: #2c3e50;
                text-align: center;
                margin-bottom: 30px;
                font-weight: bold;
            }
        </style>
    </head>
    <body>
        {%app_entry%}
        <footer>
            {%config%}
            {%scripts%}
            {%renderer%}
        </footer>
    </body>
</html>
'''

app.layout = dbc.Container([
    html.H1("Medical Resource Demand Prediction Dashboard", className="dashboard-title mt-4 mb-4"),

    dbc.Tabs([
        dbc.Tab(label="Predictions & Resources", children=[
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Control Panel"),
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col([
                                    html.Label("Select Country", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='country-dropdown',
                                        options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                        value='USA',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Forecast Days", style={'fontWeight': 'bold'}),
                                    dcc.Slider(
                                        id='forecast-days',
                                        min=7,
                                        max=90,
                                        step=1,
                                        value=30,
                                        marks={i: str(i) for i in range(7, 91, 14)},
                                        tooltip={"placement": "bottom", "always_visible": True}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Model Type", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='model-type',
                                        options=[
                                            {'label': 'Random Forest', 'value': 'Random Forest'},
                                            {'label': 'Prophet', 'value': 'Prophet'},
                                            {'label': 'Ensemble', 'value': 'Ensemble'}
                                        ],
                                        value='Ensemble',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4)
                            ]),
                            html.Label("Target Variable", className="mt-3", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='target-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            ),
                            dbc.Button("Run Prediction", id='run-btn', color='primary', className='mt-3', style={'width': '100%'})
                        ])
                    ])
                ], width=12)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Prediction Results"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-prediction",
                                type="circle",
                                children=dcc.Graph(id='prediction-plot')
                            )
                        ])
                    ])
                ], width=12, lg=8),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Required Medical Resources"),
                        dbc.CardBody([
                            html.Div(id='resources-output')
                        ])
                    ])
                ], width=12, lg=4)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Feature Importance"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-features",
                                type="circle",
                                children=dcc.Graph(id='feature-importance-plot')
                            )
                        ])
                    ])
                ], width=12, lg=6),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Model Performance Metrics"),
                        dbc.CardBody([
                            html.Div(id='performance-metrics')
                        ])
                    ])
                ], width=12, lg=6)
            ])
        ]),

        dbc.Tab(label="Global Impact Map", children=[
            dbc.Card([
                dbc.CardHeader("Global Disease Distribution"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Visualization", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='map-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'Cases per Million', 'value': 'cases_per_million'},
                                    {'label': 'Deaths per Million', 'value': 'deaths_per_million'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Map Type", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='map-type',
                                options=[
                                    {'label': 'Choropleth Map', 'value': 'choropleth'},
                                    {'label': 'Bubble Map', 'value': 'bubble'}
                                ],
                                value='choropleth',
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-map",
                        type="circle",
                        children=dcc.Graph(id='global-map', style={'height': '70vh'})
                    )
                ])
            ])
        ]),

        dbc.Tab(label="News & Disease Signals", children=[
            dbc.Card([
                dbc.CardHeader("Real-time Disease Signal Analysis"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Country", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='news-country-dropdown',
                                options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                value='USA',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Date Range", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='news-timeframe',
                                options=[
                                    {'label': 'Last 30 days', 'value': 30},
                                    {'label': 'Last 90 days', 'value': 90},
                                    {'label': 'Last 180 days', 'value': 180}
                                ],
                                value=30,
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-news",
                        type="circle",
                        children=[
                            dcc.Graph(id='news-signal-plot'),
                            html.Div(id='news-headlines', className="mt-4")
                        ]
                    )
                ])
            ])
        ])
    ])
], fluid=True)

# Prepare models and data
def retrain_model(df, location, api_key, target='total_cases', interval_hours=24):
    try:
        last_date = df['date'].max()
        new_data = fetch_promed_data(api_key, df['date'].min(), df['date'].max())

        if new_data.empty or not pd.api.types.is_datetime64_any_dtype(new_data['date']):
            print("No valid new data to concatenate")
            return None, None, df

        new_data['date'] = pd.to_datetime(new_data['date']).dt.tz_localize(None).dt.normalize()
        new_data_max = new_data['date'].max()

        if pd.notna(new_data_max) and new_data_max > last_date:
            new_data = new_data[['date', 'location', 'disease_signal']]
            for col in df.columns:
                if col not in new_data.columns:
                    new_data[col] = np.nan
            df = pd.concat([df, new_data], ignore_index=True)
            df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None).dt.normalize()

        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            print(f"Warning: Insufficient data for {location} after prepare_features")
            return None, None, df

        rf_model, _ = train_random_forest(df_loc, target)
        prophet_model = train_prophet(df_loc, target)
        return rf_model, prophet_model, df
    except Exception as e:
        print(f"Error in retrain_model: {e}")
        return None, None, df

# Calculate derived metrics
map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

# Callbacks
@app.callback(
    [Output('prediction-plot', 'figure'),
     Output('feature-importance-plot', 'figure'),
     Output('performance-metrics', 'children'),
     Output('resources-output', 'children')],
    [Input('run-btn', 'n_clicks')],
    [State('country-dropdown', 'value'),
     State('forecast-days', 'value'),
     State('model-type', 'value'),
     State('target-variable', 'value')]
)
def update_prediction_dashboard(n_clicks, country, forecast_days, model_type, target_variable):
    if n_clicks is None:
        pred_fig = go.Figure()
        pred_fig.update_layout(title="No predictions yet - press 'Run Prediction'")

        imp_fig = go.Figure()
        imp_fig.update_layout(title="No feature importance data yet")

        return pred_fig, imp_fig, "No performance metrics yet", "No resource calculation yet"

    try:
        # Run prediction
        pred, importance, confidence = forecast_resource(df, country, forecast_days, model_type, target_variable)

        # Create prediction plot
        pred_fig = go.Figure()
        pred_dates = [datetime.now() + timedelta(days=i) for i in range(forecast_days)]
        pred_values = pred if not isinstance(pred, float) else [pred] * forecast_days

        # Add historical data to the plot
        if target_variable in ['total_cases', 'total_deaths', 'new_cases', 'new_deaths', 'hosp_patients', 'icu_patients']:
            historical = df[(df['location'] == country) & (df[target_variable].notna())]
            if not historical.empty:
                pred_fig.add_trace(go.Scatter(
                    x=historical['date'][-30:],  # Last 30 days
                    y=historical[target_variable][-30:],
                    name='Historical Data',
                    line=dict(color='royalblue')
                ))

        # Add prediction line
        pred_fig.add_trace(go.Scatter(
            x=pred_dates,
            y=pred_values,
            name='Prediction',
            line=dict(color='#e74c3c', dash='dash', width=3)
        ))

        # Add confidence intervals if available
        if confidence is not None:
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_lower'],
                fill=None,
                mode='lines',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Lower CI'
            ))
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_upper'],
                fill='tonexty',
                fillcolor='rgba(231, 76, 60, 0.2)',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Upper CI'
            ))

        target_name = {
            'total_cases': 'Total Cases',
            'new_cases': 'New Cases',
            'total_deaths': 'Total Deaths',
            'new_deaths': 'New Deaths',
            'hosp_patients': 'Hospitalized Patients',
            'icu_patients': 'ICU Patients'
        }.get(target_variable, target_variable)

        pred_fig.update_layout(
            title=f"{model_type} Prediction for {target_name} in {country}",
            xaxis_title="Date",
            yaxis_title=target_name,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            template="plotly_white",
            margin=dict(l=40, r=40, t=60, b=40)
        )

        # Create feature importance plot
        imp_fig = go.Figure()
        if importance is not None and not importance.empty:
            # Sort by importance
            importance = importance.sort_values('importance', ascending=False)

            imp_fig = px.bar(
                importance,
                x='feature',
                y='importance',
                title='Feature Importance Analysis',
                labels={'feature': 'Feature', 'importance': 'Importance Score'},
                color='importance',
                color_continuous_scale='Viridis'
            )

            imp_fig.update_layout(
                xaxis_title="Feature",
                yaxis_title="Importance Score",
                template="plotly_white",
                margin=dict(l=40, r=40, t=60, b=40)
            )
        else:
            imp_fig.update_layout(title="No feature importance data available for this model")

        # Performance metrics
        metrics = evaluate_model(df, country, target_variable)
        metrics_output = dbc.Card([
            dbc.Row([
                dbc.Col([
                    html.H5("Random Forest", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Random Forest']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Random Forest']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Random Forest']['R2'] * 100))}%",
                                    "backgroundColor": "#2ecc71"
                                }
                            )
                        ])
                    ])
                ], width=6),
                dbc.Col([
                    html.H5("Prophet", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Prophet']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Prophet']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Prophet']['R2'] * 100))}%",
                                    "backgroundColor": "#3498db"
                                }
                            )
                        ])
                    ])
                ], width=6)
            ])
        ], body=True)

        # Resource calculation
        resources_needed, capacity, shortages = calculate_resources(pred_values[0], country)

        def get_alert_color(used, capacity):
            utilization = (used / capacity) if capacity > 0 else 1
            if utilization < 0.7:
                return "success"
            elif utilization < 0.9:
                return "warning"
            else:
                return "danger"

        resources_output = html.Div([
            html.H5(f"Projected Resource Needs ({forecast_days} days)"),
            html.Div([
                dbc.Alert([
                    html.Div([
                        html.Strong("Hospital Beds"),
                        html.Span(f": {resources_needed['hospital_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['hospital_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['hospital_beds'] / capacity['hospital_beds'] * 100) if capacity['hospital_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['hospital_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['hospital_beds']}", className="text-danger")
                    if shortages['hospital_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['hospital_beds'], capacity['hospital_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("ICU Beds"),
                        html.Span(f": {resources_needed['icu_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['icu_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['icu_beds'] / capacity['icu_beds'] * 100) if capacity['icu_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['icu_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['icu_beds']}", className="text-danger")
                    if shortages['icu_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['icu_beds'], capacity['icu_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Ventilators"),
                        html.Span(f": {resources_needed['ventilators']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['ventilators']} units"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['ventilators'] / capacity['ventilators'] * 100) if capacity['ventilators'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['ventilators'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['ventilators']}", className="text-danger")
                    if shortages['ventilators'] > 0 else None
                ], color=get_alert_color(resources_needed['ventilators'], capacity['ventilators']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Additional Resources"),
                        html.Ul([
                            html.Li(f"Medical Staff: {resources_needed['medical_staff']} needed"),
                            html.Li(f"PPE Kits: {resources_needed['ppe_kits']} needed"),
                            html.Li(f"Test Kits: {resources_needed['test_kits']} needed")
                        ])
                    ])
                ], color="info")
            ])
        ])

        return pred_fig, imp_fig, metrics_output, resources_output

    except Exception as e:
        print(f"Error in update_prediction_dashboard: {e}")
        return go.Figure(), go.Figure(), f"Error: {str(e)}", f"Error: {str(e)}"

@app.callback(
    Output('global-map', 'figure'),
    [Input('map-variable', 'value'),
     Input('map-type', 'value')]
)
def update_map(map_variable, map_type):
    try:
        # Calculate per million values if not already done
        if 'cases_per_million' not in map_data.columns:
            map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
        if 'deaths_per_million' not in map_data.columns:
            map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

        # Get variable name for display
        variable_name = {
            'total_cases': 'Total Cases',
            'total_deaths': 'Total Deaths',
            'cases_per_million': 'Cases per Million',
            'deaths_per_million': 'Deaths per Million'
        }.get(map_variable, map_variable)

        # Create appropriate map
        if map_type == 'choropleth':
            fig = px.choropleth(
                map_data,
                locations='location',
                locationmode='country names',
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )
        else:  # bubble map
            fig = px.scatter_geo(
                map_data,
                locations='location',
                locationmode='country names',
                size=map_variable,
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                size_max=50,
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )

        fig.update_layout(
            height=700,  # Make the map larger
            margin=dict(l=0, r=0, t=50, b=0),
            template="plotly_white",
            coloraxis_colorbar=dict(
                title=variable_name
            )
        )

        return fig

    except Exception as e:
        print(f"Error in update_map: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating map: {str(e)}")
        return fig

@app.callback(
    [Output('news-signal-plot', 'figure'),
     Output('news-headlines', 'children')],
    [Input('news-country-dropdown', 'value'),
     Input('news-timeframe', 'value')]
)
def update_news_signal(country, days):
    try:
        end_date = datetime.now()
        start_date = end_date - timedelta(days=days)

        # Get disease signal data
        news_data = fetch_promed_data(newsapi_key, start_date, end_date)
        if country != 'Global':
            news_data = news_data[news_data['location'] == country]

        # If no data, return empty figures
        if news_data.empty:
            fig = go.Figure()
            fig.update_layout(
                title=f"No disease signal data available for {country} in last {days} days",
                template="plotly_white"
            )
            return fig, "No news headlines available"

        # Summarize by date
        news_by_date = news_data.groupby('date')['disease_signal'].mean().reset_index()

        # Create signal plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=news_by_date['date'],
            y=news_by_date['disease_signal'],
            mode='lines+markers',
            name='Disease Signal',
            line=dict(color='#e74c3c', width=3)
        ))

        # Add actual cases if available
        cases_data = df[(df['location'] == country) &
                    (df['date'] >= start_date) &
                    (df['date'] <= end_date) &
                    (df['new_cases'].notna())]

        if not cases_data.empty:
            # Normalize for display
            max_signal = news_by_date['disease_signal'].max()
            max_cases = cases_data['new_cases'].max()
            scale_factor = max_signal / max_cases if max_cases > 0 else 1

            fig.add_trace(go.Scatter(
                x=cases_data['date'],
                y=cases_data['new_cases'] * scale_factor,
                mode='lines',
                name='New Cases (Scaled)',
                yaxis='y2',
                line=dict(color='#3498db', width=2, dash='dot')
            ))

        fig.update_layout(
            title=f"Disease Signal for {country} - Last {days} Days",
            xaxis_title="Date",
            yaxis_title="Disease Signal",
            template="plotly_white",
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            margin=dict(l=40, r=40, t=60, b=40),
            yaxis2=dict(
                title="New Cases (Scaled)",
                overlaying="y",
                side="right",
                showgrid=False
            ) if not cases_data.empty else {}
        )

        # Create simulated news headlines
        headlines = [
            {"title": "New outbreak reported in northeastern region", "score": 95},
            {"title": "Health officials monitoring unusual respiratory symptoms", "score": 80},
            {"title": "Vaccination campaign launched to combat emerging strain", "score": 70},
            {"title": "Hospital capacity strained amid rising cases", "score": 85},
            {"title": "Travel restrictions considered as cases increase", "score": 75}
        ]

        headlines_output = html.Div([
            html.H5("Recent Disease-Related Headlines"),
            dbc.ListGroup([
                dbc.ListGroupItem([
                    html.Div([
                        html.Strong(headline["title"]),
                        html.Span(
                            f" - Signal Score: {headline['score']}",
                            className="ml-2",
                            style={"color": "#e74c3c" if headline["score"] > 80 else "#f39c12" if headline["score"] > 60 else "#2ecc71"}
                        )
                    ])
                ]) for headline in headlines
            ])
        ])

        return fig, headlines_output

    except Exception as e:
        print(f"Error in update_news_signal: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating news signal: {str(e)}")
        return fig, f"Error: {str(e)}"

# Run the app
if __name__ == '__main__':
    try:
        app.run_server(mode='inline', debug=True)
    except Exception as e:
        print(f"Error running Dash app: {e}")
        print("Trying external mode...")
        app.run(mode='external', debug=True, port=8050)


In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
import os
from flask import send_from_directory

# Configuration
newsapi_key = os.environ.get('NEWSAPI_KEY', 'your_api_key_here')  # Replace with your key

# Load COVID-19 data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)
df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None)
df = df[['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
         'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']]

# Healthcare capacity data
capacity_data = {
    'location': ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany', 'France', 'Italy', 'Spain',
                'China', 'Japan', 'South Korea', 'Australia', 'Canada', 'Mexico'],
    'hospital_beds': [10000, 5000, 3000, 8000, 9000, 7500, 7000, 6000, 12000, 9500, 8500, 6500, 7800, 4000],
    'icu_beds': [2000, 1000, 500, 1600, 1800, 1500, 1400, 1200, 2400, 1900, 1700, 1300, 1560, 800],
    'ventilators': [500, 300, 200, 400, 450, 375, 350, 300, 600, 475, 425, 325, 390, 200]
}
capacity_df = pd.DataFrame(capacity_data)

# News API integration
def fetch_promed_data(api_key, date_range_start=None, date_range_end=None):
    try:
        url = "https://newsapi.org/v2/everything"
        params = {
            "q": "disease outbreak OR epidemic OR health emergency OR infectious disease OR COVID-19",
            "language": "en",
            "sortBy": "publishedAt",
            "apiKey": api_key,
            "pageSize": 30
        }
        if date_range_start and date_range_end:
            params["from"] = date_range_start.strftime('%Y-%m-%d')
            params["to"] = date_range_end.strftime('%Y-%m-%d')

        response = requests.get(url, params=params, timeout=10)
        articles = response.json().get("articles", [])

        data = []
        for article in articles:
            date = pd.to_datetime(article["publishedAt"]).tz_localize(None).normalize()
            text = (article["title"] + " " + (article.get("content", "") or article["description"])).lower()
            signal = 100 if any(word in text for word in ["outbreak", "epidemic", "pandemic"]) else 50 if any(word in text for word in ["health", "disease", "emergency"]) else 10
            location = "USA" if any(word in text for word in ["usa", "united states", "america"]) else "India" if any(word in text for word in ["india", "indian"]) else "Brazil" if any(word in text for word in ["brazil", "brazilian"]) else "Unknown"
            data.append({"date": date, "location": location, "disease_signal": signal})

        return pd.DataFrame(data).dropna() if data else pd.DataFrame()
    except Exception as e:
        print(f"News API Error: {e}")
        return pd.DataFrame()

# Model training functions
def prepare_features(df, location, target='total_cases', lookback=7):
    df_loc = df[df['location'] == location].sort_values('date').dropna(subset=[target])
    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)
    news_data = fetch_promed_data(newsapi_key, df_loc['date'].min(), df_loc['date'].max())
    return df_loc.merge(news_data[['date', 'disease_signal']], on='date', how='left').fillna(0)

def train_random_forest(df, target):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(df[features], df[target])
    return model, pd.DataFrame({'feature': features, 'importance': model.feature_importances_})

def train_prophet(df, target):
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})
    model = Prophet(yearly_seasonality=True, weekly_seasonality=True)
    model.fit(prophet_df)
    return model

# Dashboard setup
app = Dash(__name__,
          external_stylesheets=[dbc.themes.BOOTSTRAP, 'https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css'],
          assets_folder='assets')

server = app.server

# Create assets directory if not exists
if not os.path.exists('assets'):
    os.makedirs('assets')

# Fullscreen CSS
with open('assets/fullscreen.css', 'w') as f:
    f.write('''
    .fullscreen-enabled {
        background: white;
        position: fixed !important;
        top: 0 !important;
        left: 0 !important;
        width: 100vw !important;
        height: 100vh !important;
        z-index: 9999;
        padding: 20px;
        overflow: auto;
    }
    .fullscreen-button {
        position: absolute;
        top: 10px;
        right: 10px;
        z-index: 1000;
        background: rgba(255,255,255,0.7);
        border: none;
        border-radius: 4px;
        padding: 5px 10px;
        cursor: pointer;
    }
    .fullscreen-button:hover {
        background: rgba(255,255,255,0.9);
    }
    ''')

# Fullscreen JavaScript
with open('assets/fullscreen.js', 'w') as f:
    f.write('''
    document.addEventListener('DOMContentLoaded', function() {
        function initFullscreen() {
            document.querySelectorAll('.dashboard-panel').forEach(panel => {
                const btn = document.createElement('button');
                btn.className = 'fullscreen-button';
                btn.innerHTML = '<i class="fas fa-expand"></i>';
                btn.onclick = () => panel.classList.toggle('fullscreen-enabled');
                panel.style.position = 'relative';
                panel.appendChild(btn);
            });
        }
        new MutationObserver(initFullscreen).observe(document, {childList: true, subtree: true});
    });
    ''')

# Serve static files
@server.route('/assets/<path:path>')
def serve_static(path):
    return send_from_directory('assets', path)

# Dashboard layout
app.layout = dbc.Container([
    html.H1("Medical Resource Forecasting Dashboard", className="text-center my-4"),

    dbc.Tabs([
        dbc.Tab(label="Predictions", children=[
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Controls"),
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col(dcc.Dropdown(
                                    id='country-dropdown',
                                    options=[{'label': l, 'value': l} for l in df['location'].unique()],
                                    value='USA'
                                ), md=4),
                                dbc.Col(dcc.Slider(
                                    id='forecast-days',
                                    min=7, max=90, value=30,
                                    marks={i: str(i) for i in range(7, 91, 14)}
                                ), md=4),
                                dbc.Col(dcc.Dropdown(
                                    id='model-type',
                                    options=[
                                        {'label': 'Random Forest', 'value': 'RF'},
                                        {'label': 'Prophet', 'value': 'Prophet'},
                                        {'label': 'Ensemble', 'value': 'Ensemble'}
                                    ], value='Ensemble'
                                ), md=4)
                            ]),
                            dbc.Button("Run Forecast", id='run-btn', color='primary', className='mt-3')
                        ])
                    ])
                ])
            ]),
            dbc.Row([
                dbc.Col(dbc.Card([
                    dbc.CardHeader("Forecast Results"),
                    dbc.CardBody(dcc.Graph(id='forecast-plot', className='dashboard-panel'))
                ]), lg=8),
                dbc.Col(dbc.Card([
                    dbc.CardHeader("Resource Needs"),
                    dbc.CardBody(html.Div(id='resource-output', className='dashboard-panel'))
                ]), lg=4)
            ], className='my-4')
        ]),

        dbc.Tab(label="Global View", children=[
            dbc.Card([
                dbc.CardHeader("Worldwide Impact"),
                dbc.CardBody([
                    dcc.Dropdown(
                        id='map-metric',
                        options=[
                            {'label': 'Total Cases', 'value': 'total_cases'},
                            {'label': 'Deaths per Million', 'value': 'deaths_per_million'}
                        ], value='total_cases'
                    ),
                    dcc.Graph(id='world-map', className='dashboard-panel')
                ])
            ])
        ])
    ])
], fluid=True)

# Callbacks
@app.callback(
    [Output('forecast-plot', 'figure'),
     Output('resource-output', 'children')],
    [Input('run-btn', 'n_clicks')],
    [State('country-dropdown', 'value'),
     State('forecast-days', 'value'),
     State('model-type', 'value')]
)
def update_forecast(n, country, days, model):
    if n is None:
        return go.Figure(), "Select parameters and click Run Forecast"

    # Model prediction logic
    df_loc = prepare_features(df, country)
    pred = np.random.randint(1000, 5000)  # Replace with actual model prediction

    # Create forecast plot
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=df_loc['date'], y=df_loc['total_cases'],
        name='Historical', line=dict(color='blue')))
    fig.add_trace(go.Scatter(
        x=[datetime.now() + timedelta(days=i) for i in range(days)],
        y=[pred] * days,
        name='Forecast', line=dict(color='red', dash='dash')))

    # Resource calculation
    resources = dbc.ListGroup([
        dbc.ListGroupItem(f"Hospital Beds Needed: {pred//100}"),
        dbc.ListGroupItem(f"ICU Capacity: {pred//500}"),
        dbc.ListGroupItem(f"Ventilators Required: {pred//1000}")
    ])

    return fig, resources

@app.callback(
    Output('world-map', 'figure'),
    [Input('map-metric', 'value')]
)
def update_world_map(metric):
    map_df = df.groupby('location').last().reset_index()
    return px.choropleth(
        map_df, locations='location', locationmode='country names',
        color=metric, hover_data=['population'],
        color_continuous_scale='Viridis'
    )

    # Prepared by the assistant based on the original code

import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
from google.colab import userdata

# Retraining Model Function (moved to the top)
def retrain_model(df, location, api_key, target='total_cases', interval_hours=24):
    try:
        last_date = df['date'].max()
        new_data = fetch_promed_data(api_key, df['date'].min(), df['date'].max())

        if new_data.empty or not pd.api.types.is_datetime64_any_dtype(new_data['date']):
            print("No valid new data to concatenate")
            return None, None, df

        new_data['date'] = pd.to_datetime(new_data['date']).dt.tz_localize(None).dt.normalize()
        new_data_max = new_data['date'].max()

        if pd.notna(new_data_max) and new_data_max > last_date:
            new_data = new_data[['date', 'location', 'disease_signal']]
            for col in df.columns:
                if col not in new_data.columns:
                    new_data[col] = np.nan
            df = pd.concat([df, new_data], ignore_index=True)
            df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None).dt.normalize()

        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            print(f"Warning: Insufficient data for {location} after prepare_features")
            return None, None, df

        rf_model, _ = train_random_forest(df_loc, target)
        prophet_model = train_prophet(df_loc, target)
        return rf_model, prophet_model, df
    except Exception as e:
        print(f"Error in retrain_model: {e}")
        return None, None, df

# Rest of the code remains the same as in the previous version
# ... (include the entire remaining code from the previous implementation)

# Load NewsAPI key
try:
    newsapi_key = userdata.get('NEWSAPI_KEY')  # Set in Colab Secrets
    print("NewsAPI key loaded successfully")
except Exception as e:
    print(f"Failed to load NewsAPI key: {e}")
    newsapi_key = "YOUR_API_KEY_HERE"  # Replace with your actual key if not using Colab Secrets

# Load OWID data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)
df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None)  # Ensure timezone-naive
df = df[['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
         'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']]

# Run retraining
for _ in range(2):  # Reduced to 2 for faster startup
    rf_model, prophet_model, df = retrain_model(df, 'USA', newsapi_key)
    if rf_model is None or prophet_model is None:
        print("Retraining skipped due to invalid data")
    else:
        print("Retraining complete")

if __name__ == '__main__':
    app.run(debug=True)

In [None]:
!pip install dash_bootstrap_components

In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
from google.colab import userdata

# Load NewsAPI key
try:
    newsapi_key = userdata.get('NEWSAPI_KEY')  # Set in Colab Secrets
    print("NewsAPI key loaded successfully")
except Exception as e:
    print(f"Failed to load NewsAPI key: {e}")
    newsapi_key = "YOUR_API_KEY_HERE"  # Replace with your actual key if not using Colab Secrets

# Load OWID data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)
df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None)  # Ensure timezone-naive
df = df[['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
         'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']]

# Simulate healthcare capacity data
capacity_data = {
    'location': ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany', 'France', 'Italy', 'Spain',
                'China', 'Japan', 'South Korea', 'Australia', 'Canada', 'Mexico'],
    'hospital_beds': [10000, 5000, 3000, 8000, 9000, 7500, 7000, 6000, 12000, 9500, 8500, 6500, 7800, 4000],
    'icu_beds': [2000, 1000, 500, 1600, 1800, 1500, 1400, 1200, 2400, 1900, 1700, 1300, 1560, 800],
    'ventilators': [500, 300, 200, 400, 450, 375, 350, 300, 600, 475, 425, 325, 390, 200]
}
capacity_df = pd.DataFrame(capacity_data)

# Fetch NewsAPI data
def fetch_promed_data(api_key, date_range_start=None, date_range_end=None):
    try:
        url = "https://newsapi.org/v2/everything"
        params = {
            "q": "disease outbreak OR epidemic OR health emergency OR infectious disease OR COVID-19",
            "language": "en",
            "sortBy": "publishedAt",
            "apiKey": api_key,
            "pageSize": 30
        }
        if date_range_start and date_range_end:
            params["from"] = date_range_start.strftime('%Y-%m-%d')
            params["to"] = date_range_end.strftime('%Y-%m-%d')

        response = requests.get(url, params=params, timeout=10)

        if response.status_code == 200:
            articles = response.json().get("articles", [])
            if not articles:
                raise Exception("No articles found")

            data = []
            for article in articles:
                date = pd.to_datetime(article["publishedAt"]).tz_localize(None).normalize()  # Date-only
                text = (article["title"] + " " + (article.get("content", "") or article["description"])).lower()

                signal = (
                    100 if any(word in text for word in ["outbreak", "epidemic", "pandemic"]) else
                    50 if any(word in text for word in ["health", "disease", "emergency"]) else
                    10
                )

                location = (
                    "USA" if any(word in text for word in ["usa", "united states", "america"]) else
                    "India" if any(word in text for word in ["india", "indian"]) else
                    "Brazil" if any(word in text for word in ["brazil", "brazilian"]) else
                    "Unknown"
                )

                data.append({"date": date, "location": location, "disease_signal": signal})

            df = pd.DataFrame(data)
            if df.empty:
                raise Exception("No relevant data parsed")

            df["date"] = pd.to_datetime(df["date"]).dt.tz_localize(None).dt.normalize()
            print("NewsAPI data fetched:", df.head())
            return df

        else:
            raise Exception(f"API request failed with status {response.status_code}")

    except Exception as e:
        print(f"Error fetching NewsAPI data: {e}")
        owid_dates = pd.date_range(start=date_range_start, end=date_range_end, freq='D')[-30:] if date_range_start and date_range_end else \
                     [pd.to_datetime(datetime.now().date() - timedelta(days=x)).tz_localize(None).normalize() for x in range(30)]
        fallback = pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": ["USA"] * len(owid_dates),
            "disease_signal": np.random.randint(10, 100, len(owid_dates))
        })
        print("Using fallback data")
        return fallback

# Prepare features
def prepare_features(df, location, target='total_cases', lookback=7):
    df_loc = df[df['location'] == location].sort_values('date')

    df_loc = df_loc.dropna(subset=[target])

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data for {location} with non-null {target}")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)

    date_start = df_loc['date'].min()
    date_end = df_loc['date'].max()
    news_data = fetch_promed_data(newsapi_key, date_start, date_end)
    if 'disease_signal' not in news_data.columns:
        news_data['disease_signal'] = 0

    df_loc['date'] = df_loc['date'].dt.normalize()
    news_data['date'] = news_data['date'].dt.normalize()

    df_loc = df_loc.merge(news_data[['date', 'disease_signal']], on='date', how='left')

    if 'disease_signal' not in df_loc.columns:
        df_loc['disease_signal'] = 0
    df_loc['disease_signal'] = df_loc['disease_signal'].fillna(0)

    feature_cols = [target] + [f'lag_{i}' for i in range(1, lookback + 1)]
    df_loc = df_loc.dropna(subset=feature_cols)

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data after processing for {location}, returning fallback")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    return df_loc

# Random Forest
def train_random_forest(df, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features]
    y = df[target]

    if X.empty or X.shape[0] < 1:
        raise ValueError("No samples available for training RandomForestRegressor")

    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X, y)

    importance = pd.DataFrame({'feature': features, 'importance': model.feature_importances_})
    return model, importance

# Prophet
def train_prophet(df, target='total_cases'):
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})
    if prophet_df.shape[0] < 2:
        raise ValueError("Dataframe has less than 2 non-NaN rows for training Prophet")
    model = Prophet(yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=True)
    model.fit(prophet_df)
    return model

# Ensemble
def ensemble_predict(rf_model, prophet_model, df, forecast_days, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features].tail(1)
    rf_pred = rf_model.predict(X)

    future = prophet_model.make_future_dataframe(periods=forecast_days)
    prophet_pred = prophet_model.predict(future)['yhat'].tail(forecast_days).values

    ensemble_pred = (rf_pred[0] + prophet_pred[0]) / 2
    return ensemble_pred

# Forecasting
def forecast_resource(df, location, forecast_days, model_type, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)

        if model_type == 'Random Forest':
            model, importance = train_random_forest(df_loc, target)
            features = [col for col in df_loc.columns if 'lag_' in col or col == 'disease_signal']
            X = df_loc[features].tail(1)
            pred = model.predict(X)
            return pred, importance, None

        elif model_type == 'Prophet':
            model = train_prophet(df_loc, target)
            future = model.make_future_dataframe(periods=forecast_days)
            forecast = model.predict(future)
            return forecast['yhat'].tail(forecast_days).values, None, forecast[['ds', 'yhat_lower', 'yhat_upper']]

        else:  # Ensemble
            rf_model, importance = train_random_forest(df_loc, target)
            prophet_model = train_prophet(df_loc, target)
            pred = ensemble_predict(rf_model, prophet_model, df_loc, forecast_days, target)
            return pred, importance, None
    except Exception as e:
        print(f"Error in forecast_resource: {e}")
        return np.array([0]), None, None

# Model evaluation
def evaluate_model(df, location, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        train_size = int(0.8 * len(df_loc))
        train_df, test_df = df_loc[:train_size], df_loc[train_size:]

        if test_df.empty:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        rf_model, _ = train_random_forest(train_df, target)
        features = [col for col in test_df.columns if 'lag_' in col or col == 'disease_signal']
        rf_pred = rf_model.predict(test_df[features])
        rf_rmse = np.sqrt(mean_squared_error(test_df[target], rf_pred))
        rf_r2 = r2_score(test_df[target], rf_pred)

        prophet_model = train_prophet(train_df, target)
        future = prophet_model.make_future_dataframe(periods=len(test_df))
        forecast = prophet_model.predict(future)
        prophet_pred = forecast['yhat'].tail(len(test_df)).values
        prophet_rmse = np.sqrt(mean_squared_error(test_df[target], prophet_pred))
        prophet_r2 = r2_score(test_df[target], prophet_pred)

        return {
            'Random Forest': {'RMSE': rf_rmse, 'R2': rf_r2},
            'Prophet': {'RMSE': prophet_rmse, 'R2': prophet_r2}
        }
    except Exception as e:
        print(f"Error in evaluate_model: {e}")
        return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

# Calculate resource allocation
def calculate_resources(predicted_cases, country):
    # Define resource ratios based on healthcare capacity and case severity
    resource_ratios = {
        'hospital_beds': 0.15,  # 15% of cases require hospitalization
        'icu_beds': 0.05,       # 5% of cases require ICU
        'ventilators': 0.02,    # 2% of cases require ventilators
        'medical_staff': 0.3,   # Staff needed per hospitalized patient
        'ppe_kits': 10,         # PPE kits per hospitalized patient per day
        'test_kits': 2,         # Test kits per case
    }

    # Get country capacity
    country_capacity = capacity_df[capacity_df['location'] == country]
    if country_capacity.empty:
        # Use default values if country not found
        capacity = {
            'hospital_beds': 5000,
            'icu_beds': 1000,
            'ventilators': 300
        }
    else:
        capacity = {
            'hospital_beds': country_capacity['hospital_beds'].iloc[0],
            'icu_beds': country_capacity['icu_beds'].iloc[0],
            'ventilators': country_capacity['ventilators'].iloc[0]
        }

    # Calculate needed resources
    predicted_value = predicted_cases[0] if isinstance(predicted_cases, np.ndarray) else predicted_cases
    resources_needed = {
        'hospital_beds': int(predicted_value * resource_ratios['hospital_beds']),
        'icu_beds': int(predicted_value * resource_ratios['icu_beds']),
        'ventilators': int(predicted_value * resource_ratios['ventilators']),
        'medical_staff': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['medical_staff']),
        'ppe_kits': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['ppe_kits']),
        'test_kits': int(predicted_value * resource_ratios['test_kits'])
    }

    # Calculate shortages
    shortages = {
        'hospital_beds': max(0, resources_needed['hospital_beds'] - capacity['hospital_beds']),
        'icu_beds': max(0, resources_needed['icu_beds'] - capacity['icu_beds']),
        'ventilators': max(0, resources_needed['ventilators'] - capacity['ventilators'])
    }

    return resources_needed, capacity, shortages

# Run retraining
for _ in range(2):  # Reduced to 2 for faster startup
    rf_model, prophet_model, df = retrain_model(df, 'USA', newsapi_key)
    if rf_model is None or prophet_model is None:
        print("Retraining skipped due to invalid data")
    else:
        print("Retraining complete")

# Create filtered data for map
map_data = df.groupby('location').last().reset_index()
map_data = map_data.dropna(subset=['total_cases'])

# Dashboard
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Add CSS for custom styling
app.index_string = '''
<!DOCTYPE html>
<html>
    <head>
        {%metas%}
        <title>{%title%}</title>
        {%favicon%}
        {%css%}
        <style>
            .card {
                box-shadow: 0 4px 8px rgba(0,0,0,0.1);
                transition: 0.3s;
                border-radius: 10px;
                margin-bottom: 20px;
            }
            .card:hover {
                box-shadow: 0 8px 16px rgba(0,0,0,0.2);
            }
            .card-header {
                background-color: #3498db;
                color: white;
                border-radius: 10px 10px 0 0 !important;
                font-weight: bold;
            }
            .nav-tabs .nav-link.active {
                background-color: #3498db;
                color: white;
            }
            .tab-content {
                border: 1px solid #dee2e6;
                border-top: 0;
                padding: 15px;
            }
            .resource-alert {
                padding: 10px;
                border-radius: 5px;
                margin-bottom: 10px;
            }
            .dashboard-title {
                color: #2c3e50;
                text-align: center;
                margin-bottom: 30px;
                font-weight: bold;
            }
        </style>
    </head>
    <body>
        {%app_entry%}
        <footer>
            {%config%}
            {%scripts%}
            {%renderer%}
        </footer>
    </body>
</html>
'''

app.layout = dbc.Container([
    html.H1("Medical Resource Demand Prediction Dashboard", className="dashboard-title mt-4 mb-4"),

    dbc.Tabs([
        dbc.Tab(label="Predictions & Resources", children=[
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Control Panel"),
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col([
                                    html.Label("Select Country", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='country-dropdown',
                                        options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                        value='USA',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Forecast Days", style={'fontWeight': 'bold'}),
                                    dcc.Slider(
                                        id='forecast-days',
                                        min=7,
                                        max=90,
                                        step=1,
                                        value=30,
                                        marks={i: str(i) for i in range(7, 91, 14)},
                                        tooltip={"placement": "bottom", "always_visible": True}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Model Type", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='model-type',
                                        options=[
                                            {'label': 'Random Forest', 'value': 'Random Forest'},
                                            {'label': 'Prophet', 'value': 'Prophet'},
                                            {'label': 'Ensemble', 'value': 'Ensemble'}
                                        ],
                                        value='Ensemble',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4)
                            ]),
                            html.Label("Target Variable", className="mt-3", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='target-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            ),
                            dbc.Button("Run Prediction", id='run-btn', color='primary', className='mt-3', style={'width': '100%'})
                        ])
                    ])
                ], width=12)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Prediction Results"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-prediction",
                                type="circle",
                                children=dcc.Graph(id='prediction-plot')
                            )
                        ])
                    ])
                ], width=12, lg=8),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Required Medical Resources"),
                        dbc.CardBody([
                            html.Div(id='resources-output')
                        ])
                    ])
                ], width=12, lg=4)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Feature Importance"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-features",
                                type="circle",
                                children=dcc.Graph(id='feature-importance-plot')
                            )
                        ])
                    ])
                ], width=12, lg=6),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Model Performance Metrics"),
                        dbc.CardBody([
                            html.Div(id='performance-metrics')
                        ])
                    ])
                ], width=12, lg=6)
            ])
        ]),

        dbc.Tab(label="Global Impact Map", children=[
            dbc.Card([
                dbc.CardHeader("Global Disease Distribution"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Visualization", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='map-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'Cases per Million', 'value': 'cases_per_million'},
                                    {'label': 'Deaths per Million', 'value': 'deaths_per_million'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Map Type", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='map-type',
                                options=[
                                    {'label': 'Choropleth Map', 'value': 'choropleth'},
                                    {'label': 'Bubble Map', 'value': 'bubble'}
                                ],
                                value='choropleth',
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-map",
                        type="circle",
                        children=dcc.Graph(id='global-map', style={'height': '70vh'})
                    )
                ])
            ])
        ]),

        dbc.Tab(label="News & Disease Signals", children=[
            dbc.Card([
                dbc.CardHeader("Real-time Disease Signal Analysis"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Country", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='news-country-dropdown',
                                options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                value='USA',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Date Range", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='news-timeframe',
                                options=[
                                    {'label': 'Last 30 days', 'value': 30},
                                    {'label': 'Last 90 days', 'value': 90},
                                    {'label': 'Last 180 days', 'value': 180}
                                ],
                                value=30,
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-news",
                        type="circle",
                        children=[
                            dcc.Graph(id='news-signal-plot'),
                            html.Div(id='news-headlines', className="mt-4")
                        ]
                    )
                ])
            ])
        ])
    ])
], fluid=True)

# Prepare models and data
def retrain_model(df, location, api_key, target='total_cases', interval_hours=24):
    try:
        last_date = df['date'].max()
        new_data = fetch_promed_data(api_key, df['date'].min(), df['date'].max())

        if new_data.empty or not pd.api.types.is_datetime64_any_dtype(new_data['date']):
            print("No valid new data to concatenate")
            return None, None, df

        new_data['date'] = pd.to_datetime(new_data['date']).dt.tz_localize(None).dt.normalize()
        new_data_max = new_data['date'].max()

        if pd.notna(new_data_max) and new_data_max > last_date:
            new_data = new_data[['date', 'location', 'disease_signal']]
            for col in df.columns:
                if col not in new_data.columns:
                    new_data[col] = np.nan
            df = pd.concat([df, new_data], ignore_index=True)
            df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None).dt.normalize()

        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            print(f"Warning: Insufficient data for {location} after prepare_features")
            return None, None, df

        rf_model, _ = train_random_forest(df_loc, target)
        prophet_model = train_prophet(df_loc, target)
        return rf_model, prophet_model, df
    except Exception as e:
        print(f"Error in retrain_model: {e}")
        return None, None, df

# Calculate derived metrics
map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

# Callbacks
@app.callback(
    [Output('prediction-plot', 'figure'),
     Output('feature-importance-plot', 'figure'),
     Output('performance-metrics', 'children'),
     Output('resources-output', 'children')],
    [Input('run-btn', 'n_clicks')],
    [State('country-dropdown', 'value'),
     State('forecast-days', 'value'),
     State('model-type', 'value'),
     State('target-variable', 'value')]
)
def update_prediction_dashboard(n_clicks, country, forecast_days, model_type, target_variable):
    if n_clicks is None:
        pred_fig = go.Figure()
        pred_fig.update_layout(title="No predictions yet - press 'Run Prediction'")

        imp_fig = go.Figure()
        imp_fig.update_layout(title="No feature importance data yet")

        return pred_fig, imp_fig, "No performance metrics yet", "No resource calculation yet"

    try:
        # Run prediction
        pred, importance, confidence = forecast_resource(df, country, forecast_days, model_type, target_variable)

        # Create prediction plot
        pred_fig = go.Figure()
        pred_dates = [datetime.now() + timedelta(days=i) for i in range(forecast_days)]
        pred_values = pred if not isinstance(pred, float) else [pred] * forecast_days

        # Add historical data to the plot
        if target_variable in ['total_cases', 'total_deaths', 'new_cases', 'new_deaths', 'hosp_patients', 'icu_patients']:
            historical = df[(df['location'] == country) & (df[target_variable].notna())]
            if not historical.empty:
                pred_fig.add_trace(go.Scatter(
                    x=historical['date'][-30:],  # Last 30 days
                    y=historical[target_variable][-30:],
                    name='Historical Data',
                    line=dict(color='royalblue')
                ))

        # Add prediction line
        pred_fig.add_trace(go.Scatter(
            x=pred_dates,
            y=pred_values,
            name='Prediction',
            line=dict(color='#e74c3c', dash='dash', width=3)
        ))

        # Add confidence intervals if available
        if confidence is not None:
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_lower'],
                fill=None,
                mode='lines',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Lower CI'
            ))
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_upper'],
                fill='tonexty',
                fillcolor='rgba(231, 76, 60, 0.2)',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Upper CI'
            ))

        target_name = {
            'total_cases': 'Total Cases',
            'new_cases': 'New Cases',
            'total_deaths': 'Total Deaths',
            'new_deaths': 'New Deaths',
            'hosp_patients': 'Hospitalized Patients',
            'icu_patients': 'ICU Patients'
        }.get(target_variable, target_variable)

        pred_fig.update_layout(
            title=f"{model_type} Prediction for {target_name} in {country}",
            xaxis_title="Date",
            yaxis_title=target_name,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            template="plotly_white",
            margin=dict(l=40, r=40, t=60, b=40)
        )

        # Create feature importance plot
        imp_fig = go.Figure()
        if importance is not None and not importance.empty:
            # Sort by importance
            importance = importance.sort_values('importance', ascending=False)

            imp_fig = px.bar(
                importance,
                x='feature',
                y='importance',
                title='Feature Importance Analysis',
                labels={'feature': 'Feature', 'importance': 'Importance Score'},
                color='importance',
                color_continuous_scale='Viridis'
            )

            imp_fig.update_layout(
                xaxis_title="Feature",
                yaxis_title="Importance Score",
                template="plotly_white",
                margin=dict(l=40, r=40, t=60, b=40)
            )
        else:
            imp_fig.update_layout(title="No feature importance data available for this model")

        # Performance metrics
        metrics = evaluate_model(df, country, target_variable)
        metrics_output = dbc.Card([
            dbc.Row([
                dbc.Col([
                    html.H5("Random Forest", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Random Forest']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Random Forest']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Random Forest']['R2'] * 100))}%",
                                    "backgroundColor": "#2ecc71"
                                }
                            )
                        ])
                    ])
                ], width=6),
                dbc.Col([
                    html.H5("Prophet", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Prophet']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Prophet']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Prophet']['R2'] * 100))}%",
                                    "backgroundColor": "#3498db"
                                }
                            )
                        ])
                    ])
                ], width=6)
            ])
        ], body=True)

        # Resource calculation
        resources_needed, capacity, shortages = calculate_resources(pred_values[0], country)

        def get_alert_color(used, capacity):
            utilization = (used / capacity) if capacity > 0 else 1
            if utilization < 0.7:
                return "success"
            elif utilization < 0.9:
                return "warning"
            else:
                return "danger"

        resources_output = html.Div([
            html.H5(f"Projected Resource Needs ({forecast_days} days)"),
            html.Div([
                dbc.Alert([
                    html.Div([
                        html.Strong("Hospital Beds"),
                        html.Span(f": {resources_needed['hospital_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['hospital_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['hospital_beds'] / capacity['hospital_beds'] * 100) if capacity['hospital_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['hospital_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['hospital_beds']}", className="text-danger")
                    if shortages['hospital_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['hospital_beds'], capacity['hospital_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("ICU Beds"),
                        html.Span(f": {resources_needed['icu_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['icu_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['icu_beds'] / capacity['icu_beds'] * 100) if capacity['icu_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['icu_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['icu_beds']}", className="text-danger")
                    if shortages['icu_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['icu_beds'], capacity['icu_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Ventilators"),
                        html.Span(f": {resources_needed['ventilators']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['ventilators']} units"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['ventilators'] / capacity['ventilators'] * 100) if capacity['ventilators'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['ventilators'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['ventilators']}", className="text-danger")
                    if shortages['ventilators'] > 0 else None
                ], color=get_alert_color(resources_needed['ventilators'], capacity['ventilators']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Additional Resources"),
                        html.Ul([
                            html.Li(f"Medical Staff: {resources_needed['medical_staff']} needed"),
                            html.Li(f"PPE Kits: {resources_needed['ppe_kits']} needed"),
                            html.Li(f"Test Kits: {resources_needed['test_kits']} needed")
                        ])
                    ])
                ], color="info")
            ])
        ])

        return pred_fig, imp_fig, metrics_output, resources_output

    except Exception as e:
        print(f"Error in update_prediction_dashboard: {e}")
        return go.Figure(), go.Figure(), f"Error: {str(e)}", f"Error: {str(e)}"

@app.callback(
    Output('global-map', 'figure'),
    [Input('map-variable', 'value'),
     Input('map-type', 'value')]
)
def update_map(map_variable, map_type):
    try:
        # Calculate per million values if not already done
        if 'cases_per_million' not in map_data.columns:
            map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
        if 'deaths_per_million' not in map_data.columns:
            map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

        # Get variable name for display
        variable_name = {
            'total_cases': 'Total Cases',
            'total_deaths': 'Total Deaths',
            'cases_per_million': 'Cases per Million',
            'deaths_per_million': 'Deaths per Million'
        }.get(map_variable, map_variable)

        # Create appropriate map
        if map_type == 'choropleth':
            fig = px.choropleth(
                map_data,
                locations='location',
                locationmode='country names',
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )
        else:  # bubble map
            fig = px.scatter_geo(
                map_data,
                locations='location',
                locationmode='country names',
                size=map_variable,
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                size_max=50,
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )

        fig.update_layout(
            height=700,  # Make the map larger
            margin=dict(l=0, r=0, t=50, b=0),
            template="plotly_white",
            coloraxis_colorbar=dict(
                title=variable_name
            )
        )

        return fig

    except Exception as e:
        print(f"Error in update_map: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating map: {str(e)}")
        return fig

@app.callback(
    [Output('news-signal-plot', 'figure'),
     Output('news-headlines', 'children')],
    [Input('news-country-dropdown', 'value'),
     Input('news-timeframe', 'value')]
)
def update_news_signal(country, days):
    try:
        end_date = datetime.now()
        start_date = end_date - timedelta(days=days)

        # Get disease signal data
        news_data = fetch_promed_data(newsapi_key, start_date, end_date)
        if country != 'Global':
            news_data = news_data[news_data['location'] == country]

        # If no data, return empty figures
        if news_data.empty:
            fig = go.Figure()
            fig.update_layout(
                title=f"No disease signal data available for {country} in last {days} days",
                template="plotly_white"
            )
            return fig, "No news headlines available"

        # Summarize by date
        news_by_date = news_data.groupby('date')['disease_signal'].mean().reset_index()

        # Create signal plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=news_by_date['date'],
            y=news_by_date['disease_signal'],
            mode='lines+markers',
            name='Disease Signal',
            line=dict(color='#e74c3c', width=3)
        ))

        # Add actual cases if available
        cases_data = df[(df['location'] == country) &
                    (df['date'] >= start_date) &
                    (df['date'] <= end_date) &
                    (df['new_cases'].notna())]

        if not cases_data.empty:
            # Normalize for display
            max_signal = news_by_date['disease_signal'].max()
            max_cases = cases_data['new_cases'].max()
            scale_factor = max_signal / max_cases if max_cases > 0 else 1

            fig.add_trace(go.Scatter(
                x=cases_data['date'],
                y=cases_data['new_cases'] * scale_factor,
                mode='lines',
                name='New Cases (Scaled)',
                yaxis='y2',
                line=dict(color='#3498db', width=2, dash='dot')
            ))

        fig.update_layout(
            title=f"Disease Signal for {country} - Last {days} Days",
            xaxis_title="Date",
            yaxis_title="Disease Signal",
            template="plotly_white",
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            margin=dict(l=40, r=40, t=60, b=40),
            yaxis2=dict(
                title="New Cases (Scaled)",
                overlaying="y",
                side="right",
                showgrid=False
            ) if not cases_data.empty else {}
        )

        # Create simulated news headlines
        headlines = [
            {"title": "New outbreak reported in northeastern region", "score": 95},
            {"title": "Health officials monitoring unusual respiratory symptoms", "score": 80},
            {"title": "Vaccination campaign launched to combat emerging strain", "score": 70},
            {"title": "Hospital capacity strained amid rising cases", "score": 85},
            {"title": "Travel restrictions considered as cases increase", "score": 75}
        ]

        headlines_output = html.Div([
            html.H5("Recent Disease-Related Headlines"),
            dbc.ListGroup([
                dbc.ListGroupItem([
                    html.Div([
                        html.Strong(headline["title"]),
                        html.Span(
                            f" - Signal Score: {headline['score']}",
                            className="ml-2",
                            style={"color": "#e74c3c" if headline["score"] > 80 else "#f39c12" if headline["score"] > 60 else "#2ecc71"}
                        )
                    ])
                ]) for headline in headlines
            ])
        ])

        return fig, headlines_output

    except Exception as e:
        print(f"Error in update_news_signal: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating news signal: {str(e)}")
        return fig, f"Error: {str(e)}"

# Run the app
if __name__ == '__main__':
    try:
        app.run_server(mode='inline', debug=True)
    except Exception as e:
        print(f"Error running Dash app: {e}")
        print("Trying external mode...")
        app.run(mode='external', debug=True, port=8050)



In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
import random

# Mock function to simulate NewsAPI data when real API is unavailable
def mock_fetch_promed_data(start_date, end_date):
    """
    Generate mock disease signal data when real API data is unavailable
    """
    # Create a range of dates
    date_range = pd.date_range(start=start_date, end=end_date)

    # Generate some random locations
    locations = ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany',
                 'France', 'Italy', 'Spain', 'China', 'Japan']

    # Create mock data
    mock_data = []
    for date in date_range:
        for loc in locations:
            # Generate a random disease signal
            disease_signal = random.randint(10, 90)
            mock_data.append({
                'date': date,
                'location': loc,
                'disease_signal': disease_signal
            })

    return pd.DataFrame(mock_data)

# Fallback function for data fetching
def fetch_promed_data(api_key, start_date, end_date):
    """
    Attempt to fetch real NewsAPI data, fallback to mock data if fails
    """
    try:
        # Simulated API call - replace with actual API logic if available
        print("Attempting to fetch NewsAPI data...")

        # Simulate API call failure
        # Uncomment the following line to simulate API failure
        # raise Exception("Simulated API failure")

        # If you have a real NewsAPI implementation, it would go here
        # For now, we'll use mock data to demonstrate the workflow
        return mock_fetch_promed_data(start_date, end_date)

    except Exception as e:
        print(f"Error fetching NewsAPI data: {e}")
        print("Falling back to mock data generation")
        return mock_fetch_promed_data(start_date, end_date)

# Prepare features function
def prepare_features(df, location, target='total_cases', lookback=7):
    """
    Prepare features for machine learning models
    """
    # Filter for specific location and sort by date
    df_loc = df[df['location'] == location].sort_values('date')

    # Ensure we have data
    if df_loc.empty:
        print(f"No data available for {location}")
        return pd.DataFrame()

    # Drop rows with NaN in target column
    df_loc = df_loc.dropna(subset=[target])

    # Create lagged features
    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)

    # Add disease signal (mock or from API)
    df_loc['disease_signal'] = np.random.randint(10, 100, len(df_loc))

    # Drop rows with NaN after creating lagged features
    df_loc = df_loc.dropna()

    return df_loc

# Random Forest training
def train_random_forest(df, target='total_cases'):
    """
    Train a Random Forest model
    """
    if df.empty:
        print("Empty dataframe provided for Random Forest training")
        return None, None

    # Select features
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']

    if not features:
        print("No valid features found for training")
        return None, None

    X = df[features]
    y = df[target]

    # Train model
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X, y)

    # Calculate feature importance
    importance = pd.DataFrame({
        'feature': features,
        'importance': model.feature_importances_
    }).sort_values('importance', ascending=False)

    return model, importance

# Prophet training
def train_prophet(df, target='total_cases'):
    """
    Train a Prophet model
    """
    if df.empty:
        print("Empty dataframe provided for Prophet training")
        return None

    # Prepare data for Prophet
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})

    # Train model
    model = Prophet(
        yearly_seasonality=True,
        weekly_seasonality=True,
        daily_seasonality=False
    )
    model.fit(prophet_df)

    return model

# Retraining Model Function
def retrain_model(df, location, api_key, target='total_cases', interval_hours=24):
    """
    Retrain models with latest data
    """
    try:
        # Get date range for data fetching
        last_date = df['date'].max()
        start_date = last_date - timedelta(days=365)  # Go back one year

        # Fetch new data
        new_data = fetch_promed_data(api_key, start_date, last_date)

        if new_data is None or new_data.empty:
            print("No new data available for retraining")
            return None, None, df

        # Normalize dates
        new_data['date'] = pd.to_datetime(new_data['date']).dt.tz_localize(None).dt.normalize()

        # Prepare features
        df_loc = prepare_features(df, location, target)

        if df_loc.empty:
            print(f"No valid features for {location}")
            return None, None, df

        # Train models
        rf_model, importance = train_random_forest(df_loc, target)
        prophet_model = train_prophet(df_loc, target)

        return rf_model, prophet_model, df_loc

    except Exception as e:
        print(f"Error in model retraining: {e}")
        return None, None, df

# Simulate a basic dataset if real data is unavailable
def create_mock_covid_data():
    """
    Create a mock COVID-19 dataset for demonstration
    """
    # Generate dates
    dates = pd.date_range(start='2020-01-01', end='2023-12-31', freq='D')

    # Generate mock data for multiple countries
    locations = ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany']

    data = []
    for location in locations:
        # Simulate case progression with some randomness
        base_cases = np.random.randint(1000, 100000)
        daily_increase = np.random.randint(10, 500)

        location_data = pd.DataFrame({
            'date': dates,
            'location': location,
            'total_cases': np.cumsum(np.random.normal(daily_increase, daily_increase*0.2, len(dates))) + base_cases,
            'new_cases': np.random.normal(daily_increase, daily_increase*0.2, len(dates)),
            'total_deaths': np.cumsum(np.random.normal(daily_increase*0.02, (daily_increase*0.02)*0.2, len(dates))),
            'new_deaths': np.random.normal(daily_increase*0.02, (daily_increase*0.02)*0.2, len(dates)),
            'hosp_patients': np.random.normal(daily_increase*0.1, (daily_increase*0.1)*0.2, len(dates)),
            'icu_patients': np.random.normal(daily_increase*0.03, (daily_increase*0.03)*0.2, len(dates)),
            'people_vaccinated': np.cumsum(np.random.normal(daily_increase*0.5, (daily_increase*0.5)*0.2, len(dates))),
            'population': np.random.randint(50000000, 330000000)
        })
        data.append(location_data)

    # Combine data
    df = pd.concat(data, ignore_index=True)
    return df

# Main execution
def main():
    # Set a mock NewsAPI key
    newsapi_key = "MOCK_API_KEY"

    # Create mock COVID-19 dataset
    print("Generating mock COVID-19 dataset...")
    df = create_mock_covid_data()

    # Run retraining for multiple countries
    countries = ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany']
    targets = ['total_cases', 'new_cases', 'total_deaths']

    for country in countries:
        for target in targets:
            print(f"\nRetraining for {country} - Target: {target}")
            rf_model, prophet_model, trained_df = retrain_model(df, country, newsapi_key, target)

            if rf_model is not None and prophet_model is not None:
                print(f"Successfully trained models for {country} - {target}")
            else:
                print(f"Failed to train models for {country} - {target}")

    print("\nModel training and evaluation complete.")

# Run the main function
if __name__ == "__main__":
    main()

In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import os

print("Starting medical resource forecasting dashboard setup...")

# Create assets directory if not exists
if not os.path.exists('assets'):
    os.makedirs('assets')
    print("Created assets directory")

# Sample data (to avoid dependency on external APIs during testing)
print("Loading sample data...")
# Create sample data instead of fetching
dates = pd.date_range(start='2020-01-01', end='2023-12-31', freq='D')
locations = ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany']
sample_data = []

for loc in locations:
    base_cases = np.random.randint(1000, 10000)
    growth_factor = np.random.uniform(1.001, 1.005)
    for i, date in enumerate(dates):
        cases = int(base_cases * (growth_factor ** i))
        deaths = int(cases * np.random.uniform(0.01, 0.05))
        sample_data.append({
            'date': date,
            'location': loc,
            'total_cases': cases,
            'new_cases': int(cases * np.random.uniform(0.01, 0.1)),
            'total_deaths': deaths,
            'new_deaths': int(deaths * np.random.uniform(0.05, 0.2)),
            'hosp_patients': int(cases * np.random.uniform(0.05, 0.2)),
            'icu_patients': int(cases * np.random.uniform(0.01, 0.05)),
            'people_vaccinated': int(np.random.uniform(0, 0.8) * 1000000),
            'population': np.random.randint(10000000, 300000000)
        })

df = pd.DataFrame(sample_data)
print(f"Sample data created with {len(df)} rows for {len(locations)} locations")

# Healthcare capacity data
capacity_data = {
    'location': locations,
    'hospital_beds': [10000, 5000, 3000, 8000, 9000],
    'icu_beds': [2000, 1000, 500, 1600, 1800],
    'ventilators': [500, 300, 200, 400, 450]
}
capacity_df = pd.DataFrame(capacity_data)
print("Capacity data created")

# Model training functions
def prepare_features(df, location, target='total_cases', lookback=7):
    print(f"Preparing features for {location}")
    df_loc = df[df['location'] == location].sort_values('date')
    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)
    # Add a disease signal from sample data
    df_loc['disease_signal'] = np.random.randint(10, 100, size=len(df_loc))
    return df_loc.dropna()

# Dashboard setup
print("Setting up Dash application...")
app = Dash(__name__,
          external_stylesheets=[dbc.themes.BOOTSTRAP],
          assets_folder='assets')

server = app.server

# Dashboard layout
app.layout = dbc.Container([
    html.H1("Medical Resource Forecasting Dashboard", className="text-center my-4"),
    html.P("This is a test dashboard. Select parameters below and click 'Run Forecast'", className="text-center"),

    dbc.Row([
        dbc.Col([
            dbc.Card([
                dbc.CardHeader("Controls"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col(html.Label("Select Country:"), width=3),
                        dbc.Col(dcc.Dropdown(
                            id='country-dropdown',
                            options=[{'label': l, 'value': l} for l in locations],
                            value='USA'
                        ), width=9)
                    ], className="mb-3"),
                    dbc.Row([
                        dbc.Col(html.Label("Forecast Days:"), width=3),
                        dbc.Col(dcc.Slider(
                            id='forecast-days',
                            min=7, max=90, value=30,
                            marks={i: str(i) for i in [7, 30, 60, 90]}
                        ), width=9)
                    ], className="mb-3"),
                    dbc.Button("Run Forecast", id='run-btn', color='primary', className='mt-3')
                ])
            ])
        ])
    ]),

    dbc.Row([
        dbc.Col(dbc.Card([
            dbc.CardHeader("Forecast Results"),
            dbc.CardBody(dcc.Graph(id='forecast-plot'))
        ]), width=8),
        dbc.Col(dbc.Card([
            dbc.CardHeader("Resource Needs"),
            dbc.CardBody(html.Div(id='resource-output'))
        ]), width=4)
    ], className='my-4')
], fluid=True)

# Callbacks
@app.callback(
    [Output('forecast-plot', 'figure'),
     Output('resource-output', 'children')],
    [Input('run-btn', 'n_clicks')],
    [State('country-dropdown', 'value'),
     State('forecast-days', 'value')]
)
def update_forecast(n, country, days):
    print(f"Callback triggered with n_clicks={n}, country={country}, days={days}")

    if n is None:
        return go.Figure(layout=dict(title="Select parameters and click Run Forecast")), "Select parameters and click Run Forecast"

    # Simple forecasting logic
    df_loc = prepare_features(df, country)
    last_value = df_loc['total_cases'].iloc[-1]
    growth_rate = 1.02  # Simple 2% daily growth
    forecast_values = [last_value * (growth_rate ** i) for i in range(days)]

    # Create forecast plot
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=df_loc['date'].iloc[-90:],
        y=df_loc['total_cases'].iloc[-90:],
        name='Historical',
        line=dict(color='blue')
    ))

    forecast_dates = [df_loc['date'].iloc[-1] + timedelta(days=i) for i in range(1, days+1)]
    fig.add_trace(go.Scatter(
        x=forecast_dates,
        y=forecast_values,
        name='Forecast',
        line=dict(color='red', dash='dash')
    ))

    fig.update_layout(
        title=f"COVID-19 Case Forecast for {country}",
        xaxis_title="Date",
        yaxis_title="Total Cases"
    )

    # Resource calculation
    resource_needs = [
        f"Hospital Beds Needed: {int(forecast_values[-1] * 0.05)}",
        f"ICU Beds Needed: {int(forecast_values[-1] * 0.01)}",
        f"Ventilators Required: {int(forecast_values[-1] * 0.005)}"
    ]

    print(f"Forecast complete. Predicting {int(forecast_values[-1])} cases by day {days}")

    return fig, [html.P(item) for item in resource_needs]

if __name__ == '__main__':
    print("\nStarting the Dash server. Once running, open http://127.0.0.1:8050/ in your web browser")
    print("Press Ctrl+C to stop the server\n")
    app.run(debug=True)

In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
from google.colab import userdata

# Load NewsAPI key
try:
    newsapi_key = userdata.get('NEWSAPI_KEY')  # Set in Colab Secrets
    print("NewsAPI key loaded successfully")
except Exception as e:
    print(f"Failed to load NewsAPI key: {e}")
    newsapi_key = "YOUR_API_KEY_HERE"  # Replace with your actual key if not using Colab Secrets

# Load OWID data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)
df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None)  # Ensure timezone-naive
df = df[['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
         'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']]

# Simulate healthcare capacity data
capacity_data = {
    'location': ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany', 'France', 'Italy', 'Spain',
                'China', 'Japan', 'South Korea', 'Australia', 'Canada', 'Mexico'],
    'hospital_beds': [10000, 5000, 3000, 8000, 9000, 7500, 7000, 6000, 12000, 9500, 8500, 6500, 7800, 4000],
    'icu_beds': [2000, 1000, 500, 1600, 1800, 1500, 1400, 1200, 2400, 1900, 1700, 1300, 1560, 800],
    'ventilators': [500, 300, 200, 400, 450, 375, 350, 300, 600, 475, 425, 325, 390, 200]
}
capacity_df = pd.DataFrame(capacity_data)

# Fetch NewsAPI data
def fetch_promed_data(api_key, date_range_start=None, date_range_end=None):
    try:
        url = "https://newsapi.org/v2/everything"
        params = {
            "q": "disease outbreak OR epidemic OR health emergency OR infectious disease OR COVID-19",
            "language": "en",
            "sortBy": "publishedAt",
            "apiKey": api_key,
            "pageSize": 30
        }
        if date_range_start and date_range_end:
            params["from"] = date_range_start.strftime('%Y-%m-%d')
            params["to"] = date_range_end.strftime('%Y-%m-%d')

        response = requests.get(url, params=params, timeout=10)

        if response.status_code == 200:
            articles = response.json().get("articles", [])
            if not articles:
                raise Exception("No articles found")

            data = []
            for article in articles:
                date = pd.to_datetime(article["publishedAt"]).tz_localize(None).normalize()  # Date-only
                text = (article["title"] + " " + (article.get("content", "") or article["description"])).lower()

                signal = (
                    100 if any(word in text for word in ["outbreak", "epidemic", "pandemic"]) else
                    50 if any(word in text for word in ["health", "disease", "emergency"]) else
                    10
                )

                location = (
                    "USA" if any(word in text for word in ["usa", "united states", "america"]) else
                    "India" if any(word in text for word in ["india", "indian"]) else
                    "Brazil" if any(word in text for word in ["brazil", "brazilian"]) else
                    "Unknown"
                )

                data.append({"date": date, "location": location, "disease_signal": signal})

            df = pd.DataFrame(data)
            if df.empty:
                raise Exception("No relevant data parsed")

            df["date"] = pd.to_datetime(df["date"]).dt.tz_localize(None).dt.normalize()
            print("NewsAPI data fetched:", df.head())
            return df

        else:
            raise Exception(f"API request failed with status {response.status_code}")

    except Exception as e:
        print(f"Error fetching NewsAPI data: {e}")
        owid_dates = pd.date_range(start=date_range_start, end=date_range_end, freq='D')[-30:] if date_range_start and date_range_end else \
                     [pd.to_datetime(datetime.now().date() - timedelta(days=x)).tz_localize(None).normalize() for x in range(30)]
        fallback = pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": ["USA"] * len(owid_dates),
            "disease_signal": np.random.randint(10, 100, len(owid_dates))
        })
        print("Using fallback data")
        return fallback

# Prepare features
def prepare_features(df, location, target='total_cases', lookback=7):
    df_loc = df[df['location'] == location].sort_values('date')

    df_loc = df_loc.dropna(subset=[target])

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data for {location} with non-null {target}")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)

    date_start = df_loc['date'].min()
    date_end = df_loc['date'].max()
    news_data = fetch_promed_data(newsapi_key, date_start, date_end)
    if 'disease_signal' not in news_data.columns:
        news_data['disease_signal'] = 0

    df_loc['date'] = df_loc['date'].dt.normalize()
    news_data['date'] = news_data['date'].dt.normalize()

    df_loc = df_loc.merge(news_data[['date', 'disease_signal']], on='date', how='left')

    if 'disease_signal' not in df_loc.columns:
        df_loc['disease_signal'] = 0
    df_loc['disease_signal'] = df_loc['disease_signal'].fillna(0)

    feature_cols = [target] + [f'lag_{i}' for i in range(1, lookback + 1)]
    df_loc = df_loc.dropna(subset=feature_cols)

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data after processing for {location}, returning fallback")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    return df_loc

# Random Forest
def train_random_forest(df, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features]
    y = df[target]

    if X.empty or X.shape[0] < 1:
        raise ValueError("No samples available for training RandomForestRegressor")

    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X, y)

    importance = pd.DataFrame({'feature': features, 'importance': model.feature_importances_})
    return model, importance

# Prophet
def train_prophet(df, target='total_cases'):
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})
    if prophet_df.shape[0] < 2:
        raise ValueError("Dataframe has less than 2 non-NaN rows for training Prophet")
    model = Prophet(yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=True)
    model.fit(prophet_df)
    return model

# Ensemble
def ensemble_predict(rf_model, prophet_model, df, forecast_days, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features].tail(1)
    rf_pred = rf_model.predict(X)

    future = prophet_model.make_future_dataframe(periods=forecast_days)
    prophet_pred = prophet_model.predict(future)['yhat'].tail(forecast_days).values

    ensemble_pred = (rf_pred[0] + prophet_pred[0]) / 2
    return ensemble_pred

# Forecasting
def forecast_resource(df, location, forecast_days, model_type, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)

        if model_type == 'Random Forest':
            model, importance = train_random_forest(df_loc, target)
            features = [col for col in df_loc.columns if 'lag_' in col or col == 'disease_signal']
            X = df_loc[features].tail(1)
            pred = model.predict(X)
            return pred, importance, None

        elif model_type == 'Prophet':
            model = train_prophet(df_loc, target)
            future = model.make_future_dataframe(periods=forecast_days)
            forecast = model.predict(future)
            return forecast['yhat'].tail(forecast_days).values, None, forecast[['ds', 'yhat_lower', 'yhat_upper']]

        else:  # Ensemble
            rf_model, importance = train_random_forest(df_loc, target)
            prophet_model = train_prophet(df_loc, target)
            pred = ensemble_predict(rf_model, prophet_model, df_loc, forecast_days, target)
            return pred, importance, None
    except Exception as e:
        print(f"Error in forecast_resource: {e}")
        return np.array([0]), None, None

# Model evaluation
def evaluate_model(df, location, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        train_size = int(0.8 * len(df_loc))
        train_df, test_df = df_loc[:train_size], df_loc[train_size:]

        if test_df.empty:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        rf_model, _ = train_random_forest(train_df, target)
        features = [col for col in test_df.columns if 'lag_' in col or col == 'disease_signal']
        rf_pred = rf_model.predict(test_df[features])
        rf_rmse = np.sqrt(mean_squared_error(test_df[target], rf_pred))
        rf_r2 = r2_score(test_df[target], rf_pred)

        prophet_model = train_prophet(train_df, target)
        future = prophet_model.make_future_dataframe(periods=len(test_df))
        forecast = prophet_model.predict(future)
        prophet_pred = forecast['yhat'].tail(len(test_df)).values
        prophet_rmse = np.sqrt(mean_squared_error(test_df[target], prophet_pred))
        prophet_r2 = r2_score(test_df[target], prophet_pred)

        return {
            'Random Forest': {'RMSE': rf_rmse, 'R2': rf_r2},
            'Prophet': {'RMSE': prophet_rmse, 'R2': prophet_r2}
        }
    except Exception as e:
        print(f"Error in evaluate_model: {e}")
        return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

# Calculate resource allocation
def calculate_resources(predicted_cases, country):
    # Define resource ratios based on healthcare capacity and case severity
    resource_ratios = {
        'hospital_beds': 0.15,  # 15% of cases require hospitalization
        'icu_beds': 0.05,       # 5% of cases require ICU
        'ventilators': 0.02,    # 2% of cases require ventilators
        'medical_staff': 0.3,   # Staff needed per hospitalized patient
        'ppe_kits': 10,         # PPE kits per hospitalized patient per day
        'test_kits': 2,         # Test kits per case
    }

    # Get country capacity
    country_capacity = capacity_df[capacity_df['location'] == country]
    if country_capacity.empty:
        # Use default values if country not found
        capacity = {
            'hospital_beds': 5000,
            'icu_beds': 1000,
            'ventilators': 300
        }
    else:
        capacity = {
            'hospital_beds': country_capacity['hospital_beds'].iloc[0],
            'icu_beds': country_capacity['icu_beds'].iloc[0],
            'ventilators': country_capacity['ventilators'].iloc[0]
        }

    # Calculate needed resources
    predicted_value = predicted_cases[0] if isinstance(predicted_cases, np.ndarray) else predicted_cases
    resources_needed = {
        'hospital_beds': int(predicted_value * resource_ratios['hospital_beds']),
        'icu_beds': int(predicted_value * resource_ratios['icu_beds']),
        'ventilators': int(predicted_value * resource_ratios['ventilators']),
        'medical_staff': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['medical_staff']),
        'ppe_kits': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['ppe_kits']),
        'test_kits': int(predicted_value * resource_ratios['test_kits'])
    }

    # Calculate shortages
    shortages = {
        'hospital_beds': max(0, resources_needed['hospital_beds'] - capacity['hospital_beds']),
        'icu_beds': max(0, resources_needed['icu_beds'] - capacity['icu_beds']),
        'ventilators': max(0, resources_needed['ventilators'] - capacity['ventilators'])
    }

    return resources_needed, capacity, shortages

# Run retraining
for _ in range(2):  # Reduced to 2 for faster startup
    rf_model, prophet_model, df = retrain_model(df, 'USA', newsapi_key)
    if rf_model is None or prophet_model is None:
        print("Retraining skipped due to invalid data")
    else:
        print("Retraining complete")

# Create filtered data for map
map_data = df.groupby('location').last().reset_index()
map_data = map_data.dropna(subset=['total_cases'])

# Dashboard
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Add CSS for custom styling
app.index_string = '''
<!DOCTYPE html>
<html>
    <head>
        {%metas%}
        <title>{%title%}</title>
        {%favicon%}
        {%css%}
        <style>
            .card {
                box-shadow: 0 4px 8px rgba(0,0,0,0.1);
                transition: 0.3s;
                border-radius: 10px;
                margin-bottom: 20px;
            }
            .card:hover {
                box-shadow: 0 8px 16px rgba(0,0,0,0.2);
            }
            .card-header {
                background-color: #3498db;
                color: white;
                border-radius: 10px 10px 0 0 !important;
                font-weight: bold;
            }
            .nav-tabs .nav-link.active {
                background-color: #3498db;
                color: white;
            }
            .tab-content {
                border: 1px solid #dee2e6;
                border-top: 0;
                padding: 15px;
            }
            .resource-alert {
                padding: 10px;
                border-radius: 5px;
                margin-bottom: 10px;
            }
            .dashboard-title {
                color: #2c3e50;
                text-align: center;
                margin-bottom: 30px;
                font-weight: bold;
            }
        </style>
    </head>
    <body>
        {%app_entry%}
        <footer>
            {%config%}
            {%scripts%}
            {%renderer%}
        </footer>
    </body>
</html>
'''

app.layout = dbc.Container([
    html.H1("Medical Resource Demand Prediction Dashboard", className="dashboard-title mt-4 mb-4"),

    dbc.Tabs([
        dbc.Tab(label="Predictions & Resources", children=[
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Control Panel"),
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col([
                                    html.Label("Select Country", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='country-dropdown',
                                        options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                        value='USA',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Forecast Days", style={'fontWeight': 'bold'}),
                                    dcc.Slider(
                                        id='forecast-days',
                                        min=7,
                                        max=90,
                                        step=1,
                                        value=30,
                                        marks={i: str(i) for i in range(7, 91, 14)},
                                        tooltip={"placement": "bottom", "always_visible": True}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Model Type", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='model-type',
                                        options=[
                                            {'label': 'Random Forest', 'value': 'Random Forest'},
                                            {'label': 'Prophet', 'value': 'Prophet'},
                                            {'label': 'Ensemble', 'value': 'Ensemble'}
                                        ],
                                        value='Ensemble',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4)
                            ]),
                            html.Label("Target Variable", className="mt-3", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='target-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            ),
                            dbc.Button("Run Prediction", id='run-btn', color='primary', className='mt-3', style={'width': '100%'})
                        ])
                    ])
                ], width=12)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Prediction Results"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-prediction",
                                type="circle",
                                children=dcc.Graph(id='prediction-plot')
                            )
                        ])
                    ])
                ], width=12, lg=8),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Required Medical Resources"),
                        dbc.CardBody([
                            html.Div(id='resources-output')
                        ])
                    ])
                ], width=12, lg=4)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Feature Importance"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-features",
                                type="circle",
                                children=dcc.Graph(id='feature-importance-plot')
                            )
                        ])
                    ])
                ], width=12, lg=6),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Model Performance Metrics"),
                        dbc.CardBody([
                            html.Div(id='performance-metrics')
                        ])
                    ])
                ], width=12, lg=6)
            ])
        ]),

        dbc.Tab(label="Global Impact Map", children=[
            dbc.Card([
                dbc.CardHeader("Global Disease Distribution"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Visualization", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='map-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'Cases per Million', 'value': 'cases_per_million'},
                                    {'label': 'Deaths per Million', 'value': 'deaths_per_million'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Map Type", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='map-type',
                                options=[
                                    {'label': 'Choropleth Map', 'value': 'choropleth'},
                                    {'label': 'Bubble Map', 'value': 'bubble'}
                                ],
                                value='choropleth',
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-map",
                        type="circle",
                        children=dcc.Graph(id='global-map', style={'height': '70vh'})
                    )
                ])
            ])
        ]),

        dbc.Tab(label="News & Disease Signals", children=[
            dbc.Card([
                dbc.CardHeader("Real-time Disease Signal Analysis"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Country", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='news-country-dropdown',
                                options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                value='USA',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Date Range", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='news-timeframe',
                                options=[
                                    {'label': 'Last 30 days', 'value': 30},
                                    {'label': 'Last 90 days', 'value': 90},
                                    {'label': 'Last 180 days', 'value': 180}
                                ],
                                value=30,
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-news",
                        type="circle",
                        children=[
                            dcc.Graph(id='news-signal-plot'),
                            html.Div(id='news-headlines', className="mt-4")
                        ]
                    )
                ])
            ])
        ])
    ])
], fluid=True)

# Prepare models and data
def retrain_model(df, location, api_key, target='total_cases', interval_hours=24):
    try:
        last_date = df['date'].max()
        new_data = fetch_promed_data(api_key, df['date'].min(), df['date'].max())

        if new_data.empty or not pd.api.types.is_datetime64_any_dtype(new_data['date']):
            print("No valid new data to concatenate")
            return None, None, df

        new_data['date'] = pd.to_datetime(new_data['date']).dt.tz_localize(None).dt.normalize()
        new_data_max = new_data['date'].max()

        if pd.notna(new_data_max) and new_data_max > last_date:
            new_data = new_data[['date', 'location', 'disease_signal']]
            for col in df.columns:
                if col not in new_data.columns:
                    new_data[col] = np.nan
            df = pd.concat([df, new_data], ignore_index=True)
            df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None).dt.normalize()

        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            print(f"Warning: Insufficient data for {location} after prepare_features")
            return None, None, df

        rf_model, _ = train_random_forest(df_loc, target)
        prophet_model = train_prophet(df_loc, target)
        return rf_model, prophet_model, df
    except Exception as e:
        print(f"Error in retrain_model: {e}")
        return None, None, df

# Calculate derived metrics
map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

# Callbacks
@app.callback(
    [Output('prediction-plot', 'figure'),
     Output('feature-importance-plot', 'figure'),
     Output('performance-metrics', 'children'),
     Output('resources-output', 'children')],
    [Input('run-btn', 'n_clicks')],
    [State('country-dropdown', 'value'),
     State('forecast-days', 'value'),
     State('model-type', 'value'),
     State('target-variable', 'value')]
)
def update_prediction_dashboard(n_clicks, country, forecast_days, model_type, target_variable):
    if n_clicks is None:
        pred_fig = go.Figure()
        pred_fig.update_layout(title="No predictions yet - press 'Run Prediction'")

        imp_fig = go.Figure()
        imp_fig.update_layout(title="No feature importance data yet")

        return pred_fig, imp_fig, "No performance metrics yet", "No resource calculation yet"

    try:
        # Run prediction
        pred, importance, confidence = forecast_resource(df, country, forecast_days, model_type, target_variable)

        # Create prediction plot
        pred_fig = go.Figure()
        pred_dates = [datetime.now() + timedelta(days=i) for i in range(forecast_days)]
        pred_values = pred if not isinstance(pred, float) else [pred] * forecast_days

        # Add historical data to the plot
        if target_variable in ['total_cases', 'total_deaths', 'new_cases', 'new_deaths', 'hosp_patients', 'icu_patients']:
            historical = df[(df['location'] == country) & (df[target_variable].notna())]
            if not historical.empty:
                pred_fig.add_trace(go.Scatter(
                    x=historical['date'][-30:],  # Last 30 days
                    y=historical[target_variable][-30:],
                    name='Historical Data',
                    line=dict(color='royalblue')
                ))

        # Add prediction line
        pred_fig.add_trace(go.Scatter(
            x=pred_dates,
            y=pred_values,
            name='Prediction',
            line=dict(color='#e74c3c', dash='dash', width=3)
        ))

        # Add confidence intervals if available
        if confidence is not None:
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_lower'],
                fill=None,
                mode='lines',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Lower CI'
            ))
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_upper'],
                fill='tonexty',
                fillcolor='rgba(231, 76, 60, 0.2)',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Upper CI'
            ))

        target_name = {
            'total_cases': 'Total Cases',
            'new_cases': 'New Cases',
            'total_deaths': 'Total Deaths',
            'new_deaths': 'New Deaths',
            'hosp_patients': 'Hospitalized Patients',
            'icu_patients': 'ICU Patients'
        }.get(target_variable, target_variable)

        pred_fig.update_layout(
            title=f"{model_type} Prediction for {target_name} in {country}",
            xaxis_title="Date",
            yaxis_title=target_name,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            template="plotly_white",
            margin=dict(l=40, r=40, t=60, b=40)
        )

        # Create feature importance plot
        imp_fig = go.Figure()
        if importance is not None and not importance.empty:
            # Sort by importance
            importance = importance.sort_values('importance', ascending=False)

            imp_fig = px.bar(
                importance,
                x='feature',
                y='importance',
                title='Feature Importance Analysis',
                labels={'feature': 'Feature', 'importance': 'Importance Score'},
                color='importance',
                color_continuous_scale='Viridis'
            )

            imp_fig.update_layout(
                xaxis_title="Feature",
                yaxis_title="Importance Score",
                template="plotly_white",
                margin=dict(l=40, r=40, t=60, b=40)
            )
        else:
            imp_fig.update_layout(title="No feature importance data available for this model")

        # Performance metrics
        metrics = evaluate_model(df, country, target_variable)
        metrics_output = dbc.Card([
            dbc.Row([
                dbc.Col([
                    html.H5("Random Forest", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Random Forest']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Random Forest']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Random Forest']['R2'] * 100))}%",
                                    "backgroundColor": "#2ecc71"
                                }
                            )
                        ])
                    ])
                ], width=6),
                dbc.Col([
                    html.H5("Prophet", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Prophet']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Prophet']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Prophet']['R2'] * 100))}%",
                                    "backgroundColor": "#3498db"
                                }
                            )
                        ])
                    ])
                ], width=6)
            ])
        ], body=True)

        # Resource calculation
        resources_needed, capacity, shortages = calculate_resources(pred_values[0], country)

        def get_alert_color(used, capacity):
            utilization = (used / capacity) if capacity > 0 else 1
            if utilization < 0.7:
                return "success"
            elif utilization < 0.9:
                return "warning"
            else:
                return "danger"

        resources_output = html.Div([
            html.H5(f"Projected Resource Needs ({forecast_days} days)"),
            html.Div([
                dbc.Alert([
                    html.Div([
                        html.Strong("Hospital Beds"),
                        html.Span(f": {resources_needed['hospital_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['hospital_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['hospital_beds'] / capacity['hospital_beds'] * 100) if capacity['hospital_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['hospital_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['hospital_beds']}", className="text-danger")
                    if shortages['hospital_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['hospital_beds'], capacity['hospital_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("ICU Beds"),
                        html.Span(f": {resources_needed['icu_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['icu_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['icu_beds'] / capacity['icu_beds'] * 100) if capacity['icu_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['icu_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['icu_beds']}", className="text-danger")
                    if shortages['icu_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['icu_beds'], capacity['icu_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Ventilators"),
                        html.Span(f": {resources_needed['ventilators']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['ventilators']} units"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['ventilators'] / capacity['ventilators'] * 100) if capacity['ventilators'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['ventilators'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['ventilators']}", className="text-danger")
                    if shortages['ventilators'] > 0 else None
                ], color=get_alert_color(resources_needed['ventilators'], capacity['ventilators']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Additional Resources"),
                        html.Ul([
                            html.Li(f"Medical Staff: {resources_needed['medical_staff']} needed"),
                            html.Li(f"PPE Kits: {resources_needed['ppe_kits']} needed"),
                            html.Li(f"Test Kits: {resources_needed['test_kits']} needed")
                        ])
                    ])
                ], color="info")
            ])
        ])

        return pred_fig, imp_fig, metrics_output, resources_output

    except Exception as e:
        print(f"Error in update_prediction_dashboard: {e}")
        return go.Figure(), go.Figure(), f"Error: {str(e)}", f"Error: {str(e)}"

@app.callback(
    Output('global-map', 'figure'),
    [Input('map-variable', 'value'),
     Input('map-type', 'value')]
)
def update_map(map_variable, map_type):
    try:
        # Calculate per million values if not already done
        if 'cases_per_million' not in map_data.columns:
            map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
        if 'deaths_per_million' not in map_data.columns:
            map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

        # Get variable name for display
        variable_name = {
            'total_cases': 'Total Cases',
            'total_deaths': 'Total Deaths',
            'cases_per_million': 'Cases per Million',
            'deaths_per_million': 'Deaths per Million'
        }.get(map_variable, map_variable)

        # Create appropriate map
        if map_type == 'choropleth':
            fig = px.choropleth(
                map_data,
                locations='location',
                locationmode='country names',
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )
        else:  # bubble map
            fig = px.scatter_geo(
                map_data,
                locations='location',
                locationmode='country names',
                size=map_variable,
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                size_max=50,
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )

        fig.update_layout(
            height=700,  # Make the map larger
            margin=dict(l=0, r=0, t=50, b=0),
            template="plotly_white",
            coloraxis_colorbar=dict(
                title=variable_name
            )
        )

        return fig

    except Exception as e:
        print(f"Error in update_map: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating map: {str(e)}")
        return fig

@app.callback(
    [Output('news-signal-plot', 'figure'),
     Output('news-headlines', 'children')],
    [Input('news-country-dropdown', 'value'),
     Input('news-timeframe', 'value')]
)
def update_news_signal(country, days):
    try:
        end_date = datetime.now()
        start_date = end_date - timedelta(days=days)

        # Get disease signal data
        news_data = fetch_promed_data(newsapi_key, start_date, end_date)
        if country != 'Global':
            news_data = news_data[news_data['location'] == country]

        # If no data, return empty figures
        if news_data.empty:
            fig = go.Figure()
            fig.update_layout(
                title=f"No disease signal data available for {country} in last {days} days",
                template="plotly_white"
            )
            return fig, "No news headlines available"

        # Summarize by date
        news_by_date = news_data.groupby('date')['disease_signal'].mean().reset_index()

        # Create signal plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=news_by_date['date'],
            y=news_by_date['disease_signal'],
            mode='lines+markers',
            name='Disease Signal',
            line=dict(color='#e74c3c', width=3)
        ))

        # Add actual cases if available
        cases_data = df[(df['location'] == country) &
                    (df['date'] >= start_date) &
                    (df['date'] <= end_date) &
                    (df['new_cases'].notna())]

        if not cases_data.empty:
            # Normalize for display
            max_signal = news_by_date['disease_signal'].max()
            max_cases = cases_data['new_cases'].max()
            scale_factor = max_signal / max_cases if max_cases > 0 else 1

            fig.add_trace(go.Scatter(
                x=cases_data['date'],
                y=cases_data['new_cases'] * scale_factor,
                mode='lines',
                name='New Cases (Scaled)',
                yaxis='y2',
                line=dict(color='#3498db', width=2, dash='dot')
            ))

        fig.update_layout(
            title=f"Disease Signal for {country} - Last {days} Days",
            xaxis_title="Date",
            yaxis_title="Disease Signal",
            template="plotly_white",
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            margin=dict(l=40, r=40, t=60, b=40),
            yaxis2=dict(
                title="New Cases (Scaled)",
                overlaying="y",
                side="right",
                showgrid=False
            ) if not cases_data.empty else {}
        )

        # Create simulated news headlines
        headlines = [
            {"title": "New outbreak reported in northeastern region", "score": 95},
            {"title": "Health officials monitoring unusual respiratory symptoms", "score": 80},
            {"title": "Vaccination campaign launched to combat emerging strain", "score": 70},
            {"title": "Hospital capacity strained amid rising cases", "score": 85},
            {"title": "Travel restrictions considered as cases increase", "score": 75}
        ]

        headlines_output = html.Div([
            html.H5("Recent Disease-Related Headlines"),
            dbc.ListGroup([
                dbc.ListGroupItem([
                    html.Div([
                        html.Strong(headline["title"]),
                        html.Span(
                            f" - Signal Score: {headline['score']}",
                            className="ml-2",
                            style={"color": "#e74c3c" if headline["score"] > 80 else "#f39c12" if headline["score"] > 60 else "#2ecc71"}
                        )
                    ])
                ]) for headline in headlines
            ])
        ])

        return fig, headlines_output

    except Exception as e:
        print(f"Error in update_news_signal: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating news signal: {str(e)}")
        return fig, f"Error: {str(e)}"

# Run the app
if __name__ == '__main__':
    try:
        app.run_server(mode='inline', debug=True)
    except Exception as e:
        print(f"Error running Dash app: {e}")
        print("Trying external mode...")
        app.run(mode='external', debug=True, port=8050)


In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
from google.colab import userdata

# Load NewsAPI key
try:
    newsapi_key = userdata.get('NEWSAPI_KEY')  # Set in Colab Secrets
    print("NewsAPI key loaded successfully")
except Exception as e:
    print(f"Failed to load NewsAPI key: {e}")
    newsapi_key = "YOUR_API_KEY_HERE"  # Replace with your actual key if not using Colab Secrets

# Load OWID data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)
df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None)  # Ensure timezone-naive
df = df[['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
         'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']]

# Simulate healthcare capacity data
capacity_data = {
    'location': ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany', 'France', 'Italy', 'Spain',
                'China', 'Japan', 'South Korea', 'Australia', 'Canada', 'Mexico'],
    'hospital_beds': [10000, 5000, 3000, 8000, 9000, 7500, 7000, 6000, 12000, 9500, 8500, 6500, 7800, 4000],
    'icu_beds': [2000, 1000, 500, 1600, 1800, 1500, 1400, 1200, 2400, 1900, 1700, 1300, 1560, 800],
    'ventilators': [500, 300, 200, 400, 450, 375, 350, 300, 600, 475, 425, 325, 390, 200]
}
capacity_df = pd.DataFrame(capacity_data)

# Fetch NewsAPI data
def fetch_promed_data(api_key, date_range_start=None, date_range_end=None):
    try:
        url = "https://newsapi.org/v2/everything"
        params = {
            "q": "disease outbreak OR epidemic OR health emergency OR infectious disease OR COVID-19",
            "language": "en",
            "sortBy": "publishedAt",
            "apiKey": api_key,
            "pageSize": 30
        }
        if date_range_start and date_range_end:
            params["from"] = date_range_start.strftime('%Y-%m-%d')
            params["to"] = date_range_end.strftime('%Y-%m-%d')

        response = requests.get(url, params=params, timeout=10)

        if response.status_code == 200:
            articles = response.json().get("articles", [])
            if not articles:
                raise Exception("No articles found")

            data = []
            for article in articles:
                date = pd.to_datetime(article["publishedAt"]).tz_localize(None).normalize()  # Date-only
                text = (article["title"] + " " + (article.get("content", "") or article["description"])).lower()

                signal = (
                    100 if any(word in text for word in ["outbreak", "epidemic", "pandemic"]) else
                    50 if any(word in text for word in ["health", "disease", "emergency"]) else
                    10
                )

                location = (
                    "USA" if any(word in text for word in ["usa", "united states", "america"]) else
                    "India" if any(word in text for word in ["india", "indian"]) else
                    "Brazil" if any(word in text for word in ["brazil", "brazilian"]) else
                    "Unknown"
                )

                data.append({"date": date, "location": location, "disease_signal": signal})

            df = pd.DataFrame(data)
            if df.empty:
                raise Exception("No relevant data parsed")

            df["date"] = pd.to_datetime(df["date"]).dt.tz_localize(None).dt.normalize()
            print("NewsAPI data fetched:", df.head())
            return df

        else:
            raise Exception(f"API request failed with status {response.status_code}")

    except Exception as e:
        print(f"Error fetching NewsAPI data: {e}")
        owid_dates = pd.date_range(start=date_range_start, end=date_range_end, freq='D')[-30:] if date_range_start and date_range_end else \
                     [pd.to_datetime(datetime.now().date() - timedelta(days=x)).tz_localize(None).normalize() for x in range(30)]
        fallback = pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": ["USA"] * len(owid_dates),
            "disease_signal": np.random.randint(10, 100, len(owid_dates))
        })
        print("Using fallback data")
        return fallback

# Prepare features
def prepare_features(df, location, target='total_cases', lookback=7):
    df_loc = df[df['location'] == location].sort_values('date')

    df_loc = df_loc.dropna(subset=[target])

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data for {location} with non-null {target}")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)

    date_start = df_loc['date'].min()
    date_end = df_loc['date'].max()
    news_data = fetch_promed_data(newsapi_key, date_start, date_end)
    if 'disease_signal' not in news_data.columns:
        news_data['disease_signal'] = 0

    df_loc['date'] = df_loc['date'].dt.normalize()
    news_data['date'] = news_data['date'].dt.normalize()

    df_loc = df_loc.merge(news_data[['date', 'disease_signal']], on='date', how='left')

    if 'disease_signal' not in df_loc.columns:
        df_loc['disease_signal'] = 0
    df_loc['disease_signal'] = df_loc['disease_signal'].fillna(0)

    feature_cols = [target] + [f'lag_{i}' for i in range(1, lookback + 1)]
    df_loc = df_loc.dropna(subset=feature_cols)

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data after processing for {location}, returning fallback")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    return df_loc

# Random Forest
def train_random_forest(df, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features]
    y = df[target]

    if X.empty or X.shape[0] < 1:
        raise ValueError("No samples available for training RandomForestRegressor")

    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X, y)

    importance = pd.DataFrame({'feature': features, 'importance': model.feature_importances_})
    return model, importance

# Prophet
def train_prophet(df, target='total_cases'):
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})
    if prophet_df.shape[0] < 2:
        raise ValueError("Dataframe has less than 2 non-NaN rows for training Prophet")
    model = Prophet(yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=True)
    model.fit(prophet_df)
    return model

# Ensemble
def ensemble_predict(rf_model, prophet_model, df, forecast_days, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features].tail(1)
    rf_pred = rf_model.predict(X)

    future = prophet_model.make_future_dataframe(periods=forecast_days)
    prophet_pred = prophet_model.predict(future)['yhat'].tail(forecast_days).values

    ensemble_pred = (rf_pred[0] + prophet_pred[0]) / 2
    return ensemble_pred

# Forecasting
def forecast_resource(df, location, forecast_days, model_type, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)

        if model_type == 'Random Forest':
            model, importance = train_random_forest(df_loc, target)
            features = [col for col in df_loc.columns if 'lag_' in col or col == 'disease_signal']
            X = df_loc[features].tail(1)
            pred = model.predict(X)
            return pred, importance, None

        elif model_type == 'Prophet':
            model = train_prophet(df_loc, target)
            future = model.make_future_dataframe(periods=forecast_days)
            forecast = model.predict(future)
            return forecast['yhat'].tail(forecast_days).values, None, forecast[['ds', 'yhat_lower', 'yhat_upper']]

        else:  # Ensemble
            rf_model, importance = train_random_forest(df_loc, target)
            prophet_model = train_prophet(df_loc, target)
            pred = ensemble_predict(rf_model, prophet_model, df_loc, forecast_days, target)
            return pred, importance, None
    except Exception as e:
        print(f"Error in forecast_resource: {e}")
        return np.array([0]), None, None

# Model evaluation
def evaluate_model(df, location, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        train_size = int(0.8 * len(df_loc))
        train_df, test_df = df_loc[:train_size], df_loc[train_size:]

        if test_df.empty:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        rf_model, _ = train_random_forest(train_df, target)
        features = [col for col in test_df.columns if 'lag_' in col or col == 'disease_signal']
        rf_pred = rf_model.predict(test_df[features])
        rf_rmse = np.sqrt(mean_squared_error(test_df[target], rf_pred))
        rf_r2 = r2_score(test_df[target], rf_pred)

        prophet_model = train_prophet(train_df, target)
        future = prophet_model.make_future_dataframe(periods=len(test_df))
        forecast = prophet_model.predict(future)
        prophet_pred = forecast['yhat'].tail(len(test_df)).values
        prophet_rmse = np.sqrt(mean_squared_error(test_df[target], prophet_pred))
        prophet_r2 = r2_score(test_df[target], prophet_pred)

        return {
            'Random Forest': {'RMSE': rf_rmse, 'R2': rf_r2},
            'Prophet': {'RMSE': prophet_rmse, 'R2': prophet_r2}
        }
    except Exception as e:
        print(f"Error in evaluate_model: {e}")
        return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

# Calculate resource allocation
def calculate_resources(predicted_cases, country):
    # Define resource ratios based on healthcare capacity and case severity
    resource_ratios = {
        'hospital_beds': 0.15,  # 15% of cases require hospitalization
        'icu_beds': 0.05,       # 5% of cases require ICU
        'ventilators': 0.02,    # 2% of cases require ventilators
        'medical_staff': 0.3,   # Staff needed per hospitalized patient
        'ppe_kits': 10,         # PPE kits per hospitalized patient per day
        'test_kits': 2,         # Test kits per case
    }

    # Get country capacity
    country_capacity = capacity_df[capacity_df['location'] == country]
    if country_capacity.empty:
        # Use default values if country not found
        capacity = {
            'hospital_beds': 5000,
            'icu_beds': 1000,
            'ventilators': 300
        }
    else:
        capacity = {
            'hospital_beds': country_capacity['hospital_beds'].iloc[0],
            'icu_beds': country_capacity['icu_beds'].iloc[0],
            'ventilators': country_capacity['ventilators'].iloc[0]
        }

    # Calculate needed resources
    predicted_value = predicted_cases[0] if isinstance(predicted_cases, np.ndarray) else predicted_cases
    resources_needed = {
        'hospital_beds': int(predicted_value * resource_ratios['hospital_beds']),
        'icu_beds': int(predicted_value * resource_ratios['icu_beds']),
        'ventilators': int(predicted_value * resource_ratios['ventilators']),
        'medical_staff': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['medical_staff']),
        'ppe_kits': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['ppe_kits']),
        'test_kits': int(predicted_value * resource_ratios['test_kits'])
    }

    # Calculate shortages
    shortages = {
        'hospital_beds': max(0, resources_needed['hospital_beds'] - capacity['hospital_beds']),
        'icu_beds': max(0, resources_needed['icu_beds'] - capacity['icu_beds']),
        'ventilators': max(0, resources_needed['ventilators'] - capacity['ventilators'])
    }

    return resources_needed, capacity, shortages

# Run retraining
for _ in range(2):  # Reduced to 2 for faster startup
    rf_model, prophet_model, df = retrain_model(df, 'USA', newsapi_key)
    if rf_model is None or prophet_model is None:
        print("Retraining skipped due to invalid data")
    else:
        print("Retraining complete")

# Create filtered data for map
map_data = df.groupby('location').last().reset_index()
map_data = map_data.dropna(subset=['total_cases'])

# Dashboard
app = Dash(_name_, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Add CSS for custom styling
app.index_string = '''
<!DOCTYPE html>
<html>
    <head>
        {%metas%}
        <title>{%title%}</title>
        {%favicon%}
        {%css%}
        <style>
            .card {
                box-shadow: 0 4px 8px rgba(0,0,0,0.1);
                transition: 0.3s;
                border-radius: 10px;
                margin-bottom: 20px;
            }
            .card:hover {
                box-shadow: 0 8px 16px rgba(0,0,0,0.2);
            }
            .card-header {
                background-color: #3498db;
                color: white;
                border-radius: 10px 10px 0 0 !important;
                font-weight: bold;
            }
            .nav-tabs .nav-link.active {
                background-color: #3498db;
                color: white;
            }
            .tab-content {
                border: 1px solid #dee2e6;
                border-top: 0;
                padding: 15px;
            }
            .resource-alert {
                padding: 10px;
                border-radius: 5px;
                margin-bottom: 10px;
            }
            .dashboard-title {
                color: #2c3e50;
                text-align: center;
                margin-bottom: 30px;
                font-weight: bold;
            }
        </style>
    </head>
    <body>
        {%app_entry%}
        <footer>
            {%config%}
            {%scripts%}
            {%renderer%}
        </footer>
    </body>
</html>
'''

app.layout = dbc.Container([
    html.H1("Medical Resource Demand Prediction Dashboard", className="dashboard-title mt-4 mb-4"),

    dbc.Tabs([
        dbc.Tab(label="Predictions & Resources", children=[
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Control Panel"),
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col([
                                    html.Label("Select Country", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='country-dropdown',
                                        options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                        value='USA',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Forecast Days", style={'fontWeight': 'bold'}),
                                    dcc.Slider(
                                        id='forecast-days',
                                        min=7,
                                        max=90,
                                        step=1,
                                        value=30,
                                        marks={i: str(i) for i in range(7, 91, 14)},
                                        tooltip={"placement": "bottom", "always_visible": True}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Model Type", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='model-type',
                                        options=[
                                            {'label': 'Random Forest', 'value': 'Random Forest'},
                                            {'label': 'Prophet', 'value': 'Prophet'},
                                            {'label': 'Ensemble', 'value': 'Ensemble'}
                                        ],
                                        value='Ensemble',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4)
                            ]),
                            html.Label("Target Variable", className="mt-3", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='target-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            ),
                            dbc.Button("Run Prediction", id='run-btn', color='primary', className='mt-3', style={'width': '100%'})
                        ])
                    ])
                ], width=12)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Prediction Results"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-prediction",
                                type="circle",
                                children=dcc.Graph(id='prediction-plot')
                            )
                        ])
                    ])
                ], width=12, lg=8),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Required Medical Resources"),
                        dbc.CardBody([
                            html.Div(id='resources-output')
                        ])
                    ])
                ], width=12, lg=4)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Feature Importance"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-features",
                                type="circle",
                                children=dcc.Graph(id='feature-importance-plot')
                            )
                        ])
                    ])
                ], width=12, lg=6),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Model Performance Metrics"),
                        dbc.CardBody([
                            html.Div(id='performance-metrics')
                        ])
                    ])
                ], width=12, lg=6)
            ])
        ]),

        dbc.Tab(label="Global Impact Map", children=[
            dbc.Card([
                dbc.CardHeader("Global Disease Distribution"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Visualization", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='map-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'Cases per Million', 'value': 'cases_per_million'},
                                    {'label': 'Deaths per Million', 'value': 'deaths_per_million'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Map Type", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='map-type',
                                options=[
                                    {'label': 'Choropleth Map', 'value': 'choropleth'},
                                    {'label': 'Bubble Map', 'value': 'bubble'}
                                ],
                                value='choropleth',
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-map",
                        type="circle",
                        children=dcc.Graph(id='global-map', style={'height': '70vh'})
                    )
                ])
            ])
        ]),

        dbc.Tab(label="News & Disease Signals", children=[
            dbc.Card([
                dbc.CardHeader("Real-time Disease Signal Analysis"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Country", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='news-country-dropdown',
                                options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                value='USA',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Date Range", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='news-timeframe',
                                options=[
                                    {'label': 'Last 30 days', 'value': 30},
                                    {'label': 'Last 90 days', 'value': 90},
                                    {'label': 'Last 180 days', 'value': 180}
                                ],
                                value=30,
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-news",
                        type="circle",
                        children=[
                            dcc.Graph(id='news-signal-plot'),
                            html.Div(id='news-headlines', className="mt-4")
                        ]
                    )
                ])
            ])
        ])
    ])
], fluid=True)

# Prepare models and data
def retrain_model(df, location, api_key, target='total_cases', interval_hours=24):
    try:
        last_date = df['date'].max()
        new_data = fetch_promed_data(api_key, df['date'].min(), df['date'].max())

        if new_data.empty or not pd.api.types.is_datetime64_any_dtype(new_data['date']):
            print("No valid new data to concatenate")
            return None, None, df

        new_data['date'] = pd.to_datetime(new_data['date']).dt.tz_localize(None).dt.normalize()
        new_data_max = new_data['date'].max()

        if pd.notna(new_data_max) and new_data_max > last_date:
            new_data = new_data[['date', 'location', 'disease_signal']]
            for col in df.columns:
                if col not in new_data.columns:
                    new_data[col] = np.nan
            df = pd.concat([df, new_data], ignore_index=True)
            df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None).dt.normalize()

        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            print(f"Warning: Insufficient data for {location} after prepare_features")
            return None, None, df

        rf_model, _ = train_random_forest(df_loc, target)
        prophet_model = train_prophet(df_loc, target)
        return rf_model, prophet_model, df
    except Exception as e:
        print(f"Error in retrain_model: {e}")
        return None, None, df

# Calculate derived metrics
map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

# Callbacks
@app.callback(
    [Output('prediction-plot', 'figure'),
     Output('feature-importance-plot', 'figure'),
     Output('performance-metrics', 'children'),
     Output('resources-output', 'children')],
    [Input('run-btn', 'n_clicks')],
    [State('country-dropdown', 'value'),
     State('forecast-days', 'value'),
     State('model-type', 'value'),
     State('target-variable', 'value')]
)
def update_prediction_dashboard(n_clicks, country, forecast_days, model_type, target_variable):
    if n_clicks is None:
        pred_fig = go.Figure()
        pred_fig.update_layout(title="No predictions yet - press 'Run Prediction'")

        imp_fig = go.Figure()
        imp_fig.update_layout(title="No feature importance data yet")

        return pred_fig, imp_fig, "No performance metrics yet", "No resource calculation yet"

    try:
        # Run prediction
        pred, importance, confidence = forecast_resource(df, country, forecast_days, model_type, target_variable)

        # Create prediction plot
        pred_fig = go.Figure()
        pred_dates = [datetime.now() + timedelta(days=i) for i in range(forecast_days)]
        pred_values = pred if not isinstance(pred, float) else [pred] * forecast_days

        # Add historical data to the plot
        if target_variable in ['total_cases', 'total_deaths', 'new_cases', 'new_deaths', 'hosp_patients', 'icu_patients']:
            historical = df[(df['location'] == country) & (df[target_variable].notna())]
            if not historical.empty:
                pred_fig.add_trace(go.Scatter(
                    x=historical['date'][-30:],  # Last 30 days
                    y=historical[target_variable][-30:],
                    name='Historical Data',
                    line=dict(color='royalblue')
                ))

        # Add prediction line
        pred_fig.add_trace(go.Scatter(
            x=pred_dates,
            y=pred_values,
            name='Prediction',
            line=dict(color='#e74c3c', dash='dash', width=3)
        ))

        # Add confidence intervals if available
        if confidence is not None:
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_lower'],
                fill=None,
                mode='lines',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Lower CI'
            ))
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_upper'],
                fill='tonexty',
                fillcolor='rgba(231, 76, 60, 0.2)',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Upper CI'
            ))

        target_name = {
            'total_cases': 'Total Cases',
            'new_cases': 'New Cases',
            'total_deaths': 'Total Deaths',
            'new_deaths': 'New Deaths',
            'hosp_patients': 'Hospitalized Patients',
            'icu_patients': 'ICU Patients'
        }.get(target_variable, target_variable)

        pred_fig.update_layout(
            title=f"{model_type} Prediction for {target_name} in {country}",
            xaxis_title="Date",
            yaxis_title=target_name,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            template="plotly_white",
            margin=dict(l=40, r=40, t=60, b=40)
        )

        # Create feature importance plot
        imp_fig = go.Figure()
        if importance is not None and not importance.empty:
            # Sort by importance
            importance = importance.sort_values('importance', ascending=False)

            imp_fig = px.bar(
                importance,
                x='feature',
                y='importance',
                title='Feature Importance Analysis',
                labels={'feature': 'Feature', 'importance': 'Importance Score'},
                color='importance',
                color_continuous_scale='Viridis'
            )

            imp_fig.update_layout(
                xaxis_title="Feature",
                yaxis_title="Importance Score",
                template="plotly_white",
                margin=dict(l=40, r=40, t=60, b=40)
            )
        else:
            imp_fig.update_layout(title="No feature importance data available for this model")

        # Performance metrics
        metrics = evaluate_model(df, country, target_variable)
        metrics_output = dbc.Card([
            dbc.Row([
                dbc.Col([
                    html.H5("Random Forest", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Random Forest']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Random Forest']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Random Forest']['R2'] * 100))}%",
                                    "backgroundColor": "#2ecc71"
                                }
                            )
                        ])
                    ])
                ], width=6),
                dbc.Col([
                    html.H5("Prophet", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Prophet']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Prophet']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Prophet']['R2'] * 100))}%",
                                    "backgroundColor": "#3498db"
                                }
                            )
                        ])
                    ])
                ], width=6)
            ])
        ], body=True)

        # Resource calculation
        resources_needed, capacity, shortages = calculate_resources(pred_values[0], country)

        def get_alert_color(used, capacity):
            utilization = (used / capacity) if capacity > 0 else 1
            if utilization < 0.7:
                return "success"
            elif utilization < 0.9:
                return "warning"
            else:
                return "danger"

        resources_output = html.Div([
            html.H5(f"Projected Resource Needs ({forecast_days} days)"),
            html.Div([
                dbc.Alert([
                    html.Div([
                        html.Strong("Hospital Beds"),
                        html.Span(f": {resources_needed['hospital_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['hospital_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['hospital_beds'] / capacity['hospital_beds'] * 100) if capacity['hospital_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['hospital_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['hospital_beds']}", className="text-danger")
                    if shortages['hospital_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['hospital_beds'], capacity['hospital_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("ICU Beds"),
                        html.Span(f": {resources_needed['icu_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['icu_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['icu_beds'] / capacity['icu_beds'] * 100) if capacity['icu_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['icu_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['icu_beds']}", className="text-danger")
                    if shortages['icu_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['icu_beds'], capacity['icu_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Ventilators"),
                        html.Span(f": {resources_needed['ventilators']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['ventilators']} units"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['ventilators'] / capacity['ventilators'] * 100) if capacity['ventilators'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['ventilators'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['ventilators']}", className="text-danger")
                    if shortages['ventilators'] > 0 else None
                ], color=get_alert_color(resources_needed['ventilators'], capacity['ventilators']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Additional Resources"),
                        html.Ul([
                            html.Li(f"Medical Staff: {resources_needed['medical_staff']} needed"),
                            html.Li(f"PPE Kits: {resources_needed['ppe_kits']} needed"),
                            html.Li(f"Test Kits: {resources_needed['test_kits']} needed")
                        ])
                    ])
                ], color="info")
            ])
        ])

        return pred_fig, imp_fig, metrics_output, resources_output

    except Exception as e:
        print(f"Error in update_prediction_dashboard: {e}")
        return go.Figure(), go.Figure(), f"Error: {str(e)}", f"Error: {str(e)}"

@app.callback(
    Output('global-map', 'figure'),
    [Input('map-variable', 'value'),
     Input('map-type', 'value')]
)
def update_map(map_variable, map_type):
    try:
        # Calculate per million values if not already done
        if 'cases_per_million' not in map_data.columns:
            map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
        if 'deaths_per_million' not in map_data.columns:
            map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

        # Get variable name for display
        variable_name = {
            'total_cases': 'Total Cases',
            'total_deaths': 'Total Deaths',
            'cases_per_million': 'Cases per Million',
            'deaths_per_million': 'Deaths per Million'
        }.get(map_variable, map_variable)

        # Create appropriate map
        if map_type == 'choropleth':
            fig = px.choropleth(
                map_data,
                locations='location',
                locationmode='country names',
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )
        else:  # bubble map
            fig = px.scatter_geo(
                map_data,
                locations='location',
                locationmode='country names',
                size=map_variable,
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                size_max=50,
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )

        fig.update_layout(
            height=700,  # Make the map larger
            margin=dict(l=0, r=0, t=50, b=0),
            template="plotly_white",
            coloraxis_colorbar=dict(
                title=variable_name
            )
        )

        return fig

    except Exception as e:
        print(f"Error in update_map: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating map: {str(e)}")
        return fig

@app.callback(
    [Output('news-signal-plot', 'figure'),
     Output('news-headlines', 'children')],
    [Input('news-country-dropdown', 'value'),
     Input('news-timeframe', 'value')]
)
def update_news_signal(country, days):
    try:
        end_date = datetime.now()
        start_date = end_date - timedelta(days=days)

        # Get disease signal data
        news_data = fetch_promed_data(newsapi_key, start_date, end_date)
        if country != 'Global':
            news_data = news_data[news_data['location'] == country]

        # If no data, return empty figures
        if news_data.empty:
            fig = go.Figure()
            fig.update_layout(
                title=f"No disease signal data available for {country} in last {days} days",
                template="plotly_white"
            )
            return fig, "No news headlines available"

        # Summarize by date
        news_by_date = news_data.groupby('date')['disease_signal'].mean().reset_index()

        # Create signal plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=news_by_date['date'],
            y=news_by_date['disease_signal'],
            mode='lines+markers',
            name='Disease Signal',
            line=dict(color='#e74c3c', width=3)
        ))

        # Add actual cases if available
        cases_data = df[(df['location'] == country) &
                    (df['date'] >= start_date) &
                    (df['date'] <= end_date) &
                    (df['new_cases'].notna())]

        if not cases_data.empty:
            # Normalize for display
            max_signal = news_by_date['disease_signal'].max()
            max_cases = cases_data['new_cases'].max()
            scale_factor = max_signal / max_cases if max_cases > 0 else 1

            fig.add_trace(go.Scatter(
                x=cases_data['date'],
                y=cases_data['new_cases'] * scale_factor,
                mode='lines',
                name='New Cases (Scaled)',
                yaxis='y2',
                line=dict(color='#3498db', width=2, dash='dot')
            ))

        fig.update_layout(
            title=f"Disease Signal for {country} - Last {days} Days",
            xaxis_title="Date",
            yaxis_title="Disease Signal",
            template="plotly_white",
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            margin=dict(l=40, r=40, t=60, b=40),
            yaxis2=dict(
                title="New Cases (Scaled)",
                overlaying="y",
                side="right",
                showgrid=False
            ) if not cases_data.empty else {}
        )

        # Create simulated news headlines
        headlines = [
            {"title": "New outbreak reported in northeastern region", "score": 95},
            {"title": "Health officials monitoring unusual respiratory symptoms", "score": 80},
            {"title": "Vaccination campaign launched to combat emerging strain", "score": 70},
            {"title": "Hospital capacity strained amid rising cases", "score": 85},
            {"title": "Travel restrictions considered as cases increase", "score": 75}
        ]

        headlines_output = html.Div([
            html.H5("Recent Disease-Related Headlines"),
            dbc.ListGroup([
                dbc.ListGroupItem([
                    html.Div([
                        html.Strong(headline["title"]),
                        html.Span(
                            f" - Signal Score: {headline['score']}",
                            className="ml-2",
                            style={"color": "#e74c3c" if headline["score"] > 80 else "#f39c12" if headline["score"] > 60 else "#2ecc71"}
                        )
                    ])
                ]) for headline in headlines
            ])
        ])

        return fig, headlines_output

    except Exception as e:
        print(f"Error in update_news_signal: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating news signal: {str(e)}")
        return fig, f"Error: {str(e)}"

# Run the app
if _name_ == '_main_':
    try:
        app.run_server(mode='inline', debug=True)
    except Exception as e:
        print(f"Error running Dash app: {e}")
        print("Trying external mode...")
        app.run(mode='external', debug=True, port=8050)
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
from google.colab import userdata

# Load NewsAPI key
try:
    newsapi_key = userdata.get('NEWSAPI_KEY')  # Set in Colab Secrets
    print("NewsAPI key loaded successfully")
except Exception as e:
    print(f"Failed to load NewsAPI key: {e}")
    newsapi_key = "YOUR_API_KEY_HERE"  # Replace with your actual key if not using Colab Secrets

# Load OWID data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)
df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None)  # Ensure timezone-naive
df = df[['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
         'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']]

# Simulate healthcare capacity data
capacity_data = {
    'location': ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany', 'France', 'Italy', 'Spain',
                'China', 'Japan', 'South Korea', 'Australia', 'Canada', 'Mexico'],
    'hospital_beds': [10000, 5000, 3000, 8000, 9000, 7500, 7000, 6000, 12000, 9500, 8500, 6500, 7800, 4000],
    'icu_beds': [2000, 1000, 500, 1600, 1800, 1500, 1400, 1200, 2400, 1900, 1700, 1300, 1560, 800],
    'ventilators': [500, 300, 200, 400, 450, 375, 350, 300, 600, 475, 425, 325, 390, 200]
}
capacity_df = pd.DataFrame(capacity_data)

# Fetch NewsAPI data
def fetch_promed_data(api_key, date_range_start=None, date_range_end=None):
    try:
        url = "https://newsapi.org/v2/everything"
        params = {
            "q": "disease outbreak OR epidemic OR health emergency OR infectious disease OR COVID-19",
            "language": "en",
            "sortBy": "publishedAt",
            "apiKey": api_key,
            "pageSize": 30
        }
        if date_range_start and date_range_end:
            params["from"] = date_range_start.strftime('%Y-%m-%d')
            params["to"] = date_range_end.strftime('%Y-%m-%d')

        response = requests.get(url, params=params, timeout=10)

        if response.status_code == 200:
            articles = response.json().get("articles", [])
            if not articles:
                raise Exception("No articles found")

            data = []
            for article in articles:
                date = pd.to_datetime(article["publishedAt"]).tz_localize(None).normalize()  # Date-only
                text = (article["title"] + " " + (article.get("content", "") or article["description"])).lower()

                signal = (
                    100 if any(word in text for word in ["outbreak", "epidemic", "pandemic"]) else
                    50 if any(word in text for word in ["health", "disease", "emergency"]) else
                    10
                )

                location = (
                    "USA" if any(word in text for word in ["usa", "united states", "america"]) else
                    "India" if any(word in text for word in ["india", "indian"]) else
                    "Brazil" if any(word in text for word in ["brazil", "brazilian"]) else
                    "Unknown"
                )

                data.append({"date": date, "location": location, "disease_signal": signal})

            df = pd.DataFrame(data)
            if df.empty:
                raise Exception("No relevant data parsed")

            df["date"] = pd.to_datetime(df["date"]).dt.tz_localize(None).dt.normalize()
            print("NewsAPI data fetched:", df.head())
            return df

        else:
            raise Exception(f"API request failed with status {response.status_code}")

    except Exception as e:
        print(f"Error fetching NewsAPI data: {e}")
        owid_dates = pd.date_range(start=date_range_start, end=date_range_end, freq='D')[-30:] if date_range_start and date_range_end else \
                     [pd.to_datetime(datetime.now().date() - timedelta(days=x)).tz_localize(None).normalize() for x in range(30)]
        fallback = pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": ["USA"] * len(owid_dates),
            "disease_signal": np.random.randint(10, 100, len(owid_dates))
        })
        print("Using fallback data")
        return fallback

# Prepare features
def prepare_features(df, location, target='total_cases', lookback=7):
    df_loc = df[df['location'] == location].sort_values('date')

    df_loc = df_loc.dropna(subset=[target])

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data for {location} with non-null {target}")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)

    date_start = df_loc['date'].min()
    date_end = df_loc['date'].max()
    news_data = fetch_promed_data(newsapi_key, date_start, date_end)
    if 'disease_signal' not in news_data.columns:
        news_data['disease_signal'] = 0

    df_loc['date'] = df_loc['date'].dt.normalize()
    news_data['date'] = news_data['date'].dt.normalize()

    df_loc = df_loc.merge(news_data[['date', 'disease_signal']], on='date', how='left')

    if 'disease_signal' not in df_loc.columns:
        df_loc['disease_signal'] = 0
    df_loc['disease_signal'] = df_loc['disease_signal'].fillna(0)

    feature_cols = [target] + [f'lag_{i}' for i in range(1, lookback + 1)]
    df_loc = df_loc.dropna(subset=feature_cols)

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data after processing for {location}, returning fallback")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    return df_loc

# Random Forest
def train_random_forest(df, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features]
    y = df[target]

    if X.empty or X.shape[0] < 1:
        raise ValueError("No samples available for training RandomForestRegressor")

    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X, y)

    importance = pd.DataFrame({'feature': features, 'importance': model.feature_importances_})
    return model, importance

# Prophet
def train_prophet(df, target='total_cases'):
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})
    if prophet_df.shape[0] < 2:
        raise ValueError("Dataframe has less than 2 non-NaN rows for training Prophet")
    model = Prophet(yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=True)
    model.fit(prophet_df)
    return model

# Ensemble
def ensemble_predict(rf_model, prophet_model, df, forecast_days, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features].tail(1)
    rf_pred = rf_model.predict(X)

    future = prophet_model.make_future_dataframe(periods=forecast_days)
    prophet_pred = prophet_model.predict(future)['yhat'].tail(forecast_days).values

    ensemble_pred = (rf_pred[0] + prophet_pred[0]) / 2
    return ensemble_pred

# Forecasting
def forecast_resource(df, location, forecast_days, model_type, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)

        if model_type == 'Random Forest':
            model, importance = train_random_forest(df_loc, target)
            features = [col for col in df_loc.columns if 'lag_' in col or col == 'disease_signal']
            X = df_loc[features].tail(1)
            pred = model.predict(X)
            return pred, importance, None

        elif model_type == 'Prophet':
            model = train_prophet(df_loc, target)
            future = model.make_future_dataframe(periods=forecast_days)
            forecast = model.predict(future)
            return forecast['yhat'].tail(forecast_days).values, None, forecast[['ds', 'yhat_lower', 'yhat_upper']]

        else:  # Ensemble
            rf_model, importance = train_random_forest(df_loc, target)
            prophet_model = train_prophet(df_loc, target)
            pred = ensemble_predict(rf_model, prophet_model, df_loc, forecast_days, target)
            return pred, importance, None
    except Exception as e:
        print(f"Error in forecast_resource: {e}")
        return np.array([0]), None, None

# Model evaluation
def evaluate_model(df, location, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        train_size = int(0.8 * len(df_loc))
        train_df, test_df = df_loc[:train_size], df_loc[train_size:]

        if test_df.empty:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        rf_model, _ = train_random_forest(train_df, target)
        features = [col for col in test_df.columns if 'lag_' in col or col == 'disease_signal']
        rf_pred = rf_model.predict(test_df[features])
        rf_rmse = np.sqrt(mean_squared_error(test_df[target], rf_pred))
        rf_r2 = r2_score(test_df[target], rf_pred)

        prophet_model = train_prophet(train_df, target)
        future = prophet_model.make_future_dataframe(periods=len(test_df))
        forecast = prophet_model.predict(future)
        prophet_pred = forecast['yhat'].tail(len(test_df)).values
        prophet_rmse = np.sqrt(mean_squared_error(test_df[target], prophet_pred))
        prophet_r2 = r2_score(test_df[target], prophet_pred)

        return {
            'Random Forest': {'RMSE': rf_rmse, 'R2': rf_r2},
            'Prophet': {'RMSE': prophet_rmse, 'R2': prophet_r2}
        }
    except Exception as e:
        print(f"Error in evaluate_model: {e}")
        return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

# Calculate resource allocation
def calculate_resources(predicted_cases, country):
    # Define resource ratios based on healthcare capacity and case severity
    resource_ratios = {
        'hospital_beds': 0.15,  # 15% of cases require hospitalization
        'icu_beds': 0.05,       # 5% of cases require ICU
        'ventilators': 0.02,    # 2% of cases require ventilators
        'medical_staff': 0.3,   # Staff needed per hospitalized patient
        'ppe_kits': 10,         # PPE kits per hospitalized patient per day
        'test_kits': 2,         # Test kits per case
    }

    # Get country capacity
    country_capacity = capacity_df[capacity_df['location'] == country]
    if country_capacity.empty:
        # Use default values if country not found
        capacity = {
            'hospital_beds': 5000,
            'icu_beds': 1000,
            'ventilators': 300
        }
    else:
        capacity = {
            'hospital_beds': country_capacity['hospital_beds'].iloc[0],
            'icu_beds': country_capacity['icu_beds'].iloc[0],
            'ventilators': country_capacity['ventilators'].iloc[0]
        }

    # Calculate needed resources
    predicted_value = predicted_cases[0] if isinstance(predicted_cases, np.ndarray) else predicted_cases
    resources_needed = {
        'hospital_beds': int(predicted_value * resource_ratios['hospital_beds']),
        'icu_beds': int(predicted_value * resource_ratios['icu_beds']),
        'ventilators': int(predicted_value * resource_ratios['ventilators']),
        'medical_staff': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['medical_staff']),
        'ppe_kits': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['ppe_kits']),
        'test_kits': int(predicted_value * resource_ratios['test_kits'])
    }

    # Calculate shortages
    shortages = {
        'hospital_beds': max(0, resources_needed['hospital_beds'] - capacity['hospital_beds']),
        'icu_beds': max(0, resources_needed['icu_beds'] - capacity['icu_beds']),
        'ventilators': max(0, resources_needed['ventilators'] - capacity['ventilators'])
    }

    return resources_needed, capacity, shortages

# Run retraining
for _ in range(2):  # Reduced to 2 for faster startup
    rf_model, prophet_model, df = retrain_model(df, 'USA', newsapi_key)
    if rf_model is None or prophet_model is None:
        print("Retraining skipped due to invalid data")
    else:
        print("Retraining complete")

# Create filtered data for map
map_data = df.groupby('location').last().reset_index()
map_data = map_data.dropna(subset=['total_cases'])

# Dashboard
app = Dash(_name_, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Add CSS for custom styling
app.index_string = '''
<!DOCTYPE html>
<html>
    <head>
        {%metas%}
        <title>{%title%}</title>
        {%favicon%}
        {%css%}
        <style>
            .card {
                box-shadow: 0 4px 8px rgba(0,0,0,0.1);
                transition: 0.3s;
                border-radius: 10px;
                margin-bottom: 20px;
            }
            .card:hover {
                box-shadow: 0 8px 16px rgba(0,0,0,0.2);
            }
            .card-header {
                background-color: #3498db;
                color: white;
                border-radius: 10px 10px 0 0 !important;
                font-weight: bold;
            }
            .nav-tabs .nav-link.active {
                background-color: #3498db;
                color: white;
            }
            .tab-content {
                border: 1px solid #dee2e6;
                border-top: 0;
                padding: 15px;
            }
            .resource-alert {
                padding: 10px;
                border-radius: 5px;
                margin-bottom: 10px;
            }
            .dashboard-title {
                color: #2c3e50;
                text-align: center;
                margin-bottom: 30px;
                font-weight: bold;
            }
        </style>
    </head>
    <body>
        {%app_entry%}
        <footer>
            {%config%}
            {%scripts%}
            {%renderer%}
        </footer>
    </body>
</html>
'''

app.layout = dbc.Container([
    html.H1("Medical Resource Demand Prediction Dashboard", className="dashboard-title mt-4 mb-4"),

    dbc.Tabs([
        dbc.Tab(label="Predictions & Resources", children=[
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Control Panel"),
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col([
                                    html.Label("Select Country", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='country-dropdown',
                                        options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                        value='USA',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Forecast Days", style={'fontWeight': 'bold'}),
                                    dcc.Slider(
                                        id='forecast-days',
                                        min=7,
                                        max=90,
                                        step=1,
                                        value=30,
                                        marks={i: str(i) for i in range(7, 91, 14)},
                                        tooltip={"placement": "bottom", "always_visible": True}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Model Type", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='model-type',
                                        options=[
                                            {'label': 'Random Forest', 'value': 'Random Forest'},
                                            {'label': 'Prophet', 'value': 'Prophet'},
                                            {'label': 'Ensemble', 'value': 'Ensemble'}
                                        ],
                                        value='Ensemble',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4)
                            ]),
                            html.Label("Target Variable", className="mt-3", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='target-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            ),
                            dbc.Button("Run Prediction", id='run-btn', color='primary', className='mt-3', style={'width': '100%'})
                        ])
                    ])
                ], width=12)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Prediction Results"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-prediction",
                                type="circle",
                                children=dcc.Graph(id='prediction-plot')
                            )
                        ])
                    ])
                ], width=12, lg=8),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Required Medical Resources"),
                        dbc.CardBody([
                            html.Div(id='resources-output')
                        ])
                    ])
                ], width=12, lg=4)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Feature Importance"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-features",
                                type="circle",
                                children=dcc.Graph(id='feature-importance-plot')
                            )
                        ])
                    ])
                ], width=12, lg=6),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Model Performance Metrics"),
                        dbc.CardBody([
                            html.Div(id='performance-metrics')
                        ])
                    ])
                ], width=12, lg=6)
            ])
        ]),

        dbc.Tab(label="Global Impact Map", children=[
            dbc.Card([
                dbc.CardHeader("Global Disease Distribution"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Visualization", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='map-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'Cases per Million', 'value': 'cases_per_million'},
                                    {'label': 'Deaths per Million', 'value': 'deaths_per_million'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Map Type", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='map-type',
                                options=[
                                    {'label': 'Choropleth Map', 'value': 'choropleth'},
                                    {'label': 'Bubble Map', 'value': 'bubble'}
                                ],
                                value='choropleth',
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-map",
                        type="circle",
                        children=dcc.Graph(id='global-map', style={'height': '70vh'})
                    )
                ])
            ])
        ]),

        dbc.Tab(label="News & Disease Signals", children=[
            dbc.Card([
                dbc.CardHeader("Real-time Disease Signal Analysis"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Country", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='news-country-dropdown',
                                options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                value='USA',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Date Range", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='news-timeframe',
                                options=[
                                    {'label': 'Last 30 days', 'value': 30},
                                    {'label': 'Last 90 days', 'value': 90},
                                    {'label': 'Last 180 days', 'value': 180}
                                ],
                                value=30,
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-news",
                        type="circle",
                        children=[
                            dcc.Graph(id='news-signal-plot'),
                            html.Div(id='news-headlines', className="mt-4")
                        ]
                    )
                ])
            ])
        ])
    ])
], fluid=True)

# Prepare models and data
def retrain_model(df, location, api_key, target='total_cases', interval_hours=24):
    try:
        last_date = df['date'].max()
        new_data = fetch_promed_data(api_key, df['date'].min(), df['date'].max())

        if new_data.empty or not pd.api.types.is_datetime64_any_dtype(new_data['date']):
            print("No valid new data to concatenate")
            return None, None, df

        new_data['date'] = pd.to_datetime(new_data['date']).dt.tz_localize(None).dt.normalize()
        new_data_max = new_data['date'].max()

        if pd.notna(new_data_max) and new_data_max > last_date:
            new_data = new_data[['date', 'location', 'disease_signal']]
            for col in df.columns:
                if col not in new_data.columns:
                    new_data[col] = np.nan
            df = pd.concat([df, new_data], ignore_index=True)
            df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None).dt.normalize()

        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            print(f"Warning: Insufficient data for {location} after prepare_features")
            return None, None, df

        rf_model, _ = train_random_forest(df_loc, target)
        prophet_model = train_prophet(df_loc, target)
        return rf_model, prophet_model, df
    except Exception as e:
        print(f"Error in retrain_model: {e}")
        return None, None, df

# Calculate derived metrics
map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

# Callbacks
@app.callback(
    [Output('prediction-plot', 'figure'),
     Output('feature-importance-plot', 'figure'),
     Output('performance-metrics', 'children'),
     Output('resources-output', 'children')],
    [Input('run-btn', 'n_clicks')],
    [State('country-dropdown', 'value'),
     State('forecast-days', 'value'),
     State('model-type', 'value'),
     State('target-variable', 'value')]
)
def update_prediction_dashboard(n_clicks, country, forecast_days, model_type, target_variable):
    if n_clicks is None:
        pred_fig = go.Figure()
        pred_fig.update_layout(title="No predictions yet - press 'Run Prediction'")

        imp_fig = go.Figure()
        imp_fig.update_layout(title="No feature importance data yet")

        return pred_fig, imp_fig, "No performance metrics yet", "No resource calculation yet"

    try:
        # Run prediction
        pred, importance, confidence = forecast_resource(df, country, forecast_days, model_type, target_variable)

        # Create prediction plot
        pred_fig = go.Figure()
        pred_dates = [datetime.now() + timedelta(days=i) for i in range(forecast_days)]
        pred_values = pred if not isinstance(pred, float) else [pred] * forecast_days

        # Add historical data to the plot
        if target_variable in ['total_cases', 'total_deaths', 'new_cases', 'new_deaths', 'hosp_patients', 'icu_patients']:
            historical = df[(df['location'] == country) & (df[target_variable].notna())]
            if not historical.empty:
                pred_fig.add_trace(go.Scatter(
                    x=historical['date'][-30:],  # Last 30 days
                    y=historical[target_variable][-30:],
                    name='Historical Data',
                    line=dict(color='royalblue')
                ))

        # Add prediction line
        pred_fig.add_trace(go.Scatter(
            x=pred_dates,
            y=pred_values,
            name='Prediction',
            line=dict(color='#e74c3c', dash='dash', width=3)
        ))

        # Add confidence intervals if available
        if confidence is not None:
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_lower'],
                fill=None,
                mode='lines',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Lower CI'
            ))
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_upper'],
                fill='tonexty',
                fillcolor='rgba(231, 76, 60, 0.2)',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Upper CI'
            ))

        target_name = {
            'total_cases': 'Total Cases',
            'new_cases': 'New Cases',
            'total_deaths': 'Total Deaths',
            'new_deaths': 'New Deaths',
            'hosp_patients': 'Hospitalized Patients',
            'icu_patients': 'ICU Patients'
        }.get(target_variable, target_variable)

        pred_fig.update_layout(
            title=f"{model_type} Prediction for {target_name} in {country}",
            xaxis_title="Date",
            yaxis_title=target_name,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            template="plotly_white",
            margin=dict(l=40, r=40, t=60, b=40)
        )

        # Create feature importance plot
        imp_fig = go.Figure()
        if importance is not None and not importance.empty:
            # Sort by importance
            importance = importance.sort_values('importance', ascending=False)

            imp_fig = px.bar(
                importance,
                x='feature',
                y='importance',
                title='Feature Importance Analysis',
                labels={'feature': 'Feature', 'importance': 'Importance Score'},
                color='importance',
                color_continuous_scale='Viridis'
            )

            imp_fig.update_layout(
                xaxis_title="Feature",
                yaxis_title="Importance Score",
                template="plotly_white",
                margin=dict(l=40, r=40, t=60, b=40)
            )
        else:
            imp_fig.update_layout(title="No feature importance data available for this model")

        # Performance metrics
        metrics = evaluate_model(df, country, target_variable)
        metrics_output = dbc.Card([
            dbc.Row([
                dbc.Col([
                    html.H5("Random Forest", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Random Forest']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Random Forest']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Random Forest']['R2'] * 100))}%",
                                    "backgroundColor": "#2ecc71"
                                }
                            )
                        ])
                    ])
                ], width=6),
                dbc.Col([
                    html.H5("Prophet", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Prophet']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Prophet']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Prophet']['R2'] * 100))}%",
                                    "backgroundColor": "#3498db"
                                }
                            )
                        ])
                    ])
                ], width=6)
            ])
        ], body=True)

        # Resource calculation
        resources_needed, capacity, shortages = calculate_resources(pred_values[0], country)

        def get_alert_color(used, capacity):
            utilization = (used / capacity) if capacity > 0 else 1
            if utilization < 0.7:
                return "success"
            elif utilization < 0.9:
                return "warning"
            else:
                return "danger"

        resources_output = html.Div([
            html.H5(f"Projected Resource Needs ({forecast_days} days)"),
            html.Div([
                dbc.Alert([
                    html.Div([
                        html.Strong("Hospital Beds"),
                        html.Span(f": {resources_needed['hospital_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['hospital_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['hospital_beds'] / capacity['hospital_beds'] * 100) if capacity['hospital_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['hospital_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['hospital_beds']}", className="text-danger")
                    if shortages['hospital_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['hospital_beds'], capacity['hospital_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("ICU Beds"),
                        html.Span(f": {resources_needed['icu_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['icu_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['icu_beds'] / capacity['icu_beds'] * 100) if capacity['icu_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['icu_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['icu_beds']}", className="text-danger")
                    if shortages['icu_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['icu_beds'], capacity['icu_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Ventilators"),
                        html.Span(f": {resources_needed['ventilators']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['ventilators']} units"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['ventilators'] / capacity['ventilators'] * 100) if capacity['ventilators'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['ventilators'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['ventilators']}", className="text-danger")
                    if shortages['ventilators'] > 0 else None
                ], color=get_alert_color(resources_needed['ventilators'], capacity['ventilators']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Additional Resources"),
                        html.Ul([
                            html.Li(f"Medical Staff: {resources_needed['medical_staff']} needed"),
                            html.Li(f"PPE Kits: {resources_needed['ppe_kits']} needed"),
                            html.Li(f"Test Kits: {resources_needed['test_kits']} needed")
                        ])
                    ])
                ], color="info")
            ])
        ])

        return pred_fig, imp_fig, metrics_output, resources_output

    except Exception as e:
        print(f"Error in update_prediction_dashboard: {e}")
        return go.Figure(), go.Figure(), f"Error: {str(e)}", f"Error: {str(e)}"

@app.callback(
    Output('global-map', 'figure'),
    [Input('map-variable', 'value'),
     Input('map-type', 'value')]
)
def update_map(map_variable, map_type):
    try:
        # Calculate per million values if not already done
        if 'cases_per_million' not in map_data.columns:
            map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
        if 'deaths_per_million' not in map_data.columns:
            map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

        # Get variable name for display
        variable_name = {
            'total_cases': 'Total Cases',
            'total_deaths': 'Total Deaths',
            'cases_per_million': 'Cases per Million',
            'deaths_per_million': 'Deaths per Million'
        }.get(map_variable, map_variable)

        # Create appropriate map
        if map_type == 'choropleth':
            fig = px.choropleth(
                map_data,
                locations='location',
                locationmode='country names',
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )
        else:  # bubble map
            fig = px.scatter_geo(
                map_data,
                locations='location',
                locationmode='country names',
                size=map_variable,
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                size_max=50,
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )

        fig.update_layout(
            height=700,  # Make the map larger
            margin=dict(l=0, r=0, t=50, b=0),
            template="plotly_white",
            coloraxis_colorbar=dict(
                title=variable_name
            )
        )

        return fig

    except Exception as e:
        print(f"Error in update_map: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating map: {str(e)}")
        return fig

@app.callback(
    [Output('news-signal-plot', 'figure'),
     Output('news-headlines', 'children')],
    [Input('news-country-dropdown', 'value'),
     Input('news-timeframe', 'value')]
)
def update_news_signal(country, days):
    try:
        end_date = datetime.now()
        start_date = end_date - timedelta(days=days)

        # Get disease signal data
        news_data = fetch_promed_data(newsapi_key, start_date, end_date)
        if country != 'Global':
            news_data = news_data[news_data['location'] == country]

        # If no data, return empty figures
        if news_data.empty:
            fig = go.Figure()
            fig.update_layout(
                title=f"No disease signal data available for {country} in last {days} days",
                template="plotly_white"
            )
            return fig, "No news headlines available"

        # Summarize by date
        news_by_date = news_data.groupby('date')['disease_signal'].mean().reset_index()

        # Create signal plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=news_by_date['date'],
            y=news_by_date['disease_signal'],
            mode='lines+markers',
            name='Disease Signal',
            line=dict(color='#e74c3c', width=3)
        ))

        # Add actual cases if available
        cases_data = df[(df['location'] == country) &
                    (df['date'] >= start_date) &
                    (df['date'] <= end_date) &
                    (df['new_cases'].notna())]

        if not cases_data.empty:
            # Normalize for display
            max_signal = news_by_date['disease_signal'].max()
            max_cases = cases_data['new_cases'].max()
            scale_factor = max_signal / max_cases if max_cases > 0 else 1

            fig.add_trace(go.Scatter(
                x=cases_data['date'],
                y=cases_data['new_cases'] * scale_factor,
                mode='lines',
                name='New Cases (Scaled)',
                yaxis='y2',
                line=dict(color='#3498db', width=2, dash='dot')
            ))

        fig.update_layout(
            title=f"Disease Signal for {country} - Last {days} Days",
            xaxis_title="Date",
            yaxis_title="Disease Signal",
            template="plotly_white",
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            margin=dict(l=40, r=40, t=60, b=40),
            yaxis2=dict(
                title="New Cases (Scaled)",
                overlaying="y",
                side="right",
                showgrid=False
            ) if not cases_data.empty else {}
        )

        # Create simulated news headlines
        headlines = [
            {"title": "New outbreak reported in northeastern region", "score": 95},
            {"title": "Health officials monitoring unusual respiratory symptoms", "score": 80},
            {"title": "Vaccination campaign launched to combat emerging strain", "score": 70},
            {"title": "Hospital capacity strained amid rising cases", "score": 85},
            {"title": "Travel restrictions considered as cases increase", "score": 75}
        ]

        headlines_output = html.Div([
            html.H5("Recent Disease-Related Headlines"),
            dbc.ListGroup([
                dbc.ListGroupItem([
                    html.Div([
                        html.Strong(headline["title"]),
                        html.Span(
                            f" - Signal Score: {headline['score']}",
                            className="ml-2",
                            style={"color": "#e74c3c" if headline["score"] > 80 else "#f39c12" if headline["score"] > 60 else "#2ecc71"}
                        )
                    ])
                ]) for headline in headlines
            ])
        ])

        return fig, headlines_output

    except Exception as e:
        print(f"Error in update_news_signal: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating news signal: {str(e)}")
        return fig, f"Error: {str(e)}"

# Run the app
if _name_ == '_main_':
    try:
        app.run_server(mode='inline', debug=True)
    except Exception as e:
        print(f"Error running Dash app: {e}")
        print("Trying external mode...")
        app.run(mode='external', debug=True, port=8050)


In [None]:
!pip install dash_bootstrap_components


In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
from google.colab import userdata

# Load NewsAPI key
try:
    newsapi_key = userdata.get('NEWSAPI_KEY')  # Set in Colab Secrets
    print("NewsAPI key loaded successfully")
except Exception as e:
    print(f"Failed to load NewsAPI key: {e}")
    newsapi_key = "YOUR_API_KEY_HERE"  # Replace with your actual key if not using Colab Secrets

# Load OWID data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)
df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None)  # Ensure timezone-naive
df = df[['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
         'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']]

# Simulate healthcare capacity data
capacity_data = {
    'location': ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany', 'France', 'Italy', 'Spain',
                'China', 'Japan', 'South Korea', 'Australia', 'Canada', 'Mexico'],
    'hospital_beds': [10000, 5000, 3000, 8000, 9000, 7500, 7000, 6000, 12000, 9500, 8500, 6500, 7800, 4000],
    'icu_beds': [2000, 1000, 500, 1600, 1800, 1500, 1400, 1200, 2400, 1900, 1700, 1300, 1560, 800],
    'ventilators': [500, 300, 200, 400, 450, 375, 350, 300, 600, 475, 425, 325, 390, 200]
}
capacity_df = pd.DataFrame(capacity_data)

# Fetch NewsAPI data
def fetch_promed_data(api_key, date_range_start=None, date_range_end=None):
    try:
        url = "https://newsapi.org/v2/everything"
        params = {
            "q": "disease outbreak OR epidemic OR health emergency OR infectious disease OR COVID-19",
            "language": "en",
            "sortBy": "publishedAt",
            "apiKey": api_key,
            "pageSize": 30
        }
        if date_range_start and date_range_end:
            params["from"] = date_range_start.strftime('%Y-%m-%d')
            params["to"] = date_range_end.strftime('%Y-%m-%d')

        response = requests.get(url, params=params, timeout=10)

        if response.status_code == 200:
            articles = response.json().get("articles", [])
            if not articles:
                raise Exception("No articles found")

            data = []
            for article in articles:
                date = pd.to_datetime(article["publishedAt"]).tz_localize(None).normalize()  # Date-only
                text = (article["title"] + " " + (article.get("content", "") or article["description"])).lower()

                signal = (
                    100 if any(word in text for word in ["outbreak", "epidemic", "pandemic"]) else
                    50 if any(word in text for word in ["health", "disease", "emergency"]) else
                    10
                )

                location = (
                    "USA" if any(word in text for word in ["usa", "united states", "america"]) else
                    "India" if any(word in text for word in ["india", "indian"]) else
                    "Brazil" if any(word in text for word in ["brazil", "brazilian"]) else
                    "Unknown"
                )

                data.append({"date": date, "location": location, "disease_signal": signal})

            df = pd.DataFrame(data)
            if df.empty:
                raise Exception("No relevant data parsed")

            df["date"] = pd.to_datetime(df["date"]).dt.tz_localize(None).dt.normalize()
            print("NewsAPI data fetched:", df.head())
            return df

        else:
            raise Exception(f"API request failed with status {response.status_code}")

    except Exception as e:
        print(f"Error fetching NewsAPI data: {e}")
        owid_dates = pd.date_range(start=date_range_start, end=date_range_end, freq='D')[-30:] if date_range_start and date_range_end else \
                     [pd.to_datetime(datetime.now().date() - timedelta(days=x)).tz_localize(None).normalize() for x in range(30)]
        fallback = pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": ["USA"] * len(owid_dates),
            "disease_signal": np.random.randint(10, 100, len(owid_dates))
        })
        print("Using fallback data")
        return fallback

# Prepare features
def prepare_features(df, location, target='total_cases', lookback=7):
    df_loc = df[df['location'] == location].sort_values('date')

    df_loc = df_loc.dropna(subset=[target])

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data for {location} with non-null {target}")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)

    date_start = df_loc['date'].min()
    date_end = df_loc['date'].max()
    news_data = fetch_promed_data(newsapi_key, date_start, date_end)
    if 'disease_signal' not in news_data.columns:
        news_data['disease_signal'] = 0

    df_loc['date'] = df_loc['date'].dt.normalize()
    news_data['date'] = news_data['date'].dt.normalize()

    df_loc = df_loc.merge(news_data[['date', 'disease_signal']], on='date', how='left')

    if 'disease_signal' not in df_loc.columns:
        df_loc['disease_signal'] = 0
    df_loc['disease_signal'] = df_loc['disease_signal'].fillna(0)

    feature_cols = [target] + [f'lag_{i}' for i in range(1, lookback + 1)]
    df_loc = df_loc.dropna(subset=feature_cols)

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data after processing for {location}, returning fallback")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    return df_loc

# Random Forest
def train_random_forest(df, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features]
    y = df[target]

    if X.empty or X.shape[0] < 1:
        raise ValueError("No samples available for training RandomForestRegressor")

    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X, y)

    importance = pd.DataFrame({'feature': features, 'importance': model.feature_importances_})
    return model, importance

# Prophet
def train_prophet(df, target='total_cases'):
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})
    if prophet_df.shape[0] < 2:
        raise ValueError("Dataframe has less than 2 non-NaN rows for training Prophet")
    model = Prophet(yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=True)
    model.fit(prophet_df)
    return model

# Ensemble
def ensemble_predict(rf_model, prophet_model, df, forecast_days, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features].tail(1)
    rf_pred = rf_model.predict(X)

    future = prophet_model.make_future_dataframe(periods=forecast_days)
    prophet_pred = prophet_model.predict(future)['yhat'].tail(forecast_days).values

    ensemble_pred = (rf_pred[0] + prophet_pred[0]) / 2
    return ensemble_pred

# Forecasting
def forecast_resource(df, location, forecast_days, model_type, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)

        if model_type == 'Random Forest':
            model, importance = train_random_forest(df_loc, target)
            features = [col for col in df_loc.columns if 'lag_' in col or col == 'disease_signal']
            X = df_loc[features].tail(1)
            pred = model.predict(X)
            return pred, importance, None

        elif model_type == 'Prophet':
            model = train_prophet(df_loc, target)
            future = model.make_future_dataframe(periods=forecast_days)
            forecast = model.predict(future)
            return forecast['yhat'].tail(forecast_days).values, None, forecast[['ds', 'yhat_lower', 'yhat_upper']]

        else:  # Ensemble
            rf_model, importance = train_random_forest(df_loc, target)
            prophet_model = train_prophet(df_loc, target)
            pred = ensemble_predict(rf_model, prophet_model, df_loc, forecast_days, target)
            return pred, importance, None
    except Exception as e:
        print(f"Error in forecast_resource: {e}")
        return np.array([0]), None, None

# Model evaluation
def evaluate_model(df, location, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        train_size = int(0.8 * len(df_loc))
        train_df, test_df = df_loc[:train_size], df_loc[train_size:]

        if test_df.empty:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        rf_model, _ = train_random_forest(train_df, target)
        features = [col for col in test_df.columns if 'lag_' in col or col == 'disease_signal']
        rf_pred = rf_model.predict(test_df[features])
        rf_rmse = np.sqrt(mean_squared_error(test_df[target], rf_pred))
        rf_r2 = r2_score(test_df[target], rf_pred)

        prophet_model = train_prophet(train_df, target)
        future = prophet_model.make_future_dataframe(periods=len(test_df))
        forecast = prophet_model.predict(future)
        prophet_pred = forecast['yhat'].tail(len(test_df)).values
        prophet_rmse = np.sqrt(mean_squared_error(test_df[target], prophet_pred))
        prophet_r2 = r2_score(test_df[target], prophet_pred)

        return {
            'Random Forest': {'RMSE': rf_rmse, 'R2': rf_r2},
            'Prophet': {'RMSE': prophet_rmse, 'R2': prophet_r2}
        }
    except Exception as e:
        print(f"Error in evaluate_model: {e}")
        return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

# Calculate resource allocation
def calculate_resources(predicted_cases, country):
    # Define resource ratios based on healthcare capacity and case severity
    resource_ratios = {
        'hospital_beds': 0.15,  # 15% of cases require hospitalization
        'icu_beds': 0.05,       # 5% of cases require ICU
        'ventilators': 0.02,    # 2% of cases require ventilators
        'medical_staff': 0.3,   # Staff needed per hospitalized patient
        'ppe_kits': 10,         # PPE kits per hospitalized patient per day
        'test_kits': 2,         # Test kits per case
    }

    # Get country capacity
    country_capacity = capacity_df[capacity_df['location'] == country]
    if country_capacity.empty:
        # Use default values if country not found
        capacity = {
            'hospital_beds': 5000,
            'icu_beds': 1000,
            'ventilators': 300
        }
    else:
        capacity = {
            'hospital_beds': country_capacity['hospital_beds'].iloc[0],
            'icu_beds': country_capacity['icu_beds'].iloc[0],
            'ventilators': country_capacity['ventilators'].iloc[0]
        }

    # Calculate needed resources
    predicted_value = predicted_cases[0] if isinstance(predicted_cases, np.ndarray) else predicted_cases
    resources_needed = {
        'hospital_beds': int(predicted_value * resource_ratios['hospital_beds']),
        'icu_beds': int(predicted_value * resource_ratios['icu_beds']),
        'ventilators': int(predicted_value * resource_ratios['ventilators']),
        'medical_staff': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['medical_staff']),
        'ppe_kits': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['ppe_kits']),
        'test_kits': int(predicted_value * resource_ratios['test_kits'])
    }

    # Calculate shortages
    shortages = {
        'hospital_beds': max(0, resources_needed['hospital_beds'] - capacity['hospital_beds']),
        'icu_beds': max(0, resources_needed['icu_beds'] - capacity['icu_beds']),
        'ventilators': max(0, resources_needed['ventilators'] - capacity['ventilators'])
    }

    return resources_needed, capacity, shortages

# Run retraining
for _ in range(2):  # Reduced to 2 for faster startup
    rf_model, prophet_model, df = retrain_model(df, 'USA', newsapi_key)
    if rf_model is None or prophet_model is None:
        print("Retraining skipped due to invalid data")
    else:
        print("Retraining complete")

# Create filtered data for map
map_data = df.groupby('location').last().reset_index()
map_data = map_data.dropna(subset=['total_cases'])

# Dashboard
app = Dash(_name_, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Add CSS for custom styling
app.index_string = '''
<!DOCTYPE html>
<html>
    <head>
        {%metas%}
        <title>{%title%}</title>
        {%favicon%}
        {%css%}
        <style>
            .card {
                box-shadow: 0 4px 8px rgba(0,0,0,0.1);
                transition: 0.3s;
                border-radius: 10px;
                margin-bottom: 20px;
            }
            .card:hover {
                box-shadow: 0 8px 16px rgba(0,0,0,0.2);
            }
            .card-header {
                background-color: #3498db;
                color: white;
                border-radius: 10px 10px 0 0 !important;
                font-weight: bold;
            }
            .nav-tabs .nav-link.active {
                background-color: #3498db;
                color: white;
            }
            .tab-content {
                border: 1px solid #dee2e6;
                border-top: 0;
                padding: 15px;
            }
            .resource-alert {
                padding: 10px;
                border-radius: 5px;
                margin-bottom: 10px;
            }
            .dashboard-title {
                color: #2c3e50;
                text-align: center;
                margin-bottom: 30px;
                font-weight: bold;
            }
        </style>
    </head>
    <body>
        {%app_entry%}
        <footer>
            {%config%}
            {%scripts%}
            {%renderer%}
        </footer>
    </body>
</html>
'''

app.layout = dbc.Container([
    html.H1("Medical Resource Demand Prediction Dashboard", className="dashboard-title mt-4 mb-4"),

    dbc.Tabs([
        dbc.Tab(label="Predictions & Resources", children=[
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Control Panel"),
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col([
                                    html.Label("Select Country", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='country-dropdown',
                                        options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                        value='USA',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Forecast Days", style={'fontWeight': 'bold'}),
                                    dcc.Slider(
                                        id='forecast-days',
                                        min=7,
                                        max=90,
                                        step=1,
                                        value=30,
                                        marks={i: str(i) for i in range(7, 91, 14)},
                                        tooltip={"placement": "bottom", "always_visible": True}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Model Type", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='model-type',
                                        options=[
                                            {'label': 'Random Forest', 'value': 'Random Forest'},
                                            {'label': 'Prophet', 'value': 'Prophet'},
                                            {'label': 'Ensemble', 'value': 'Ensemble'}
                                        ],
                                        value='Ensemble',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4)
                            ]),
                            html.Label("Target Variable", className="mt-3", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='target-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            ),
                            dbc.Button("Run Prediction", id='run-btn', color='primary', className='mt-3', style={'width': '100%'})
                        ])
                    ])
                ], width=12)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Prediction Results"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-prediction",
                                type="circle",
                                children=dcc.Graph(id='prediction-plot')
                            )
                        ])
                    ])
                ], width=12, lg=8),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Required Medical Resources"),
                        dbc.CardBody([
                            html.Div(id='resources-output')
                        ])
                    ])
                ], width=12, lg=4)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Feature Importance"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-features",
                                type="circle",
                                children=dcc.Graph(id='feature-importance-plot')
                            )
                        ])
                    ])
                ], width=12, lg=6),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Model Performance Metrics"),
                        dbc.CardBody([
                            html.Div(id='performance-metrics')
                        ])
                    ])
                ], width=12, lg=6)
            ])
        ]),

        dbc.Tab(label="Global Impact Map", children=[
            dbc.Card([
                dbc.CardHeader("Global Disease Distribution"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Visualization", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='map-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'Cases per Million', 'value': 'cases_per_million'},
                                    {'label': 'Deaths per Million', 'value': 'deaths_per_million'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Map Type", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='map-type',
                                options=[
                                    {'label': 'Choropleth Map', 'value': 'choropleth'},
                                    {'label': 'Bubble Map', 'value': 'bubble'}
                                ],
                                value='choropleth',
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-map",
                        type="circle",
                        children=dcc.Graph(id='global-map', style={'height': '70vh'})
                    )
                ])
            ])
        ]),

        dbc.Tab(label="News & Disease Signals", children=[
            dbc.Card([
                dbc.CardHeader("Real-time Disease Signal Analysis"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Country", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='news-country-dropdown',
                                options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                value='USA',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Date Range", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='news-timeframe',
                                options=[
                                    {'label': 'Last 30 days', 'value': 30},
                                    {'label': 'Last 90 days', 'value': 90},
                                    {'label': 'Last 180 days', 'value': 180}
                                ],
                                value=30,
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-news",
                        type="circle",
                        children=[
                            dcc.Graph(id='news-signal-plot'),
                            html.Div(id='news-headlines', className="mt-4")
                        ]
                    )
                ])
            ])
        ])
    ])
], fluid=True)

# Prepare models and data
def retrain_model(df, location, api_key, target='total_cases', interval_hours=24):
    try:
        last_date = df['date'].max()
        new_data = fetch_promed_data(api_key, df['date'].min(), df['date'].max())

        if new_data.empty or not pd.api.types.is_datetime64_any_dtype(new_data['date']):
            print("No valid new data to concatenate")
            return None, None, df

        new_data['date'] = pd.to_datetime(new_data['date']).dt.tz_localize(None).dt.normalize()
        new_data_max = new_data['date'].max()

        if pd.notna(new_data_max) and new_data_max > last_date:
            new_data = new_data[['date', 'location', 'disease_signal']]
            for col in df.columns:
                if col not in new_data.columns:
                    new_data[col] = np.nan
            df = pd.concat([df, new_data], ignore_index=True)
            df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None).dt.normalize()

        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            print(f"Warning: Insufficient data for {location} after prepare_features")
            return None, None, df

        rf_model, _ = train_random_forest(df_loc, target)
        prophet_model = train_prophet(df_loc, target)
        return rf_model, prophet_model, df
    except Exception as e:
        print(f"Error in retrain_model: {e}")
        return None, None, df

# Calculate derived metrics
map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

# Callbacks
@app.callback(
    [Output('prediction-plot', 'figure'),
     Output('feature-importance-plot', 'figure'),
     Output('performance-metrics', 'children'),
     Output('resources-output', 'children')],
    [Input('run-btn', 'n_clicks')],
    [State('country-dropdown', 'value'),
     State('forecast-days', 'value'),
     State('model-type', 'value'),
     State('target-variable', 'value')]
)
def update_prediction_dashboard(n_clicks, country, forecast_days, model_type, target_variable):
    if n_clicks is None:
        pred_fig = go.Figure()
        pred_fig.update_layout(title="No predictions yet - press 'Run Prediction'")

        imp_fig = go.Figure()
        imp_fig.update_layout(title="No feature importance data yet")

        return pred_fig, imp_fig, "No performance metrics yet", "No resource calculation yet"

    try:
        # Run prediction
        pred, importance, confidence = forecast_resource(df, country, forecast_days, model_type, target_variable)

        # Create prediction plot
        pred_fig = go.Figure()
        pred_dates = [datetime.now() + timedelta(days=i) for i in range(forecast_days)]
        pred_values = pred if not isinstance(pred, float) else [pred] * forecast_days

        # Add historical data to the plot
        if target_variable in ['total_cases', 'total_deaths', 'new_cases', 'new_deaths', 'hosp_patients', 'icu_patients']:
            historical = df[(df['location'] == country) & (df[target_variable].notna())]
            if not historical.empty:
                pred_fig.add_trace(go.Scatter(
                    x=historical['date'][-30:],  # Last 30 days
                    y=historical[target_variable][-30:],
                    name='Historical Data',
                    line=dict(color='royalblue')
                ))

        # Add prediction line
        pred_fig.add_trace(go.Scatter(
            x=pred_dates,
            y=pred_values,
            name='Prediction',
            line=dict(color='#e74c3c', dash='dash', width=3)
        ))

        # Add confidence intervals if available
        if confidence is not None:
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_lower'],
                fill=None,
                mode='lines',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Lower CI'
            ))
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_upper'],
                fill='tonexty',
                fillcolor='rgba(231, 76, 60, 0.2)',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Upper CI'
            ))

        target_name = {
            'total_cases': 'Total Cases',
            'new_cases': 'New Cases',
            'total_deaths': 'Total Deaths',
            'new_deaths': 'New Deaths',
            'hosp_patients': 'Hospitalized Patients',
            'icu_patients': 'ICU Patients'
        }.get(target_variable, target_variable)

        pred_fig.update_layout(
            title=f"{model_type} Prediction for {target_name} in {country}",
            xaxis_title="Date",
            yaxis_title=target_name,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            template="plotly_white",
            margin=dict(l=40, r=40, t=60, b=40)
        )

        # Create feature importance plot
        imp_fig = go.Figure()
        if importance is not None and not importance.empty:
            # Sort by importance
            importance = importance.sort_values('importance', ascending=False)

            imp_fig = px.bar(
                importance,
                x='feature',
                y='importance',
                title='Feature Importance Analysis',
                labels={'feature': 'Feature', 'importance': 'Importance Score'},
                color='importance',
                color_continuous_scale='Viridis'
            )

            imp_fig.update_layout(
                xaxis_title="Feature",
                yaxis_title="Importance Score",
                template="plotly_white",
                margin=dict(l=40, r=40, t=60, b=40)
            )
        else:
            imp_fig.update_layout(title="No feature importance data available for this model")

        # Performance metrics
        metrics = evaluate_model(df, country, target_variable)
        metrics_output = dbc.Card([
            dbc.Row([
                dbc.Col([
                    html.H5("Random Forest", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Random Forest']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Random Forest']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Random Forest']['R2'] * 100))}%",
                                    "backgroundColor": "#2ecc71"
                                }
                            )
                        ])
                    ])
                ], width=6),
                dbc.Col([
                    html.H5("Prophet", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Prophet']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Prophet']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Prophet']['R2'] * 100))}%",
                                    "backgroundColor": "#3498db"
                                }
                            )
                        ])
                    ])
                ], width=6)
            ])
        ], body=True)

        # Resource calculation
        resources_needed, capacity, shortages = calculate_resources(pred_values[0], country)

        def get_alert_color(used, capacity):
            utilization = (used / capacity) if capacity > 0 else 1
            if utilization < 0.7:
                return "success"
            elif utilization < 0.9:
                return "warning"
            else:
                return "danger"

        resources_output = html.Div([
            html.H5(f"Projected Resource Needs ({forecast_days} days)"),
            html.Div([
                dbc.Alert([
                    html.Div([
                        html.Strong("Hospital Beds"),
                        html.Span(f": {resources_needed['hospital_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['hospital_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['hospital_beds'] / capacity['hospital_beds'] * 100) if capacity['hospital_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['hospital_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['hospital_beds']}", className="text-danger")
                    if shortages['hospital_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['hospital_beds'], capacity['hospital_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("ICU Beds"),
                        html.Span(f": {resources_needed['icu_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['icu_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['icu_beds'] / capacity['icu_beds'] * 100) if capacity['icu_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['icu_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['icu_beds']}", className="text-danger")
                    if shortages['icu_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['icu_beds'], capacity['icu_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Ventilators"),
                        html.Span(f": {resources_needed['ventilators']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['ventilators']} units"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['ventilators'] / capacity['ventilators'] * 100) if capacity['ventilators'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['ventilators'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['ventilators']}", className="text-danger")
                    if shortages['ventilators'] > 0 else None
                ], color=get_alert_color(resources_needed['ventilators'], capacity['ventilators']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Additional Resources"),
                        html.Ul([
                            html.Li(f"Medical Staff: {resources_needed['medical_staff']} needed"),
                            html.Li(f"PPE Kits: {resources_needed['ppe_kits']} needed"),
                            html.Li(f"Test Kits: {resources_needed['test_kits']} needed")
                        ])
                    ])
                ], color="info")
            ])
        ])

        return pred_fig, imp_fig, metrics_output, resources_output

    except Exception as e:
        print(f"Error in update_prediction_dashboard: {e}")
        return go.Figure(), go.Figure(), f"Error: {str(e)}", f"Error: {str(e)}"

@app.callback(
    Output('global-map', 'figure'),
    [Input('map-variable', 'value'),
     Input('map-type', 'value')]
)
def update_map(map_variable, map_type):
    try:
        # Calculate per million values if not already done
        if 'cases_per_million' not in map_data.columns:
            map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
        if 'deaths_per_million' not in map_data.columns:
            map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

        # Get variable name for display
        variable_name = {
            'total_cases': 'Total Cases',
            'total_deaths': 'Total Deaths',
            'cases_per_million': 'Cases per Million',
            'deaths_per_million': 'Deaths per Million'
        }.get(map_variable, map_variable)

        # Create appropriate map
        if map_type == 'choropleth':
            fig = px.choropleth(
                map_data,
                locations='location',
                locationmode='country names',
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )
        else:  # bubble map
            fig = px.scatter_geo(
                map_data,
                locations='location',
                locationmode='country names',
                size=map_variable,
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                size_max=50,
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )

        fig.update_layout(
            height=700,  # Make the map larger
            margin=dict(l=0, r=0, t=50, b=0),
            template="plotly_white",
            coloraxis_colorbar=dict(
                title=variable_name
            )
        )

        return fig

    except Exception as e:
        print(f"Error in update_map: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating map: {str(e)}")
        return fig

@app.callback(
    [Output('news-signal-plot', 'figure'),
     Output('news-headlines', 'children')],
    [Input('news-country-dropdown', 'value'),
     Input('news-timeframe', 'value')]
)
def update_news_signal(country, days):
    try:
        end_date = datetime.now()
        start_date = end_date - timedelta(days=days)

        # Get disease signal data
        news_data = fetch_promed_data(newsapi_key, start_date, end_date)
        if country != 'Global':
            news_data = news_data[news_data['location'] == country]

        # If no data, return empty figures
        if news_data.empty:
            fig = go.Figure()
            fig.update_layout(
                title=f"No disease signal data available for {country} in last {days} days",
                template="plotly_white"
            )
            return fig, "No news headlines available"

        # Summarize by date
        news_by_date = news_data.groupby('date')['disease_signal'].mean().reset_index()

        # Create signal plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=news_by_date['date'],
            y=news_by_date['disease_signal'],
            mode='lines+markers',
            name='Disease Signal',
            line=dict(color='#e74c3c', width=3)
        ))

        # Add actual cases if available
        cases_data = df[(df['location'] == country) &
                    (df['date'] >= start_date) &
                    (df['date'] <= end_date) &
                    (df['new_cases'].notna())]

        if not cases_data.empty:
            # Normalize for display
            max_signal = news_by_date['disease_signal'].max()
            max_cases = cases_data['new_cases'].max()
            scale_factor = max_signal / max_cases if max_cases > 0 else 1

            fig.add_trace(go.Scatter(
                x=cases_data['date'],
                y=cases_data['new_cases'] * scale_factor,
                mode='lines',
                name='New Cases (Scaled)',
                yaxis='y2',
                line=dict(color='#3498db', width=2, dash='dot')
            ))

        fig.update_layout(
            title=f"Disease Signal for {country} - Last {days} Days",
            xaxis_title="Date",
            yaxis_title="Disease Signal",
            template="plotly_white",
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            margin=dict(l=40, r=40, t=60, b=40),
            yaxis2=dict(
                title="New Cases (Scaled)",
                overlaying="y",
                side="right",
                showgrid=False
            ) if not cases_data.empty else {}
        )

        # Create simulated news headlines
        headlines = [
            {"title": "New outbreak reported in northeastern region", "score": 95},
            {"title": "Health officials monitoring unusual respiratory symptoms", "score": 80},
            {"title": "Vaccination campaign launched to combat emerging strain", "score": 70},
            {"title": "Hospital capacity strained amid rising cases", "score": 85},
            {"title": "Travel restrictions considered as cases increase", "score": 75}
        ]

        headlines_output = html.Div([
            html.H5("Recent Disease-Related Headlines"),
            dbc.ListGroup([
                dbc.ListGroupItem([
                    html.Div([
                        html.Strong(headline["title"]),
                        html.Span(
                            f" - Signal Score: {headline['score']}",
                            className="ml-2",
                            style={"color": "#e74c3c" if headline["score"] > 80 else "#f39c12" if headline["score"] > 60 else "#2ecc71"}
                        )
                    ])
                ]) for headline in headlines
            ])
        ])

        return fig, headlines_output

    except Exception as e:
        print(f"Error in update_news_signal: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating news signal: {str(e)}")
        return fig, f"Error: {str(e)}"

# Run the app
if _name_ == '_main_':
    try:
        app.run_server(mode='inline', debug=True)
    except Exception as e:
        print(f"Error running Dash app: {e}")
        print("Trying external mode...")
        app.run(mode='external', debug=True, port=8050)
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
from google.colab import userdata

# Load NewsAPI key
try:
    newsapi_key = userdata.get('NEWSAPI_KEY')  # Set in Colab Secrets
    print("NewsAPI key loaded successfully")
except Exception as e:
    print(f"Failed to load NewsAPI key: {e}")
    newsapi_key = "YOUR_API_KEY_HERE"  # Replace with your actual key if not using Colab Secrets

# Load OWID data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)
df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None)  # Ensure timezone-naive
df = df[['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
         'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']]

# Simulate healthcare capacity data
capacity_data = {
    'location': ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany', 'France', 'Italy', 'Spain',
                'China', 'Japan', 'South Korea', 'Australia', 'Canada', 'Mexico'],
    'hospital_beds': [10000, 5000, 3000, 8000, 9000, 7500, 7000, 6000, 12000, 9500, 8500, 6500, 7800, 4000],
    'icu_beds': [2000, 1000, 500, 1600, 1800, 1500, 1400, 1200, 2400, 1900, 1700, 1300, 1560, 800],
    'ventilators': [500, 300, 200, 400, 450, 375, 350, 300, 600, 475, 425, 325, 390, 200]
}
capacity_df = pd.DataFrame(capacity_data)

# Fetch NewsAPI data
def fetch_promed_data(api_key, date_range_start=None, date_range_end=None):
    try:
        url = "https://newsapi.org/v2/everything"
        params = {
            "q": "disease outbreak OR epidemic OR health emergency OR infectious disease OR COVID-19",
            "language": "en",
            "sortBy": "publishedAt",
            "apiKey": api_key,
            "pageSize": 30
        }
        if date_range_start and date_range_end:
            params["from"] = date_range_start.strftime('%Y-%m-%d')
            params["to"] = date_range_end.strftime('%Y-%m-%d')

        response = requests.get(url, params=params, timeout=10)

        if response.status_code == 200:
            articles = response.json().get("articles", [])
            if not articles:
                raise Exception("No articles found")

            data = []
            for article in articles:
                date = pd.to_datetime(article["publishedAt"]).tz_localize(None).normalize()  # Date-only
                text = (article["title"] + " " + (article.get("content", "") or article["description"])).lower()

                signal = (
                    100 if any(word in text for word in ["outbreak", "epidemic", "pandemic"]) else
                    50 if any(word in text for word in ["health", "disease", "emergency"]) else
                    10
                )

                location = (
                    "USA" if any(word in text for word in ["usa", "united states", "america"]) else
                    "India" if any(word in text for word in ["india", "indian"]) else
                    "Brazil" if any(word in text for word in ["brazil", "brazilian"]) else
                    "Unknown"
                )

                data.append({"date": date, "location": location, "disease_signal": signal})

            df = pd.DataFrame(data)
            if df.empty:
                raise Exception("No relevant data parsed")

            df["date"] = pd.to_datetime(df["date"]).dt.tz_localize(None).dt.normalize()
            print("NewsAPI data fetched:", df.head())
            return df

        else:
            raise Exception(f"API request failed with status {response.status_code}")

    except Exception as e:
        print(f"Error fetching NewsAPI data: {e}")
        owid_dates = pd.date_range(start=date_range_start, end=date_range_end, freq='D')[-30:] if date_range_start and date_range_end else \
                     [pd.to_datetime(datetime.now().date() - timedelta(days=x)).tz_localize(None).normalize() for x in range(30)]
        fallback = pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": ["USA"] * len(owid_dates),
            "disease_signal": np.random.randint(10, 100, len(owid_dates))
        })
        print("Using fallback data")
        return fallback

# Prepare features
def prepare_features(df, location, target='total_cases', lookback=7):
    df_loc = df[df['location'] == location].sort_values('date')

    df_loc = df_loc.dropna(subset=[target])

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data for {location} with non-null {target}")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)

    date_start = df_loc['date'].min()
    date_end = df_loc['date'].max()
    news_data = fetch_promed_data(newsapi_key, date_start, date_end)
    if 'disease_signal' not in news_data.columns:
        news_data['disease_signal'] = 0

    df_loc['date'] = df_loc['date'].dt.normalize()
    news_data['date'] = news_data['date'].dt.normalize()

    df_loc = df_loc.merge(news_data[['date', 'disease_signal']], on='date', how='left')

    if 'disease_signal' not in df_loc.columns:
        df_loc['disease_signal'] = 0
    df_loc['disease_signal'] = df_loc['disease_signal'].fillna(0)

    feature_cols = [target] + [f'lag_{i}' for i in range(1, lookback + 1)]
    df_loc = df_loc.dropna(subset=feature_cols)

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data after processing for {location}, returning fallback")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    return df_loc

# Random Forest
def train_random_forest(df, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features]
    y = df[target]

    if X.empty or X.shape[0] < 1:
        raise ValueError("No samples available for training RandomForestRegressor")

    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X, y)

    importance = pd.DataFrame({'feature': features, 'importance': model.feature_importances_})
    return model, importance

# Prophet
def train_prophet(df, target='total_cases'):
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})
    if prophet_df.shape[0] < 2:
        raise ValueError("Dataframe has less than 2 non-NaN rows for training Prophet")
    model = Prophet(yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=True)
    model.fit(prophet_df)
    return model

# Ensemble
def ensemble_predict(rf_model, prophet_model, df, forecast_days, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features].tail(1)
    rf_pred = rf_model.predict(X)

    future = prophet_model.make_future_dataframe(periods=forecast_days)
    prophet_pred = prophet_model.predict(future)['yhat'].tail(forecast_days).values

    ensemble_pred = (rf_pred[0] + prophet_pred[0]) / 2
    return ensemble_pred

# Forecasting
def forecast_resource(df, location, forecast_days, model_type, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)

        if model_type == 'Random Forest':
            model, importance = train_random_forest(df_loc, target)
            features = [col for col in df_loc.columns if 'lag_' in col or col == 'disease_signal']
            X = df_loc[features].tail(1)
            pred = model.predict(X)
            return pred, importance, None

        elif model_type == 'Prophet':
            model = train_prophet(df_loc, target)
            future = model.make_future_dataframe(periods=forecast_days)
            forecast = model.predict(future)
            return forecast['yhat'].tail(forecast_days).values, None, forecast[['ds', 'yhat_lower', 'yhat_upper']]

        else:  # Ensemble
            rf_model, importance = train_random_forest(df_loc, target)
            prophet_model = train_prophet(df_loc, target)
            pred = ensemble_predict(rf_model, prophet_model, df_loc, forecast_days, target)
            return pred, importance, None
    except Exception as e:
        print(f"Error in forecast_resource: {e}")
        return np.array([0]), None, None

# Model evaluation
def evaluate_model(df, location, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        train_size = int(0.8 * len(df_loc))
        train_df, test_df = df_loc[:train_size], df_loc[train_size:]

        if test_df.empty:
            return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

        rf_model, _ = train_random_forest(train_df, target)
        features = [col for col in test_df.columns if 'lag_' in col or col == 'disease_signal']
        rf_pred = rf_model.predict(test_df[features])
        rf_rmse = np.sqrt(mean_squared_error(test_df[target], rf_pred))
        rf_r2 = r2_score(test_df[target], rf_pred)

        prophet_model = train_prophet(train_df, target)
        future = prophet_model.make_future_dataframe(periods=len(test_df))
        forecast = prophet_model.predict(future)
        prophet_pred = forecast['yhat'].tail(len(test_df)).values
        prophet_rmse = np.sqrt(mean_squared_error(test_df[target], prophet_pred))
        prophet_r2 = r2_score(test_df[target], prophet_pred)

        return {
            'Random Forest': {'RMSE': rf_rmse, 'R2': rf_r2},
            'Prophet': {'RMSE': prophet_rmse, 'R2': prophet_r2}
        }
    except Exception as e:
        print(f"Error in evaluate_model: {e}")
        return {'Random Forest': {'RMSE': 0, 'R2': 0}, 'Prophet': {'RMSE': 0, 'R2': 0}}

# Calculate resource allocation
def calculate_resources(predicted_cases, country):
    # Define resource ratios based on healthcare capacity and case severity
    resource_ratios = {
        'hospital_beds': 0.15,  # 15% of cases require hospitalization
        'icu_beds': 0.05,       # 5% of cases require ICU
        'ventilators': 0.02,    # 2% of cases require ventilators
        'medical_staff': 0.3,   # Staff needed per hospitalized patient
        'ppe_kits': 10,         # PPE kits per hospitalized patient per day
        'test_kits': 2,         # Test kits per case
    }

    # Get country capacity
    country_capacity = capacity_df[capacity_df['location'] == country]
    if country_capacity.empty:
        # Use default values if country not found
        capacity = {
            'hospital_beds': 5000,
            'icu_beds': 1000,
            'ventilators': 300
        }
    else:
        capacity = {
            'hospital_beds': country_capacity['hospital_beds'].iloc[0],
            'icu_beds': country_capacity['icu_beds'].iloc[0],
            'ventilators': country_capacity['ventilators'].iloc[0]
        }

    # Calculate needed resources
    predicted_value = predicted_cases[0] if isinstance(predicted_cases, np.ndarray) else predicted_cases
    resources_needed = {
        'hospital_beds': int(predicted_value * resource_ratios['hospital_beds']),
        'icu_beds': int(predicted_value * resource_ratios['icu_beds']),
        'ventilators': int(predicted_value * resource_ratios['ventilators']),
        'medical_staff': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['medical_staff']),
        'ppe_kits': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['ppe_kits']),
        'test_kits': int(predicted_value * resource_ratios['test_kits'])
    }

    # Calculate shortages
    shortages = {
        'hospital_beds': max(0, resources_needed['hospital_beds'] - capacity['hospital_beds']),
        'icu_beds': max(0, resources_needed['icu_beds'] - capacity['icu_beds']),
        'ventilators': max(0, resources_needed['ventilators'] - capacity['ventilators'])
    }

    return resources_needed, capacity, shortages

# Run retraining
for _ in range(2):  # Reduced to 2 for faster startup
    rf_model, prophet_model, df = retrain_model(df, 'USA', newsapi_key)
    if rf_model is None or prophet_model is None:
        print("Retraining skipped due to invalid data")
    else:
        print("Retraining complete")

# Create filtered data for map
map_data = df.groupby('location').last().reset_index()
map_data = map_data.dropna(subset=['total_cases'])

# Dashboard
app = Dash(_name_, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Add CSS for custom styling
app.index_string = '''
<!DOCTYPE html>
<html>
    <head>
        {%metas%}
        <title>{%title%}</title>
        {%favicon%}
        {%css%}
        <style>
            .card {
                box-shadow: 0 4px 8px rgba(0,0,0,0.1);
                transition: 0.3s;
                border-radius: 10px;
                margin-bottom: 20px;
            }
            .card:hover {
                box-shadow: 0 8px 16px rgba(0,0,0,0.2);
            }
            .card-header {
                background-color: #3498db;
                color: white;
                border-radius: 10px 10px 0 0 !important;
                font-weight: bold;
            }
            .nav-tabs .nav-link.active {
                background-color: #3498db;
                color: white;
            }
            .tab-content {
                border: 1px solid #dee2e6;
                border-top: 0;
                padding: 15px;
            }
            .resource-alert {
                padding: 10px;
                border-radius: 5px;
                margin-bottom: 10px;
            }
            .dashboard-title {
                color: #2c3e50;
                text-align: center;
                margin-bottom: 30px;
                font-weight: bold;
            }
        </style>
    </head>
    <body>
        {%app_entry%}
        <footer>
            {%config%}
            {%scripts%}
            {%renderer%}
        </footer>
    </body>
</html>
'''

app.layout = dbc.Container([
    html.H1("Medical Resource Demand Prediction Dashboard", className="dashboard-title mt-4 mb-4"),

    dbc.Tabs([
        dbc.Tab(label="Predictions & Resources", children=[
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Control Panel"),
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col([
                                    html.Label("Select Country", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='country-dropdown',
                                        options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                        value='USA',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Forecast Days", style={'fontWeight': 'bold'}),
                                    dcc.Slider(
                                        id='forecast-days',
                                        min=7,
                                        max=90,
                                        step=1,
                                        value=30,
                                        marks={i: str(i) for i in range(7, 91, 14)},
                                        tooltip={"placement": "bottom", "always_visible": True}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Model Type", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='model-type',
                                        options=[
                                            {'label': 'Random Forest', 'value': 'Random Forest'},
                                            {'label': 'Prophet', 'value': 'Prophet'},
                                            {'label': 'Ensemble', 'value': 'Ensemble'}
                                        ],
                                        value='Ensemble',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4)
                            ]),
                            html.Label("Target Variable", className="mt-3", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='target-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            ),
                            dbc.Button("Run Prediction", id='run-btn', color='primary', className='mt-3', style={'width': '100%'})
                        ])
                    ])
                ], width=12)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Prediction Results"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-prediction",
                                type="circle",
                                children=dcc.Graph(id='prediction-plot')
                            )
                        ])
                    ])
                ], width=12, lg=8),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Required Medical Resources"),
                        dbc.CardBody([
                            html.Div(id='resources-output')
                        ])
                    ])
                ], width=12, lg=4)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Feature Importance"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-features",
                                type="circle",
                                children=dcc.Graph(id='feature-importance-plot')
                            )
                        ])
                    ])
                ], width=12, lg=6),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Model Performance Metrics"),
                        dbc.CardBody([
                            html.Div(id='performance-metrics')
                        ])
                    ])
                ], width=12, lg=6)
            ])
        ]),

        dbc.Tab(label="Global Impact Map", children=[
            dbc.Card([
                dbc.CardHeader("Global Disease Distribution"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Visualization", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='map-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'Cases per Million', 'value': 'cases_per_million'},
                                    {'label': 'Deaths per Million', 'value': 'deaths_per_million'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Map Type", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='map-type',
                                options=[
                                    {'label': 'Choropleth Map', 'value': 'choropleth'},
                                    {'label': 'Bubble Map', 'value': 'bubble'}
                                ],
                                value='choropleth',
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-map",
                        type="circle",
                        children=dcc.Graph(id='global-map', style={'height': '70vh'})
                    )
                ])
            ])
        ]),

        dbc.Tab(label="News & Disease Signals", children=[
            dbc.Card([
                dbc.CardHeader("Real-time Disease Signal Analysis"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Country", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='news-country-dropdown',
                                options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                value='USA',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Date Range", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='news-timeframe',
                                options=[
                                    {'label': 'Last 30 days', 'value': 30},
                                    {'label': 'Last 90 days', 'value': 90},
                                    {'label': 'Last 180 days', 'value': 180}
                                ],
                                value=30,
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-news",
                        type="circle",
                        children=[
                            dcc.Graph(id='news-signal-plot'),
                            html.Div(id='news-headlines', className="mt-4")
                        ]
                    )
                ])
            ])
        ])
    ])
], fluid=True)

# Prepare models and data
def retrain_model(df, location, api_key, target='total_cases', interval_hours=24):
    try:
        last_date = df['date'].max()
        new_data = fetch_promed_data(api_key, df['date'].min(), df['date'].max())

        if new_data.empty or not pd.api.types.is_datetime64_any_dtype(new_data['date']):
            print("No valid new data to concatenate")
            return None, None, df

        new_data['date'] = pd.to_datetime(new_data['date']).dt.tz_localize(None).dt.normalize()
        new_data_max = new_data['date'].max()

        if pd.notna(new_data_max) and new_data_max > last_date:
            new_data = new_data[['date', 'location', 'disease_signal']]
            for col in df.columns:
                if col not in new_data.columns:
                    new_data[col] = np.nan
            df = pd.concat([df, new_data], ignore_index=True)
            df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None).dt.normalize()

        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            print(f"Warning: Insufficient data for {location} after prepare_features")
            return None, None, df

        rf_model, _ = train_random_forest(df_loc, target)
        prophet_model = train_prophet(df_loc, target)
        return rf_model, prophet_model, df
    except Exception as e:
        print(f"Error in retrain_model: {e}")
        return None, None, df

# Calculate derived metrics
map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

# Callbacks
@app.callback(
    [Output('prediction-plot', 'figure'),
     Output('feature-importance-plot', 'figure'),
     Output('performance-metrics', 'children'),
     Output('resources-output', 'children')],
    [Input('run-btn', 'n_clicks')],
    [State('country-dropdown', 'value'),
     State('forecast-days', 'value'),
     State('model-type', 'value'),
     State('target-variable', 'value')]
)
def update_prediction_dashboard(n_clicks, country, forecast_days, model_type, target_variable):
    if n_clicks is None:
        pred_fig = go.Figure()
        pred_fig.update_layout(title="No predictions yet - press 'Run Prediction'")

        imp_fig = go.Figure()
        imp_fig.update_layout(title="No feature importance data yet")

        return pred_fig, imp_fig, "No performance metrics yet", "No resource calculation yet"

    try:
        # Run prediction
        pred, importance, confidence = forecast_resource(df, country, forecast_days, model_type, target_variable)

        # Create prediction plot
        pred_fig = go.Figure()
        pred_dates = [datetime.now() + timedelta(days=i) for i in range(forecast_days)]
        pred_values = pred if not isinstance(pred, float) else [pred] * forecast_days

        # Add historical data to the plot
        if target_variable in ['total_cases', 'total_deaths', 'new_cases', 'new_deaths', 'hosp_patients', 'icu_patients']:
            historical = df[(df['location'] == country) & (df[target_variable].notna())]
            if not historical.empty:
                pred_fig.add_trace(go.Scatter(
                    x=historical['date'][-30:],  # Last 30 days
                    y=historical[target_variable][-30:],
                    name='Historical Data',
                    line=dict(color='royalblue')
                ))

        # Add prediction line
        pred_fig.add_trace(go.Scatter(
            x=pred_dates,
            y=pred_values,
            name='Prediction',
            line=dict(color='#e74c3c', dash='dash', width=3)
        ))

        # Add confidence intervals if available
        if confidence is not None:
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_lower'],
                fill=None,
                mode='lines',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Lower CI'
            ))
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_upper'],
                fill='tonexty',
                fillcolor='rgba(231, 76, 60, 0.2)',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Upper CI'
            ))

        target_name = {
            'total_cases': 'Total Cases',
            'new_cases': 'New Cases',
            'total_deaths': 'Total Deaths',
            'new_deaths': 'New Deaths',
            'hosp_patients': 'Hospitalized Patients',
            'icu_patients': 'ICU Patients'
        }.get(target_variable, target_variable)

        pred_fig.update_layout(
            title=f"{model_type} Prediction for {target_name} in {country}",
            xaxis_title="Date",
            yaxis_title=target_name,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            template="plotly_white",
            margin=dict(l=40, r=40, t=60, b=40)
        )

        # Create feature importance plot
        imp_fig = go.Figure()
        if importance is not None and not importance.empty:
            # Sort by importance
            importance = importance.sort_values('importance', ascending=False)

            imp_fig = px.bar(
                importance,
                x='feature',
                y='importance',
                title='Feature Importance Analysis',
                labels={'feature': 'Feature', 'importance': 'Importance Score'},
                color='importance',
                color_continuous_scale='Viridis'
            )

            imp_fig.update_layout(
                xaxis_title="Feature",
                yaxis_title="Importance Score",
                template="plotly_white",
                margin=dict(l=40, r=40, t=60, b=40)
            )
        else:
            imp_fig.update_layout(title="No feature importance data available for this model")

        # Performance metrics
        metrics = evaluate_model(df, country, target_variable)
        metrics_output = dbc.Card([
            dbc.Row([
                dbc.Col([
                    html.H5("Random Forest", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Random Forest']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Random Forest']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Random Forest']['R2'] * 100))}%",
                                    "backgroundColor": "#2ecc71"
                                }
                            )
                        ])
                    ])
                ], width=6),
                dbc.Col([
                    html.H5("Prophet", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Prophet']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Prophet']['R2']:.2f}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Prophet']['R2'] * 100))}%",
                                    "backgroundColor": "#3498db"
                                }
                            )
                        ])
                    ])
                ], width=6)
            ])
        ], body=True)

        # Resource calculation
        resources_needed, capacity, shortages = calculate_resources(pred_values[0], country)

        def get_alert_color(used, capacity):
            utilization = (used / capacity) if capacity > 0 else 1
            if utilization < 0.7:
                return "success"
            elif utilization < 0.9:
                return "warning"
            else:
                return "danger"

        resources_output = html.Div([
            html.H5(f"Projected Resource Needs ({forecast_days} days)"),
            html.Div([
                dbc.Alert([
                    html.Div([
                        html.Strong("Hospital Beds"),
                        html.Span(f": {resources_needed['hospital_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['hospital_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['hospital_beds'] / capacity['hospital_beds'] * 100) if capacity['hospital_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['hospital_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['hospital_beds']}", className="text-danger")
                    if shortages['hospital_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['hospital_beds'], capacity['hospital_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("ICU Beds"),
                        html.Span(f": {resources_needed['icu_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['icu_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['icu_beds'] / capacity['icu_beds'] * 100) if capacity['icu_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['icu_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['icu_beds']}", className="text-danger")
                    if shortages['icu_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['icu_beds'], capacity['icu_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Ventilators"),
                        html.Span(f": {resources_needed['ventilators']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['ventilators']} units"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['ventilators'] / capacity['ventilators'] * 100) if capacity['ventilators'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['ventilators'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['ventilators']}", className="text-danger")
                    if shortages['ventilators'] > 0 else None
                ], color=get_alert_color(resources_needed['ventilators'], capacity['ventilators']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Additional Resources"),
                        html.Ul([
                            html.Li(f"Medical Staff: {resources_needed['medical_staff']} needed"),
                            html.Li(f"PPE Kits: {resources_needed['ppe_kits']} needed"),
                            html.Li(f"Test Kits: {resources_needed['test_kits']} needed")
                        ])
                    ])
                ], color="info")
            ])
        ])

        return pred_fig, imp_fig, metrics_output, resources_output

    except Exception as e:
        print(f"Error in update_prediction_dashboard: {e}")
        return go.Figure(), go.Figure(), f"Error: {str(e)}", f"Error: {str(e)}"

@app.callback(
    Output('global-map', 'figure'),
    [Input('map-variable', 'value'),
     Input('map-type', 'value')]
)
def update_map(map_variable, map_type):
    try:
        # Calculate per million values if not already done
        if 'cases_per_million' not in map_data.columns:
            map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
        if 'deaths_per_million' not in map_data.columns:
            map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

        # Get variable name for display
        variable_name = {
            'total_cases': 'Total Cases',
            'total_deaths': 'Total Deaths',
            'cases_per_million': 'Cases per Million',
            'deaths_per_million': 'Deaths per Million'
        }.get(map_variable, map_variable)

        # Create appropriate map
        if map_type == 'choropleth':
            fig = px.choropleth(
                map_data,
                locations='location',
                locationmode='country names',
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )
        else:  # bubble map
            fig = px.scatter_geo(
                map_data,
                locations='location',
                locationmode='country names',
                size=map_variable,
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                size_max=50,
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )

        fig.update_layout(
            height=700,  # Make the map larger
            margin=dict(l=0, r=0, t=50, b=0),
            template="plotly_white",
            coloraxis_colorbar=dict(
                title=variable_name
            )
        )

        return fig

    except Exception as e:
        print(f"Error in update_map: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating map: {str(e)}")
        return fig

@app.callback(
    [Output('news-signal-plot', 'figure'),
     Output('news-headlines', 'children')],
    [Input('news-country-dropdown', 'value'),
     Input('news-timeframe', 'value')]
)
def update_news_signal(country, days):
    try:
        end_date = datetime.now()
        start_date = end_date - timedelta(days=days)

        # Get disease signal data
        news_data = fetch_promed_data(newsapi_key, start_date, end_date)
        if country != 'Global':
            news_data = news_data[news_data['location'] == country]

        # If no data, return empty figures
        if news_data.empty:
            fig = go.Figure()
            fig.update_layout(
                title=f"No disease signal data available for {country} in last {days} days",
                template="plotly_white"
            )
            return fig, "No news headlines available"

        # Summarize by date
        news_by_date = news_data.groupby('date')['disease_signal'].mean().reset_index()

        # Create signal plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=news_by_date['date'],
            y=news_by_date['disease_signal'],
            mode='lines+markers',
            name='Disease Signal',
            line=dict(color='#e74c3c', width=3)
        ))

        # Add actual cases if available
        cases_data = df[(df['location'] == country) &
                    (df['date'] >= start_date) &
                    (df['date'] <= end_date) &
                    (df['new_cases'].notna())]

        if not cases_data.empty:
            # Normalize for display
            max_signal = news_by_date['disease_signal'].max()
            max_cases = cases_data['new_cases'].max()
            scale_factor = max_signal / max_cases if max_cases > 0 else 1

            fig.add_trace(go.Scatter(
                x=cases_data['date'],
                y=cases_data['new_cases'] * scale_factor,
                mode='lines',
                name='New Cases (Scaled)',
                yaxis='y2',
                line=dict(color='#3498db', width=2, dash='dot')
            ))

        fig.update_layout(
            title=f"Disease Signal for {country} - Last {days} Days",
            xaxis_title="Date",
            yaxis_title="Disease Signal",
            template="plotly_white",
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            margin=dict(l=40, r=40, t=60, b=40),
            yaxis2=dict(
                title="New Cases (Scaled)",
                overlaying="y",
                side="right",
                showgrid=False
            ) if not cases_data.empty else {}
        )

        # Create simulated news headlines
        headlines = [
            {"title": "New outbreak reported in northeastern region", "score": 95},
            {"title": "Health officials monitoring unusual respiratory symptoms", "score": 80},
            {"title": "Vaccination campaign launched to combat emerging strain", "score": 70},
            {"title": "Hospital capacity strained amid rising cases", "score": 85},
            {"title": "Travel restrictions considered as cases increase", "score": 75}
        ]

        headlines_output = html.Div([
            html.H5("Recent Disease-Related Headlines"),
            dbc.ListGroup([
                dbc.ListGroupItem([
                    html.Div([
                        html.Strong(headline["title"]),
                        html.Span(
                            f" - Signal Score: {headline['score']}",
                            className="ml-2",
                            style={"color": "#e74c3c" if headline["score"] > 80 else "#f39c12" if headline["score"] > 60 else "#2ecc71"}
                        )
                    ])
                ]) for headline in headlines
            ])
        ])

        return fig, headlines_output

    except Exception as e:
        print(f"Error in update_news_signal: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating news signal: {str(e)}")
        return fig, f"Error: {str(e)}"

# Run the app
if _name_ == '_main_':
    try:
        app.run_server(mode='inline', debug=True)
    except Exception as e:
        print(f"Error running Dash app: {e}")
        print("Trying external mode...")
        app.run(mode='external', debug=True, port=8050)


In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, mean_absolute_percentage_error
from prophet import Prophet
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
from datetime import datetime, timedelta
import requests
from google.colab import userdata
from xgboost import XGBRegressor

# Load NewsAPI key
try:
    newsapi_key = userdata.get('NEWSAPI_KEY')  # Set in Colab Secrets
    print("NewsAPI key loaded successfully")
except Exception as e:
    print(f"Failed to load NewsAPI key: {e}")
    newsapi_key = "9c798a7529a048efa426de9e0445af62"  # Place  actual key here if not using Colab Secrets

# Load OWID data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)
df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None)  # Ensure timezone-naive
df = df[['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
         'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']]

# Simulate healthcare capacity data
capacity_data = {
    'location': ['USA', 'India', 'Brazil', 'United Kingdom', 'Germany', 'France', 'Italy', 'Spain',
                'China', 'Japan', 'South Korea', 'Australia', 'Canada', 'Mexico'],
    'hospital_beds': [10000, 5000, 3000, 8000, 9000, 7500, 7000, 6000, 12000, 9500, 8500, 6500, 7800, 4000],
    'icu_beds': [2000, 1000, 500, 1600, 1800, 1500, 1400, 1200, 2400, 1900, 1700, 1300, 1560, 800],
    'ventilators': [500, 300, 200, 400, 450, 375, 350, 300, 600, 475, 425, 325, 390, 200]
}
capacity_df = pd.DataFrame(capacity_data)

# Fetch NewsAPI data
def fetch_promed_data(api_key, date_range_start=None, date_range_end=None):
    try:
        url = "https://newsapi.org/v2/everything"
        params = {
            "q": "disease outbreak OR epidemic OR health emergency OR infectious disease OR COVID-19",
            "language": "en",
            "sortBy": "publishedAt",
            "apiKey": api_key,
            "pageSize": 30
        }
        if date_range_start and date_range_end:
            params["from"] = date_range_start.strftime('%Y-%m-%d')
            params["to"] = date_range_end.strftime('%Y-%m-%d')

        response = requests.get(url, params=params, timeout=10)

        if response.status_code == 200:
            articles = response.json().get("articles", [])
            if not articles:
                raise Exception("No articles found")

            data = []
            for article in articles:
                date = pd.to_datetime(article["publishedAt"]).tz_localize(None).normalize()  # Date-only
                text = (article["title"] + " " + (article.get("content", "") or article["description"])).lower()

                signal = (
                    100 if any(word in text for word in ["outbreak", "epidemic", "pandemic"]) else
                    50 if any(word in text for word in ["health", "disease", "emergency"]) else
                    10
                )

                location = (
                    "USA" if any(word in text for word in ["usa", "united states", "america"]) else
                    "India" if any(word in text for word in ["india", "indian"]) else
                    "Brazil" if any(word in text for word in ["brazil", "brazilian"]) else
                    "Unknown"
                )

                data.append({"date": date, "location": location, "disease_signal": signal})

            df = pd.DataFrame(data)
            if df.empty:
                raise Exception("No relevant data parsed")

            df["date"] = pd.to_datetime(df["date"]).dt.tz_localize(None).dt.normalize()
            print("NewsAPI data fetched:", df.head())
            return df

        else:
            raise Exception(f"API request failed with status {response.status_code}")

    except Exception as e:
        print(f"Error fetching NewsAPI data: {e}")
        owid_dates = pd.date_range(start=date_range_start, end=date_range_end, freq='D')[-30:] if date_range_start and date_range_end else \
                     [pd.to_datetime(datetime.now().date() - timedelta(days=x)).tz_localize(None).normalize() for x in range(30)]
        fallback = pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": ["USA"] * len(owid_dates),
            "disease_signal": np.random.randint(10, 100, len(owid_dates))
        })
        print("Using fallback data")
        return fallback

# Prepare features
def prepare_features(df, location, target='total_cases', lookback=7):
    df_loc = df[df['location'] == location].sort_values('date')

    df_loc = df_loc.dropna(subset=[target])

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data for {location} with non-null {target}")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    for i in range(1, lookback + 1):
        df_loc[f'lag_{i}'] = df_loc[target].shift(i)

    date_start = df_loc['date'].min()
    date_end = df_loc['date'].max()
    news_data = fetch_promed_data(newsapi_key, date_start, date_end)
    if 'disease_signal' not in news_data.columns:
        news_data['disease_signal'] = 0

    df_loc['date'] = df_loc['date'].dt.normalize()
    news_data['date'] = news_data['date'].dt.normalize()

    df_loc = df_loc.merge(news_data[['date', 'disease_signal']], on='date', how='left')

    if 'disease_signal' not in df_loc.columns:
        df_loc['disease_signal'] = 0
    df_loc['disease_signal'] = df_loc['disease_signal'].fillna(0)

    feature_cols = [target] + [f'lag_{i}' for i in range(1, lookback + 1)]
    df_loc = df_loc.dropna(subset=feature_cols)

    if df_loc.shape[0] < 2:
        print(f"Warning: Insufficient data after processing for {location}, returning fallback")
        owid_dates = df[df['location'] == location]['date'].tail(2).values
        if len(owid_dates) < 2:
            owid_dates = [pd.to_datetime(datetime.now().date() - timedelta(days=x)).normalize() for x in range(2)]
        return pd.DataFrame({
            "date": pd.to_datetime(owid_dates).normalize(),
            "location": [location] * len(owid_dates),
            target: [1000, 1100],
            "disease_signal": [0, 0]
        })

    return df_loc

# Random Forest
def train_random_forest(df, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features]
    y = df[target]

    if X.empty or X.shape[0] < 1:
        raise ValueError("No samples available for training RandomForestRegressor")

    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X, y)

    importance = pd.DataFrame({'feature': features, 'importance': model.feature_importances_})
    return model, importance

# Prophet
def train_prophet(df, target='total_cases'):
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})
    if prophet_df.shape[0] < 2:
        raise ValueError("Dataframe has less than 2 non-NaN rows for training Prophet")
    model = Prophet(yearly_seasonality=True, weekly_seasonality=True, daily_seasonality=True)
    model.fit(prophet_df)
    return model

# XGBoost
def train_xgboost(df, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features]
    y = df[target]

    if X.empty or X.shape[0] < 1:
        raise ValueError("No samples available for training XGBoost")

    model = XGBRegressor(n_estimators=100, random_state=42)
    model.fit(X, y)

    importance = pd.DataFrame({'feature': features, 'importance': model.feature_importances_})
    return model, importance

# Ensemble
def ensemble_predict(rf_model, prophet_model, xgb_model, df, forecast_days, target='total_cases'):
    features = [col for col in df.columns if 'lag_' in col or col == 'disease_signal']
    X = df[features].tail(1)
    rf_pred = rf_model.predict(X)
    xgb_pred = xgb_model.predict(X)

    future = prophet_model.make_future_dataframe(periods=forecast_days)
    prophet_pred = prophet_model.predict(future)['yhat'].tail(forecast_days).values

    ensemble_pred = (rf_pred[0] + prophet_pred[0] + xgb_pred[0]) / 3
    return ensemble_pred

# Forecasting
def forecast_resource(df, location, forecast_days, model_type, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)

        if model_type == 'Random Forest':
            model, importance = train_random_forest(df_loc, target)
            features = [col for col in df_loc.columns if 'lag_' in col or col == 'disease_signal']
            X = df_loc[features].tail(1)
            pred = model.predict(X)
            return pred, importance, None

        elif model_type == 'Prophet':
            model = train_prophet(df_loc, target)
            future = model.make_future_dataframe(periods=forecast_days)
            forecast = model.predict(future)
            return forecast['yhat'].tail(forecast_days).values, None, forecast[['ds', 'yhat_lower', 'yhat_upper']]

        elif model_type == 'XGBoost':
            model, importance = train_xgboost(df_loc, target)
            features = [col for col in df_loc.columns if 'lag_' in col or col == 'disease_signal']
            X = df_loc[features].tail(1)
            pred = model.predict(X)
            return pred, importance, None

        else:  # Ensemble
            rf_model, rf_importance = train_random_forest(df_loc, target)
            prophet_model = train_prophet(df_loc, target)
            xgb_model, xgb_importance = train_xgboost(df_loc, target)
            pred = ensemble_predict(rf_model, prophet_model, xgb_model, df_loc, forecast_days, target)
            # Combine feature importances
            importance = pd.concat([rf_importance, xgb_importance]).groupby('feature').mean().reset_index()
            return pred, importance, None
    except Exception as e:
        print(f"Error in forecast_resource: {e}")
        return np.array([0]), None, None

# Model evaluation
def evaluate_model(df, location, target='total_cases'):
    try:
        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            return {
                'Random Forest': {'RMSE': 0, 'R2': 0, 'MAE': 0, 'MAPE': 0},
                'Prophet': {'RMSE': 0, 'R2': 0, 'MAE': 0, 'MAPE': 0},
                'XGBoost': {'RMSE': 0, 'R2': 0, 'MAE': 0, 'MAPE': 0}
            }

        train_size = int(0.8 * len(df_loc))
        train_df, test_df = df_loc[:train_size], df_loc[train_size:]

        if test_df.empty:
            return {
                'Random Forest': {'RMSE': 0, 'R2': 0, 'MAE': 0, 'MAPE': 0},
                'Prophet': {'RMSE': 0, 'R2': 0, 'MAE': 0, 'MAPE': 0},
                'XGBoost': {'RMSE': 0, 'R2': 0, 'MAE': 0, 'MAPE': 0}
            }

        # Random Forest evaluation
        rf_model, _ = train_random_forest(train_df, target)
        features = [col for col in test_df.columns if 'lag_' in col or col == 'disease_signal']
        rf_pred = rf_model.predict(test_df[features])
        rf_rmse = np.sqrt(mean_squared_error(test_df[target], rf_pred))
        rf_r2 = max(0, min(1, r2_score(test_df[target], rf_pred)))  # Ensure R2 is between 0 and 1
        rf_mae = mean_absolute_error(test_df[target], rf_pred)
        rf_mape = mean_absolute_percentage_error(test_df[target], rf_pred)

        # Prophet evaluation
        prophet_model = train_prophet(train_df, target)
        future = prophet_model.make_future_dataframe(periods=len(test_df))
        forecast = prophet_model.predict(future)
        prophet_pred = forecast['yhat'].tail(len(test_df)).values
        prophet_rmse = np.sqrt(mean_squared_error(test_df[target], prophet_pred))
        prophet_r2 = max(0, min(1, r2_score(test_df[target], prophet_pred)))  # Ensure R2 is between 0 and 1
        prophet_mae = mean_absolute_error(test_df[target], prophet_pred)
        prophet_mape = mean_absolute_percentage_error(test_df[target], prophet_pred)

        # XGBoost evaluation
        xgb_model, _ = train_xgboost(train_df, target)
        xgb_pred = xgb_model.predict(test_df[features])
        xgb_rmse = np.sqrt(mean_squared_error(test_df[target], xgb_pred))
        xgb_r2 = max(0, min(1, r2_score(test_df[target], xgb_pred)))  # Ensure R2 is between 0 and 1
        xgb_mae = mean_absolute_error(test_df[target], xgb_pred)
        xgb_mape = mean_absolute_percentage_error(test_df[target], xgb_pred)

        return {
            'Random Forest': {'RMSE': rf_rmse, 'R2': rf_r2, 'MAE': rf_mae, 'MAPE': rf_mape},
            'Prophet': {'RMSE': prophet_rmse, 'R2': prophet_r2, 'MAE': prophet_mae, 'MAPE': prophet_mape},
            'XGBoost': {'RMSE': xgb_rmse, 'R2': xgb_r2, 'MAE': xgb_mae, 'MAPE': xgb_mape}
        }
    except Exception as e:
        print(f"Error in evaluate_model: {e}")
        return {
            'Random Forest': {'RMSE': 0, 'R2': 0, 'MAE': 0, 'MAPE': 0},
            'Prophet': {'RMSE': 0, 'R2': 0, 'MAE': 0, 'MAPE': 0},
            'XGBoost': {'RMSE': 0, 'R2': 0, 'MAE': 0, 'MAPE': 0}
        }

# Calculate resource allocation
def calculate_resources(predicted_cases, country):
    # Ensure predicted cases is positive
    predicted_value = abs(predicted_cases[0] if isinstance(predicted_cases, np.ndarray) else predicted_cases)

    # Define resource ratios based on healthcare capacity and case severity
    resource_ratios = {
        'hospital_beds': 0.15,  # 15% of cases require hospitalization
        'icu_beds': 0.05,       # 5% of cases require ICU
        'ventilators': 0.02,    # 2% of cases require ventilators
        'medical_staff': 0.3,   # Staff needed per hospitalized patient
        'ppe_kits': 10,         # PPE kits per hospitalized patient per day
        'test_kits': 2,         # Test kits per case
    }

    # Get country capacity
    country_capacity = capacity_df[capacity_df['location'] == country]
    if country_capacity.empty:
        # Use default values if country not found
        capacity = {
            'hospital_beds': 5000,
            'icu_beds': 1000,
            'ventilators': 300
        }
    else:
        capacity = {
            'hospital_beds': country_capacity['hospital_beds'].iloc[0],
            'icu_beds': country_capacity['icu_beds'].iloc[0],
            'ventilators': country_capacity['ventilators'].iloc[0]
        }

    # Calculate needed resources
    resources_needed = {
        'hospital_beds': int(predicted_value * resource_ratios['hospital_beds']),
        'icu_beds': int(predicted_value * resource_ratios['icu_beds']),
        'ventilators': int(predicted_value * resource_ratios['ventilators']),
        'medical_staff': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['medical_staff']),
        'ppe_kits': int(predicted_value * resource_ratios['hospital_beds'] * resource_ratios['ppe_kits']),
        'test_kits': int(predicted_value * resource_ratios['test_kits'])
    }

    # Calculate shortages
    shortages = {
        'hospital_beds': max(0, resources_needed['hospital_beds'] - capacity['hospital_beds']),
        'icu_beds': max(0, resources_needed['icu_beds'] - capacity['icu_beds']),
        'ventilators': max(0, resources_needed['ventilators'] - capacity['ventilators'])
    }

    return resources_needed, capacity, shortages

# Run retraining
for _ in range(2):  # Reduced to 2 for faster startup
    try:
        rf_model, prophet_model, xgb_model, df = retrain_model(df, 'USA', newsapi_key)
        if rf_model is None or prophet_model is None or xgb_model is None:
            print("Retraining skipped due to invalid data")
        else:
            print("Retraining complete")
    except Exception as e:
        print(f"Error during retraining: {e}")

# Prepare models and data
def retrain_model(df, location, api_key, target='total_cases', interval_hours=24):
    try:
        last_date = df['date'].max()
        start_date = df['date'].min()
        new_data = fetch_promed_data(api_key, start_date, last_date)

        if new_data.empty or not pd.api.types.is_datetime64_any_dtype(new_data['date']):
            print("No valid new data to concatenate")
            return None, None, None, df

        new_data['date'] = pd.to_datetime(new_data['date']).dt.tz_localize(None).dt.normalize()
        new_data_max = new_data['date'].max()

        if pd.notna(new_data_max) and new_data_max > last_date:
            new_data = new_data[['date', 'location', 'disease_signal']]
            for col in df.columns:
                if col not in new_data.columns:
                    new_data[col] = np.nan
            df = pd.concat([df, new_data], ignore_index=True)
            df['date'] = pd.to_datetime(df['date']).dt.tz_localize(None).dt.normalize()

        df_loc = prepare_features(df, location, target)
        if df_loc.shape[0] < 2:
            print(f"Warning: Insufficient data for {location} after prepare_features")
            return None, None, None, df

        rf_model, _ = train_random_forest(df_loc, target)
        prophet_model = train_prophet(df_loc, target)
        xgb_model, _ = train_xgboost(df_loc, target)
        return rf_model, prophet_model, xgb_model, df
    except Exception as e:
        print(f"Error in retrain_model: {e}")
        return None, None, None, df

# Create filtered data for map
map_data = df.groupby('location').last().reset_index()
map_data = map_data.dropna(subset=['total_cases'])
map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

# Dashboard
app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Add CSS for custom styling
app.index_string = '''
<!DOCTYPE html>
<html>
    <head>
        {%metas%}
        <title>{%title%}</title>
        {%favicon%}
        {%css%}
        <style>
            .card {
                box-shadow: 0 4px 8px rgba(0,0,0,0.1);
                transition: 0.3s;
                border-radius: 10px;
                margin-bottom: 20px;
            }
            .card:hover {
                box-shadow: 0 8px 16px rgba(0,0,0,0.2);
            }
            .card-header {
                background-color: #3498db;
                color: white;
                border-radius: 10px 10px 0 0 !important;
                font-weight: bold;
            }
            .nav-tabs .nav-link.active {
                background-color: #3498db;
                color: white;
            }
            .tab-content {
                border: 1px solid #dee2e6;
                border-top: 0;
                padding: 15px;
            }
            .resource-alert {
                padding: 10px;
                border-radius: 5px;
                margin-bottom: 10px;
            }
            .dashboard-title {
                color: #2c3e50;
                text-align: center;
                margin-bottom: 30px;
                font-weight: bold;
            }
        </style>
    </head>
    <body>
        {%app_entry%}
        <footer>
            {%config%}
            {%scripts%}
            {%renderer%}
        </footer>
    </body>
</html>
'''

app.layout = dbc.Container([
    html.H1("Medical Resource Demand Prediction Dashboard", className="dashboard-title mt-4 mb-4"),

    dbc.Tabs([
        dbc.Tab(label="Predictions & Resources", children=[
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Control Panel"),
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col([
                                    html.Label("Select Country", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='country-dropdown',
                                        options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                        value='USA',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Forecast Days", style={'fontWeight': 'bold'}),
                                    dcc.Slider(
                                        id='forecast-days',
                                        min=7,
                                        max=90,
                                        step=1,
                                        value=30,
                                        marks={i: str(i) for i in range(7, 91, 14)},
                                        tooltip={"placement": "bottom", "always_visible": True}
                                    )
                                ], width=12, md=4),
                                dbc.Col([
                                    html.Label("Model Type", style={'fontWeight': 'bold'}),
                                    dcc.Dropdown(
                                        id='model-type',
                                        options=[
                                            {'label': 'Random Forest', 'value': 'Random Forest'},
                                            {'label': 'Prophet', 'value': 'Prophet'},
                                            {'label': 'XGBoost', 'value': 'XGBoost'},
                                            {'label': 'Ensemble', 'value': 'Ensemble'}
                                        ],
                                        value='Ensemble',
                                        style={'width': '100%'}
                                    )
                                ], width=12, md=4)
                            ]),
                            html.Label("Target Variable", className="mt-3", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='target-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            ),
                            dbc.Button("Run Prediction", id='run-btn', color='primary', className='mt-3', style={'width': '100%'})
                        ])
                    ])
                ], width=12)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Prediction Results"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-prediction",
                                type="circle",
                                children=dcc.Graph(id='prediction-plot')
                            )
                        ])
                    ])
                ], width=12, lg=8),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Required Medical Resources"),
                        dbc.CardBody([
                            html.Div(id='resources-output')
                        ])
                    ])
                ], width=12, lg=4)
            ], className="mb-4"),

            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Feature Importance"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-features",
                                type="circle",
                                children=dcc.Graph(id='feature-importance-plot')
                            )
                        ])
                    ])
                ], width=12, lg=6),
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Model Performance Metrics"),
                        dbc.CardBody([
                            html.Div(id='performance-metrics')
                        ])
                    ])
                ], width=12, lg=6)
            ])
        ]),

        dbc.Tab(label="Global Impact Map", children=[
            dbc.Card([
                dbc.CardHeader("Global Disease Distribution"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Visualization", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='map-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'Cases per Million', 'value': 'cases_per_million'},
                                    {'label': 'Deaths per Million', 'value': 'deaths_per_million'}
                                ],
                                value='total_cases',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Map Type", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='map-type',
                                options=[
                                    {'label': 'Choropleth Map', 'value': 'choropleth'},
                                    {'label': 'Bubble Map', 'value': 'bubble'}
                                ],
                                value='choropleth',
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-map",
                        type="circle",
                        children=dcc.Graph(id='global-map', style={'height': '70vh'})
                    )
                ])
            ])
        ]),

        dbc.Tab(label="News & Disease Signals", children=[
            dbc.Card([
                dbc.CardHeader("Real-time Disease Signal Analysis"),
                dbc.CardBody([
                    dbc.Row([
                        dbc.Col([
                            html.Label("Select Country", style={'fontWeight': 'bold'}),
                            dcc.Dropdown(
                                id='news-country-dropdown',
                                options=[{'label': loc, 'value': loc} for loc in df['location'].unique()],
                                value='USA',
                                style={'width': '100%'}
                            )
                        ], width=12, md=6),
                        dbc.Col([
                            html.Label("Date Range", style={'fontWeight': 'bold'}),
                            dcc.RadioItems(
                                id='news-timeframe',
                                options=[
                                    {'label': 'Last 30 days', 'value': 30},
                                    {'label': 'Last 90 days', 'value': 90},
                                    {'label': 'Last 180 days', 'value': 180}
                                ],
                                value=30,
                                inline=True,
                                style={'width': '100%'}
                            )
                        ], width=12, md=6)
                    ], className="mb-3"),
                    dcc.Loading(
                        id="loading-news",
                        type="circle",
                        children=[
                            dcc.Graph(id='news-signal-plot'),
                            html.Div(id='news-headlines', className="mt-4")
                        ]
                    )
                ])
            ])
        ])
    ])
], fluid=True)

# Callbacks
@app.callback(
    [Output('prediction-plot', 'figure'),
     Output('feature-importance-plot', 'figure'),
     Output('performance-metrics', 'children'),
     Output('resources-output', 'children')],
    [Input('run-btn', 'n_clicks')],
    [State('country-dropdown', 'value'),
     State('forecast-days', 'value'),
     State('model-type', 'value'),
     State('target-variable', 'value')]
)
def update_prediction_dashboard(n_clicks, country, forecast_days, model_type, target_variable):
    if n_clicks is None:
        pred_fig = go.Figure()
        pred_fig.update_layout(title="No predictions yet - press 'Run Prediction'")

        imp_fig = go.Figure()
        imp_fig.update_layout(title="No feature importance data yet")

        return pred_fig, imp_fig, "No performance metrics yet", "No resource calculation yet"

    try:
        # Run prediction
        pred, importance, confidence = forecast_resource(df, country, forecast_days, model_type, target_variable)

        # Create prediction plot
        pred_fig = go.Figure()
        pred_dates = [datetime.now() + timedelta(days=i) for i in range(forecast_days)]
        pred_values = pred if not isinstance(pred, float) else [pred] * forecast_days

        # Add historical data to the plot
        if target_variable in ['total_cases', 'total_deaths', 'new_cases', 'new_deaths', 'hosp_patients', 'icu_patients']:
            historical = df[(df['location'] == country) & (df[target_variable].notna())]
            if not historical.empty:
                pred_fig.add_trace(go.Scatter(
                    x=historical['date'][-30:],  # Last 30 days
                    y=historical[target_variable][-30:],
                    name='Historical Data',
                    line=dict(color='royalblue')
                ))

        # Add prediction line
        pred_fig.add_trace(go.Scatter(
            x=pred_dates,
            y=pred_values,
            name='Prediction',
            line=dict(color='#e74c3c', dash='dash', width=3)
        ))

        # Add confidence intervals if available
        if confidence is not None:
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_lower'],
                fill=None,
                mode='lines',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Lower CI'
            ))
            pred_fig.add_trace(go.Scatter(
                x=confidence['ds'],
                y=confidence['yhat_upper'],
                fill='tonexty',
                fillcolor='rgba(231, 76, 60, 0.2)',
                line=dict(color='rgba(231, 76, 60, 0.2)'),
                name='Upper CI'
            ))

        target_name = {
            'total_cases': 'Total Cases',
            'new_cases': 'New Cases',
            'total_deaths': 'Total Deaths',
            'new_deaths': 'New Deaths',
            'hosp_patients': 'Hospitalized Patients',
            'icu_patients': 'ICU Patients'
        }.get(target_variable, target_variable)

        pred_fig.update_layout(
            title=f"{model_type} Prediction for {target_name} in {country}",
            xaxis_title="Date",
            yaxis_title=target_name,
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            template="plotly_white",
            margin=dict(l=40, r=40, t=60, b=40)
        )

        # Create feature importance plot
        imp_fig = go.Figure()
        if importance is not None and not importance.empty:
            # Sort by importance
            importance = importance.sort_values('importance', ascending=False)

            imp_fig = px.bar(
                importance,
                x='feature',
                y='importance',
                title='Feature Importance Analysis',
                labels={'feature': 'Feature', 'importance': 'Importance Score'},
                color='importance',
                color_continuous_scale='Viridis'
            )

            imp_fig.update_layout(
                xaxis_title="Feature",
                yaxis_title="Importance Score",
                template="plotly_white",
                margin=dict(l=40, r=40, t=60, b=40)
            )
        else:
            imp_fig.update_layout(title="No feature importance data available for this model")

        # Performance metrics
        metrics = evaluate_model(df, country, target_variable)
        metrics_output = dbc.Card([
            dbc.Row([
                dbc.Col([
                    html.H5("Random Forest", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Random Forest']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Random Forest']['R2']:.2f}", className="mb-1"),
                        html.P(f"MAE: {metrics['Random Forest']['MAE']:.2f}", className="mb-1"),
                        html.P(f"MAPE: {metrics['Random Forest']['MAPE']:.2%}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Random Forest']['R2'] * 100))}%",
                                    "backgroundColor": "#2ecc71"
                                }
                            )
                        ])
                    ])
                ], width=4),
                dbc.Col([
                    html.H5("Prophet", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['Prophet']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['Prophet']['R2']:.2f}", className="mb-1"),
                        html.P(f"MAE: {metrics['Prophet']['MAE']:.2f}", className="mb-1"),
                        html.P(f"MAPE: {metrics['Prophet']['MAPE']:.2%}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['Prophet']['R2'] * 100))}%",
                                    "backgroundColor": "#3498db"
                                }
                            )
                        ])
                    ])
                ], width=4),
                dbc.Col([
                    html.H5("XGBoost", className="text-center"),
                    html.Div([
                        html.P(f"RMSE: {metrics['XGBoost']['RMSE']:.2f}", className="mb-1"),
                        html.P(f"R²: {metrics['XGBoost']['R2']:.2f}", className="mb-1"),
                        html.P(f"MAE: {metrics['XGBoost']['MAE']:.2f}", className="mb-1"),
                        html.P(f"MAPE: {metrics['XGBoost']['MAPE']:.2%}", className="mb-1"),
                        html.Div(className="progress mb-2", children=[
                            html.Div(
                                className="progress-bar",
                                style={
                                    "width": f"{max(0, min(100, metrics['XGBoost']['R2'] * 100))}%",
                                    "backgroundColor": "#9b59b6"
                                }
                            )
                        ])
                    ])
                ], width=4)
            ])
        ], body=True)

        # Resource calculation
        resources_needed, capacity, shortages = calculate_resources(pred_values[0], country)

        def get_alert_color(used, capacity):
            utilization = (used / capacity) if capacity > 0 else 1
            if utilization < 0.7:
                return "success"
            elif utilization < 0.9:
                return "warning"
            else:
                return "danger"

        resources_output = html.Div([
            html.H5(f"Projected Resource Needs ({forecast_days} days)"),
            html.Div([
                dbc.Alert([
                    html.Div([
                        html.Strong("Hospital Beds"),
                        html.Span(f": {resources_needed['hospital_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['hospital_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['hospital_beds'] / capacity['hospital_beds'] * 100) if capacity['hospital_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['hospital_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['hospital_beds']}", className="text-danger")
                    if shortages['hospital_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['hospital_beds'], capacity['hospital_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("ICU Beds"),
                        html.Span(f": {resources_needed['icu_beds']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['icu_beds']} beds"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['icu_beds'] / capacity['icu_beds'] * 100) if capacity['icu_beds'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['icu_beds'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['icu_beds']}", className="text-danger")
                    if shortages['icu_beds'] > 0 else None
                ], color=get_alert_color(resources_needed['icu_beds'], capacity['icu_beds']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Ventilators"),
                        html.Span(f": {resources_needed['ventilators']} needed", className="ml-2")
                    ]),
                    html.Small(f"Capacity: {capacity['ventilators']} units"),
                    html.Div(className="progress mt-1", children=[
                        html.Div(
                            className="progress-bar",
                            style={
                                "width": f"{min(100, (resources_needed['ventilators'] / capacity['ventilators'] * 100) if capacity['ventilators'] > 0 else 100)}%"
                            }
                        )
                    ]) if capacity['ventilators'] > 0 else None,
                    html.Div(f"Potential shortage: {shortages['ventilators']}", className="text-danger")
                    if shortages['ventilators'] > 0 else None
                ], color=get_alert_color(resources_needed['ventilators'], capacity['ventilators']), className="mb-2"),

                dbc.Alert([
                    html.Div([
                        html.Strong("Additional Resources"),
                        html.Ul([
                            html.Li(f"Medical Staff: {resources_needed['medical_staff']} needed"),
                            html.Li(f"PPE Kits: {resources_needed['ppe_kits']} needed"),
                            html.Li(f"Test Kits: {resources_needed['test_kits']} needed")
                        ])
                    ])
                ], color="info")
            ])
        ])

        return pred_fig, imp_fig, metrics_output, resources_output

    except Exception as e:
        print(f"Error in update_prediction_dashboard: {e}")
        return go.Figure(), go.Figure(), f"Error: {str(e)}", f"Error: {str(e)}"

@app.callback(
    Output('global-map', 'figure'),
    [Input('map-variable', 'value'),
     Input('map-type', 'value')]
)
def update_map(map_variable, map_type):
    try:
        # Calculate per million values if not already done
        if 'cases_per_million' not in map_data.columns:
            map_data['cases_per_million'] = (map_data['total_cases'] / map_data['population']) * 1000000
        if 'deaths_per_million' not in map_data.columns:
            map_data['deaths_per_million'] = (map_data['total_deaths'] / map_data['population']) * 1000000

        # Get variable name for display
        variable_name = {
            'total_cases': 'Total Cases',
            'total_deaths': 'Total Deaths',
            'cases_per_million': 'Cases per Million',
            'deaths_per_million': 'Deaths per Million'
        }.get(map_variable, map_variable)

        # Create appropriate map
        if map_type == 'choropleth':
            fig = px.choropleth(
                map_data,
                locations='location',
                locationmode='country names',
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )
        else:  # bubble map
            fig = px.scatter_geo(
                map_data,
                locations='location',
                locationmode='country names',
                size=map_variable,
                color=map_variable,
                hover_name='location',
                hover_data=[map_variable, 'total_cases', 'total_deaths'],
                size_max=50,
                color_continuous_scale='YlOrRd',
                title=f'Global {variable_name} Distribution'
            )
            fig.update_geos(
                showcoastlines=True,
                coastlinecolor="RebeccaPurple",
                showland=True,
                landcolor="LightGreen",
                showocean=True,
                oceancolor="LightBlue",
                projection_type="natural earth",
                showcountries=True,
                countrycolor="Black",
                showframe=False
            )

        fig.update_layout(
            height=700,  # Make the map larger
            margin=dict(l=0, r=0, t=50, b=0),
            template="plotly_white",
            coloraxis_colorbar=dict(
                title=variable_name
            )
        )

        return fig

    except Exception as e:
        print(f"Error in update_map: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating map: {str(e)}")
        return fig

@app.callback(
    [Output('news-signal-plot', 'figure'),
     Output('news-headlines', 'children')],
    [Input('news-country-dropdown', 'value'),
     Input('news-timeframe', 'value')]
)
def update_news_signal(country, days):
    try:
        end_date = datetime.now()
        start_date = end_date - timedelta(days=days)

        # Get disease signal data
        news_data = fetch_promed_data(newsapi_key, start_date, end_date)
        if country != 'Global':
            news_data = news_data[news_data['location'] == country]

        # If no data, return empty figures
        if news_data.empty:
            fig = go.Figure()
            fig.update_layout(
                title=f"No disease signal data available for {country} in last {days} days",
                template="plotly_white"
            )
            return fig, "No news headlines available"

        # Summarize by date
        news_by_date = news_data.groupby('date')['disease_signal'].mean().reset_index()

        # Create signal plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=news_by_date['date'],
            y=news_by_date['disease_signal'],
            mode='lines+markers',
            name='Disease Signal',
            line=dict(color='#e74c3c', width=3)
        ))

        # Add actual cases if available
        cases_data = df[(df['location'] == country) &
                    (df['date'] >= start_date) &
                    (df['date'] <= end_date) &
                    (df['new_cases'].notna())]

        if not cases_data.empty:
            # Normalize for display
            max_signal = news_by_date['disease_signal'].max()
            max_cases = cases_data['new_cases'].max()
            scale_factor = max_signal / max_cases if max_cases > 0 else 1

            fig.add_trace(go.Scatter(
                x=cases_data['date'],
                y=cases_data['new_cases'] * scale_factor,
                mode='lines',
                name='New Cases (Scaled)',
                yaxis='y2',
                line=dict(color='#3498db', width=2, dash='dot')
            ))

        fig.update_layout(
            title=f"Disease Signal for {country} - Last {days} Days",
            xaxis_title="Date",
            yaxis_title="Disease Signal",
            template="plotly_white",
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            margin=dict(l=40, r=40, t=60, b=40),
            yaxis2=dict(
                title="New Cases (Scaled)",
                overlaying="y",
                side="right",
                showgrid=False
            ) if not cases_data.empty else {}
        )

        # Create simulated news headlines
        headlines = [
            {"title": "New outbreak reported in northeastern region", "score": 95},
            {"title": "Health officials monitoring unusual respiratory symptoms", "score": 80},
            {"title": "Vaccination campaign launched to combat emerging strain", "score": 70},
            {"title": "Hospital capacity strained amid rising cases", "score": 85},
            {"title": "Travel restrictions considered as cases increase", "score": 75}
        ]

        headlines_output = html.Div([
            html.H5("Recent Disease-Related Headlines"),
            dbc.ListGroup([
                dbc.ListGroupItem([
                    html.Div([
                        html.Strong(headline["title"]),
                        html.Span(
                            f" - Signal Score: {headline['score']}",
                            className="ml-2",
                            style={"color": "#e74c3c" if headline["score"] > 80 else "#f39c12" if headline["score"] > 60 else "#2ecc71"}
                        )
                    ])
                ]) for headline in headlines
            ])
        ])

        return fig, headlines_output

    except Exception as e:
        print(f"Error in update_news_signal: {e}")
        fig = go.Figure()
        fig.update_layout(title=f"Error generating news signal: {str(e)}")
        return fig, f"Error: {str(e)}"

# Run the app
if __name__ == '__main__':
    try:
        app.run_server(mode='inline', debug=True)
    except Exception as e:
        print(f"Error running Dash app: {e}")
        print("Trying external mode...")
        app.run(mode='external', debug=True, port=8050)

In [None]:
!pip install dash

In [None]:
!pip install dash_bootstrap_components

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:


# Access your specific dataset
import pandas as pd
df = pd.read_csv('/content/drive/MyDrive/Medical_Resource_Prediction/owid-covid-data.csv')

# Temporal Distributions
# Daily New Cases Distribution

In [None]:
plt.figure(figsize=(12,6))
sns.histplot(df['new_cases'].dropna(), bins=100, kde=True)
plt.title('Global Daily New Cases Distribution')
plt.xlim(0, 100000)  # Remove extreme outliers

# Case Velocity (7-day change)

In [None]:
df['case_velocity'] = df['new_cases'].pct_change(periods=7)
sns.displot(df['case_velocity'].dropna(), bins=50, kde=True)

# Demographic Distributions
# Age-Stratified Fatality Rates

In [None]:
age_cols = ['aged_65_older', 'aged_70_older']
df[age_cols].hist(bins=30, figsize=(10,4))

# Global

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set up the figure
plt.figure(figsize=(12, 5))

# Plot distribution for aged 65+
plt.subplot(1, 2, 1)
df['aged_65_older'].hist(bins=30, color='skyblue', edgecolor='black')
plt.title('Distribution of Population Aged 65+')
plt.xlabel('Percentage of Population Aged 65+')
plt.ylabel('Number of Countries')
plt.grid(axis='y', alpha=0.3)

# Plot distribution for aged 70+
plt.subplot(1, 2, 2)
df['aged_70_older'].hist(bins=30, color='salmon', edgecolor='black')
plt.title('Distribution of Population Aged 70+')
plt.xlabel('Percentage of Population Aged 70+')
plt.ylabel('Number of Countries')
plt.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

# Population Density v Cases

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load the dataset (as specified)
df = pd.read_csv('/content/drive/MyDrive/Medical_Resource_Prediction/owid-covid-data.csv')

# Create age distribution plots with PROPER LABELS
plt.figure(figsize=(14, 6))

# Plot 1: Population aged 65+
plt.subplot(1, 2, 1)
df['aged_65_older'].hist(bins=30, color='#1f77b4', edgecolor='white')
plt.title('A. Distribution of Population Aged 65+', pad=20, fontsize=14)
plt.xlabel('Percentage of Total Population (%)', fontsize=12)
plt.ylabel('Number of Country Records', fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.4)

# Plot 2: Population aged 70+
plt.subplot(1, 2, 2)
df['aged_70_older'].hist(bins=30, color='#ff7f0e', edgecolor='white')
plt.title('B. Distribution of Population Aged 70+', pad=20, fontsize=14)
plt.xlabel('Percentage of Total Population (%)', fontsize=12)
plt.ylabel('Number of Country Records', fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.4)

plt.tight_layout()
plt.show()

In [None]:
# Calculate summary statistics
print("\nAge Distribution Statistics:")
print(f"65+ population - Mean: {df['aged_65_older'].mean():.1f}% | Max: {df['aged_65_older'].max():.1f}% (Japan)")
print(f"70+ population - Mean: {df['aged_70_older'].mean():.1f}% | Max: {df['aged_70_older'].max():.1f}% (Japan)")

In [None]:
sns.jointplot(x='population_density', y='new_cases_per_million',
             data=df, kind='hex', xlim=(0,2000))

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr

# Load the dataset
df = pd.read_csv('/content/drive/MyDrive/Medical_Resource_Prediction/owid-covid-data.csv', parse_dates=['date'])

# Select relevant features for correlation analysis
features = [
    'new_cases', 'new_deaths',
    'icu_patients', 'hosp_patients',
    'weekly_icu_admissions', 'weekly_hosp_admissions',
    'reproduction_rate', 'positive_rate',
    'population_density', 'median_age',
    'aged_65_older', 'gdp_per_capita',
    'hospital_beds_per_thousand'
]

# Initialize list for significant correlations
significant_corrs = []

# Create correlation matrix
corr_matrix = df[features].corr(method='pearson', numeric_only=True)

# Plot heatmap
plt.figure(figsize=(14, 10))
sns.heatmap(corr_matrix,
            annot=True,
            fmt=".2f",
            cmap='coolwarm',
            center=0,
            annot_kws={"size": 10},
            linewidths=0.5,
            cbar_kws={"shrink": 0.8})
plt.title('Feature Correlation Matrix', pad=20, fontsize=16)
plt.xticks(rotation=45, ha='right', fontsize=10)
plt.yticks(fontsize=10)
plt.tight_layout()
plt.show()

# Calculate statistically significant correlations (p < 0.05)
print("\nStatistically Significant Correlations (p < 0.05):")
for i in range(len(features)):
    for j in range(i+1, len(features)):
        col1 = features[i]
        col2 = features[j]
        valid_idx = df[[col1, col2]].dropna().index
        if len(valid_idx) > 10:  # Minimum sample size
            corr, p_value = pearsonr(df.loc[valid_idx, col1], df.loc[valid_idx, col2])
            if abs(corr) > 0.3 and p_value < 0.05:
                significant_corrs.append({
                    'Feature 1': col1,
                    'Feature 2': col2,
                    'Correlation': corr,
                    'p-value': p_value
                })

# Create DataFrame from significant correlations
if significant_corrs:
    signif_df = pd.DataFrame(significant_corrs).sort_values('Correlation', key=abs, ascending=False)
    display(signif_df.head(10))

    # Plot top correlations
    top_correlations = signif_df.head(6)
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    for idx, (_, row) in enumerate(top_correlations.iterrows()):
        ax = axes[idx//3, idx%3]
        sns.regplot(data=df,
                    x=row['Feature 1'],
                    y=row['Feature 2'],
                    ax=ax,
                    scatter_kws={'alpha': 0.3, 's': 15},
                    line_kws={'color': 'red'})
        ax.set_title(f"r = {row['Correlation']:.2f}\np = {row['p-value']:.2e}")
        ax.set_xlabel(row['Feature 1'])
        ax.set_ylabel(row['Feature 2'])
    plt.suptitle('Top Significant Correlations', y=1.02, fontsize=16)
    plt.tight_layout()
    plt.show()
else:
    print("No significant correlations found with |r| > 0.3 and p < 0.05")

In [None]:
plt.figure(figsize=(10,5))
sns.boxplot(x='continent', y='hospital_beds_per_thousand', data=df)

In [None]:
sns.violinplot(x='icu_patients_per_million', data=df[df['icu_patients_per_million']<100])

In [None]:
df['cfr'] = df['new_deaths']/df['new_cases']
sns.displot(df['cfr'].dropna().clip(0,0.2), bins=50)

# Vaccination coverage

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Create histogram with proper labels
import matplotlib.pyplot as plt
import numpy as np

# Create histogram with proper labels
df['people_vaccinated_per_hundred'].hist(bins=np.arange(0,110,10))
plt.xlabel('Percentage of Population Vaccinated')
plt.ylabel('Number of Country\'s records')
plt.title('Distribution of Vaccination Rates Across Countries')
plt.show()

#  Time-Based Patterns
# Weekly Seasonality

In [None]:
df['weekday'] = df['date'].dt.dayofweek
df.groupby('weekday')['new_cases'].mean().plot()

# Data Preparation and Preprocessing
# Setting up the data pipeline that will handle both the static training data and live data feeds:

In [None]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import requests
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import joblib
import json

# Load pre-processed dataset
def load_preprocessed_data(filepath):
    """Load and preprocess the static training dataset"""
    df = pd.read_csv('/content/drive/MyDrive/Medical_Resource_Prediction/owid-covid-data.csv', parse_dates=['date'])
    # Convert date and handle missing values
    df['date'] = pd.to_datetime(df['date'])
    df = df.sort_values('date')

    # Calculate additional features
    df['cases_per_million'] = (df['total_cases'] / df['population']) * 1e6
    df['deaths_per_million'] = (df['total_deaths'] / df['population']) * 1e6

    # Forward fill missing values for time series continuity
    df.fillna(method='ffill', inplace=True)

    return df

# Live data ingestion
class LiveDataIngestor:
    """Class to handle live data ingestion from various sources"""

    def __init__(self, config_file='config.json'):
        with open(config_file) as f:
            self.config = json.load(f)

        # Initialize API credentials
        self.newsapi_key = self.config.get('newsapi_key')
        self.cdc_endpoint = self.config.get('cdc_endpoint')
        self.healthmap_key = self.config.get('healthmap_key')

    def fetch_news_data(self, query, days=30):
        """Fetch disease-related news from NewsAPI"""
        end_date = datetime.now()
        start_date = end_date - timedelta(days=days)

        url = "https://newsapi.org/v2/everything"
        params = {
            "q": query,
            "from": start_date.strftime('%Y-%m-%d'),
            "to": end_date.strftime('%Y-%m-%d'),
            "sortBy": "publishedAt",
            "apiKey": self.newsapi_key,
            "pageSize": 100
        }

        try:
            response = requests.get(url, params=params, timeout=10)
            response.raise_for_status()
            return response.json().get('articles', [])
        except Exception as e:
            print(f"Error fetching news data: {e}")
            return []

    def fetch_cdc_data(self):
        """Fetch latest CDC data"""
        try:
            response = requests.get(self.cdc_endpoint, timeout=10)
            response.raise_for_status()
            return response.json()
        except Exception as e:
            print(f"Error fetching CDC data: {e}")
            return None

    def process_live_data(self, raw_data, source):
        """Process raw data from different sources into consistent format"""
        processed = []

        if source == 'newsapi':
            for article in raw_data:
                processed.append({
                    'date': pd.to_datetime(article['publishedAt']).date(),
                    'source': 'newsapi',
                    'title': article['title'],
                    'content': article['content'] or article['description'],
                    'url': article['url']
                })

        elif source == 'cdc':
            # Process CDC data format
            pass

        return pd.DataFrame(processed)

# Feature Engineering

In [None]:
def create_features(df, target='total_cases', lookback_window=14, forecast_horizon=7):
    """
    Create time-series features for training
    Args:
        df: Input dataframe
        target: Target variable to predict
        lookback_window: Number of days to look back for features
        forecast_horizon: Number of days to predict ahead
    Returns:
        DataFrame with engineered features
    """
    # Create lag features
    for i in range(1, lookback_window + 1):
        df[f'lag_{i}'] = df[target].shift(i)

    # Create rolling statistics
    df['rolling_7_mean'] = df[target].rolling(window=7).mean()
    df['rolling_14_mean'] = df[target].rolling(window=14).mean()
    df['rolling_7_std'] = df[target].rolling(window=7).std()

    # Create date-based features
    df['day_of_week'] = df['date'].dt.dayofweek
    df['day_of_month'] = df['date'].dt.day
    df['week_of_year'] = df['date'].dt.isocalendar().week

    # Calculate target for forecasting (shifted by forecast horizon)
    df['target'] = df[target].shift(-forecast_horizon)

    # Drop rows with missing values
    df = df.dropna()

    return df

# Model Training Pipelines

# Random Forest Training

In [None]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV

def train_random_forest(X_train, y_train):
    """
    Train and tune Random Forest model
    Args:
        X_train: Training features
        y_train: Training targets
    Returns:
        Best trained Random Forest model
    """
    # Define pipeline
    pipeline = Pipeline([
        ('scaler', StandardScaler()),
        ('rf', RandomForestRegressor(random_state=42))
    ])

    # Hyperparameter grid
    param_grid = {
        'rf__n_estimators': [100, 200, 300],
        'rf__max_depth': [None, 10, 20, 30],
        'rf__min_samples_split': [2, 5, 10],
        'rf__min_samples_leaf': [1, 2, 4]
    }

    # Grid search with cross-validation
    grid_search = GridSearchCV(
        estimator=pipeline,
        param_grid=param_grid,
        cv=5,
        scoring='neg_mean_squared_error',
        n_jobs=-1,
        verbose=1
    )

    grid_search.fit(X_train, y_train)

    print(f"Best Random Forest params: {grid_search.best_params_}")
    print(f"Best RMSE: {np.sqrt(-grid_search.best_score_)}")

    return grid_search.best_estimator_

#  Prophet Training

In [None]:
from prophet import Prophet
from prophet.diagnostics import cross_validation, performance_metrics

def train_prophet(df, target='total_cases', cross_val=True):
    """
    Train Facebook Prophet model
    Args:
        df: Input dataframe with 'date' and target columns
        target: Target variable name
        cross_val: Whether to perform cross-validation
    Returns:
        Trained Prophet model and performance metrics
    """
    # Prepare data for Prophet
    prophet_df = df[['date', target]].rename(columns={'date': 'ds', target: 'y'})

    # Initialize and fit model
    model = Prophet(
        yearly_seasonality=True,
        weekly_seasonality=True,
        daily_seasonality=False,
        changepoint_prior_scale=0.05,
        seasonality_prior_scale=10.0
    )

    model.fit(prophet_df)

    # Cross-validation
    if cross_val:
        df_cv = cross_validation(
            model,
            initial='180 days',
            period='30 days',
            horizon='90 days'
        )

        df_p = performance_metrics(df_cv)
        print("Prophet cross-validation metrics:")
        print(df_p.head())

    return model

# XGBoost Training

In [None]:
from xgboost import XGBRegressor
from sklearn.model_selection import RandomizedSearchCV

def train_xgboost(X_train, y_train):
    """
    Train and tune XGBoost model
    Args:
        X_train: Training features
        y_train: Training targets
    Returns:
        Best trained XGBoost model
    """
    # Define pipeline
    pipeline = Pipeline([
        ('scaler', StandardScaler()),
        ('xgb', XGBRegressor(random_state=42))
    ])

    # Hyperparameter grid
    param_grid = {
        'xgb__n_estimators': [100, 200, 300],
        'xgb__max_depth': [3, 6, 9],
        'xgb__learning_rate': [0.01, 0.05, 0.1],
        'xgb__subsample': [0.8, 0.9, 1.0],
        'xgb__colsample_bytree': [0.8, 0.9, 1.0]
    }

    # Randomized search with cross-validation
    random_search = RandomizedSearchCV(
        estimator=pipeline,
        param_distributions=param_grid,
        n_iter=20,
        cv=5,
        scoring='neg_mean_squared_error',
        random_state=42,
        n_jobs=-1,
        verbose=1
    )

    random_search.fit(X_train, y_train)

    print(f"Best XGBoost params: {random_search.best_params_}")
    print(f"Best RMSE: {np.sqrt(-random_search.best_score_)}")

    return random_search.best_estimator_

# Model Evaluation and Comparison

In [None]:
def evaluate_model(model, X_test, y_test, model_name):
    """
    Evaluate model performance on test set
    Args:
        model: Trained model
        X_test: Test features
        y_test: Test targets
        model_name: Name of model for display
    Returns:
        Dictionary of evaluation metrics
    """
    predictions = model.predict(X_test)

    metrics = {
        'model': model_name,
        'rmse': np.sqrt(mean_squared_error(y_test, predictions)),
        'mae': mean_absolute_error(y_test, predictions),
        'r2': r2_score(y_test, predictions),
        'mape': np.mean(np.abs((y_test - predictions) / y_test)) * 100
    }

    print(f"\n{model_name} Evaluation:")
    for metric, value in metrics.items():
        if metric != 'model':
            print(f"{metric.upper()}: {value:.4f}")

    return metrics

def compare_models(metrics_list):
    """Compare performance of multiple models"""
    comparison_df = pd.DataFrame(metrics_list)
    comparison_df.set_index('model', inplace=True)

    print("\nModel Comparison:")
    print(comparison_df)

    return comparison_df

# Ensemble Model

In [None]:
from sklearn.ensemble import VotingRegressor

def create_ensemble(rf_model, prophet_model, xgb_model):
    """
    Create ensemble of the top 3 models
    Args:
        rf_model: Trained Random Forest model
        prophet_model: Trained Prophet model
        xgb_model: Trained XGBoost model
    Returns:
        Ensemble model
    """
    # Note: Prophet requires special handling in ensembles
    # For simplicity, we'll use a weighted average approach

    ensemble = VotingRegressor(
        estimators=[
            ('rf', rf_model),
            ('xgb', xgb_model)
        ],
        weights=[0.4, 0.6]  # Adjust based on individual model performance
    )

    return ensemble

def ensemble_predict(ensemble, prophet_model, X_test, prophet_future):
    """
    Make predictions using the ensemble
    Args:
        ensemble: Ensemble model (RF + XGB)
        prophet_model: Prophet model
        X_test: Test features for RF/XGB
        prophet_future: Future dataframe for Prophet
    Returns:
        Combined predictions
    """
    # Get predictions from each model
    rf_xgb_pred = ensemble.predict(X_test)
    prophet_pred = prophet_model.predict(prophet_future)['yhat'].values

    # Combine predictions (simple average for demonstration)
    combined_pred = (rf_xgb_pred + prophet_pred) / 2

    return combined_pred

# Full Training Workflow

In [None]:
def full_training_workflow(data_path, location='United States', target='total_cases'):
    """
    Complete training workflow for all models
    Args:
        data_path: Path to preprocessed data
        location: Country/region to model
        target: Target variable to predict
    Returns:
        Dictionary of trained models and evaluation results
    """
    # Load and prepare data
    df = load_preprocessed_data(data_path)
    df = df[df['location'] == location]

    # Feature engineering
    feature_df = create_features(df, target=target)

    # Split into features and target
    X = feature_df.drop(columns=['date', 'target'])
    y = feature_df['target']

    # Train-test split (time-based)
    split_idx = int(0.8 * len(X))
    X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
    y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]

    # Train models
    print("Training Random Forest...")
    rf_model = train_random_forest(X_train, y_train)

    print("\nTraining Prophet...")
    prophet_model = train_prophet(df, target=target)

    print("\nTraining XGBoost...")
    xgb_model = train_xgboost(X_train, y_train)

    # Evaluate models
    rf_metrics = evaluate_model(rf_model, X_test, y_test, 'Random Forest')
    xgb_metrics = evaluate_model(xgb_model, X_test, y_test, 'XGBoost')

    # Prophet evaluation requires different handling
    future = prophet_model.make_future_dataframe(periods=len(X_test))
    prophet_pred = prophet_model.predict(future)['yhat'].tail(len(X_test))
    prophet_metrics = {
        'model': 'Prophet',
        'rmse': np.sqrt(mean_squared_error(y_test, prophet_pred)),
        'mae': mean_absolute_error(y_test, prophet_pred),
        'r2': r2_score(y_test, prophet_pred),
        'mape': np.mean(np.abs((y_test - prophet_pred) / y_test)) * 100
    }

    # Compare models
    metrics_list = [rf_metrics, prophet_metrics, xgb_metrics]
    comparison_results = compare_models(metrics_list)

    # Create ensemble
    print("\nCreating Ensemble Model...")
    ensemble = create_ensemble(rf_model, prophet_model, xgb_model)

    # Evaluate ensemble
    ensemble_pred = ensemble_predict(ensemble, prophet_model, X_test, future.tail(len(X_test)))
    ensemble_metrics = {
        'model': 'Ensemble',
        'rmse': np.sqrt(mean_squared_error(y_test, ensemble_pred)),
        'mae': mean_absolute_error(y_test, ensemble_pred),
        'r2': r2_score(y_test, ensemble_pred),
        'mape': np.mean(np.abs((y_test - ensemble_pred) / y_test)) * 100
    }

    comparison_results = pd.concat([
        comparison_results,
        pd.DataFrame([ensemble_metrics]).set_index('model')
    ])

    print("\nFinal Model Comparison:")
    print(comparison_results)

    # Save models
    joblib.dump(rf_model, 'random_forest_model.pkl')
    joblib.dump(xgb_model, 'xgboost_model.pkl')
    joblib.dump(ensemble, 'ensemble_model.pkl')

    # Save Prophet model
    with open('prophet_model.json', 'w') as fout:
        json.dump(prophet_model.to_json(), fout)

    return {
        'models': {
            'random_forest': rf_model,
            'prophet': prophet_model,
            'xgboost': xgb_model,
            'ensemble': ensemble
        },
        'metrics': comparison_results
    }

# Model Deployment with Live Data

In [None]:
class MedicalResourcePredictor:
    """Class for deploying trained models with live data"""

    def __init__(self, model_paths):
        # Load trained models
        self.rf_model = joblib.load(model_paths['random_forest'])
        self.prophet_model = self.load_prophet_model(model_paths['prophet'])
        self.xgb_model = joblib.load(model_paths['xgboost'])
        self.ensemble_model = joblib.load(model_paths['ensemble'])

        # Initialize data ingestor
        self.data_ingestor = LiveDataIngestor()

    @staticmethod
    def load_prophet_model(path):
        """Load saved Prophet model"""
        with open(path, 'r') as fin:
            model_json = json.load(fin)
            return Prophet().from_json(model_json)

    def predict_demand(self, location, forecast_days=30):
        """Generate predictions using all models"""
        # Get latest data
        latest_data = self.get_latest_data(location)

        # Prepare features
        feature_df = create_features(latest_data)
        X = feature_df.drop(columns=['date', 'target'])

        # Generate predictions
        rf_pred = self.rf_model.predict(X.tail(1))
        xgb_pred = self.xgb_model.predict(X.tail(1))

        # Prophet prediction
        future = self.prophet_model.make_future_dataframe(periods=forecast_days)
        prophet_pred = self.prophet_model.predict(future)['yhat'].tail(forecast_days)

        # Ensemble prediction
        ensemble_pred = ensemble_predict(
            self.ensemble_model,
            self.prophet_model,
            X.tail(1),
            future.tail(forecast_days)
        )

        # Calculate resource needs
        resources_needed = self.calculate_resources(ensemble_pred.mean(), location)

        return {
            'predictions': {
                'random_forest': rf_pred[0],
                'prophet': prophet_pred.mean(),
                'xgboost': xgb_pred[0],
                'ensemble': ensemble_pred.mean()
            },
            'resource_allocation': resources_needed,
            'forecast_dates': future.tail(forecast_days)['ds'].dt.strftime('%Y-%m-%d').tolist(),
            'ensemble_forecast': ensemble_pred.tolist()
        }

    def get_latest_data(self, location):
        """Combine static and live data for prediction"""
        # Get static data
        static_data = load_preprocessed_data('owid-covid-data.csv')
        static_data = static_data[static_data['location'] == location]

        # Get live data (simplified example)
        news_data = self.data_ingestor.fetch_news_data('COVID')
        if news_data:
            news_df = self.data_ingestor.process_live_data(news_data, 'newsapi')
            # Here you would incorporate the news signals into your features
            # For example, count of articles per day as a feature

        # For now, just return the static data
        return static_data

    def calculate_resources(self, predicted_cases, location):
        """Calculate required medical resources"""
        # This would use similar logic to your existing resource calculation
        # but potentially enhanced with more detailed resource models
        pass

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from prophet import Prophet
import matplotlib.pyplot as plt

# Load the dataset
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)

# Preprocessing
def preprocess_data(df, country='United States'):
    # Filter for specific country and relevant columns
    df = df[df['location'] == country][[
        'date', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
        'hosp_patients', 'icu_patients', 'reproduction_rate',
        'positive_rate', 'tests_per_case', 'people_vaccinated',
        'people_fully_vaccinated', 'population', 'population_density'
    ]].copy()

    # Convert date and sort
    df['date'] = pd.to_datetime(df['date'])
    df = df.sort_values('date')

    # Fill missing values
    df.fillna(method='ffill', inplace=True)
    df.fillna(method='bfill', inplace=True)
    df.fillna(0, inplace=True)

    # Calculate additional features
    df['cases_per_million'] = (df['total_cases'] / df['population']) * 1e6
    df['deaths_per_million'] = (df['total_deaths'] / df['population']) * 1e6
    df['vaccination_rate'] = df['people_vaccinated'] / df['population']

    # Create lag features
    for lag in [1, 3, 7, 14]:
        df[f'cases_lag_{lag}'] = df['new_cases'].shift(lag)
        df[f'hosp_lag_{lag}'] = df['hosp_patients'].shift(lag)

    # Remove rows with missing values
    df.dropna(inplace=True)

    return df

# Prepare data for US
df_processed = preprocess_data(df, 'United States')
print("Processed Data Shape:", df_processed.shape)
print(df_processed.tail())

# Model Training
# Random Forest
# Random Forest is an ensemble learning method that operates by constructing multiple decision trees during training.

In [None]:
# First, let's properly define country_df
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

# Load and prepare the data
url = "https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv"
df = pd.read_csv(url)

# Basic preprocessing
df['date'] = pd.to_datetime(df['date'])
df = df.sort_values(['location', 'date'])

# Select relevant columns
cols = ['date', 'location', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths',
        'hosp_patients', 'icu_patients', 'people_vaccinated', 'population']
df = df[cols]

# Fill missing values
df['new_cases'] = df['new_cases'].fillna(0)
df['new_deaths'] = df['new_deaths'].fillna(0)
df['hosp_patients'] = df['hosp_patients'].fillna(method='ffill')
df['icu_patients'] = df['icu_patients'].fillna(method='ffill')

# Calculate per capita metrics
df['cases_per_million'] = (df['total_cases'] / df['population']) * 1e6
df['deaths_per_million'] = (df['total_deaths'] / df['population']) * 1e6
df['hosp_per_million'] = (df['hosp_patients'] / df['population']) * 1e6
df['icu_per_million'] = (df['icu_patients'] / df['population']) * 1e6

# Create lag features
def create_lag_features(df, lags=7):
    for i in range(1, lags+1):
        df[f'cases_lag_{i}'] = df.groupby('location')['new_cases'].shift(i)
        df[f'hosp_lag_{i}'] = df.groupby('location')['hosp_patients'].shift(i)
    return df

df = create_lag_features(df)
df = df.dropna()

# Filter to specific country (USA for this example)
country_df = df[df['location'] == 'United States'].copy()

# Now we can run the Random Forest code
features = [col for col in country_df.columns if 'lag_' in col] + ['cases_per_million', 'deaths_per_million']
target = 'hosp_patients'

X = country_df[features]
y = country_df[target]

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False)

# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train Random Forest
rf_model = RandomForestRegressor(
    n_estimators=200,
    max_depth=10,
    min_samples_split=5,
    random_state=42,
    n_jobs=-1
)

rf_model.fit(X_train_scaled, y_train)

# Make predictions
rf_pred = rf_model.predict(X_test_scaled)

# Evaluate
rf_rmse = np.sqrt(mean_squared_error(y_test, rf_pred))
rf_r2 = r2_score(y_test, rf_pred)
rf_mae = mean_absolute_error(y_test, rf_pred)

print(f"Random Forest Performance:")
print(f"RMSE: {rf_rmse:.2f}")
print(f"R²: {rf_r2:.2f}")
print(f"MAE: {rf_mae:.2f}")

# Feature importance
rf_importance = pd.DataFrame({
    'feature': features,
    'importance': rf_model.feature_importances_
}).sort_values('importance', ascending=False)

plt.figure(figsize=(10, 6))
plt.barh(rf_importance['feature'], rf_importance['importance'])
plt.title('Random Forest Feature Importance')
plt.show()

# Prophet

In [None]:
# Prepare data for Prophet
prophet_df = country_df[['date', 'hosp_patients']].rename(columns={'date': 'ds', 'hosp_patients': 'y'})

# Split into train/test
train_size = int(len(prophet_df) * 0.8)
train = prophet_df[:train_size]
test = prophet_df[train_size:]

# Train Prophet model
prophet_model = Prophet(
    yearly_seasonality=False,
    weekly_seasonality=True,
    daily_seasonality=False,
    changepoint_prior_scale=0.05,
    seasonality_prior_scale=10
)

# Add additional regressors
for feature in ['new_cases', 'icu_patients']:
    prophet_model.add_regressor(feature)

# Need to merge these features back in
train_with_features = train.merge(country_df[['date', 'new_cases', 'icu_patients']],
                                left_on='ds', right_on='date', how='left')
train_with_features = train_with_features.dropna()

prophet_model.fit(train_with_features)

# Create future dataframe
future = prophet_model.make_future_dataframe(periods=len(test))
future = future.merge(country_df[['date', 'new_cases', 'icu_patients']],
                     left_on='ds', right_on='date', how='left')

# Forecast
forecast = prophet_model.predict(future)
prophet_pred = forecast['yhat'].tail(len(test)).values

# Evaluate
prophet_rmse = np.sqrt(mean_squared_error(test['y'], prophet_pred))
prophet_r2 = r2_score(test['y'], prophet_pred)
prophet_mae = mean_absolute_error(test['y'], prophet_pred)

print(f"\nProphet Performance:")
print(f"RMSE: {prophet_rmse:.2f}")
print(f"R²: {prophet_r2:.2f}")
print(f"MAE: {prophet_mae:.2f}")

# Plot components
fig = prophet_model.plot_components(forecast)
plt.show()

# XGBoost Model Training

In [None]:
import numpy as np
import pandas as pd
from xgboost import XGBRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV, TimeSeriesSplit, train_test_split
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import matplotlib.pyplot as plt
import seaborn as sns

# 1. Data Preparation (if not already done)
# Let's assume we have a DataFrame 'df' with our features and target
# For this example, I'll create a mock preparation

# Select features and target
features = ['total_cases', 'new_cases', 'total_deaths', 'new_deaths',
            'hosp_patients', 'icu_patients', 'people_vaccinated']
target = 'hosp_patients'

X = df[features].copy()
y = df[target].copy()

# Remove rows with missing target values
valid_rows = y.notna()
X = X[valid_rows]
y = y[valid_rows]

# 2. Train-test split (time-series aware)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, shuffle=False
)

# 3. Feature Scaling
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# 4. TimeSeriesSplit for CV
tscv = TimeSeriesSplit(n_splits=5)

# 5. XGBoost Model Setup
xgb = XGBRegressor(random_state=42, objective='reg:squarederror')

# 6. Hyperparameter Tuning
param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [3, 6, 9],
    'learning_rate': [0.01, 0.1],
    'subsample': [0.8, 1.0],
    'colsample_bytree': [0.8, 1.0]
}

xgb_grid = GridSearchCV(
    xgb,
    param_grid,
    cv=tscv,
    scoring='neg_mean_squared_error',
    n_jobs=-1
)
xgb_grid.fit(X_train_scaled, y_train)

# 7. Get Best Model
best_xgb = xgb_grid.best_estimator_
print(f"Best XGBoost parameters: {xgb_grid.best_params_}")

# 8. Evaluation
y_pred = best_xgb.predict(X_test_scaled)

xgb_metrics = {
    'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)),
    'R2': r2_score(y_test, y_pred),
    'MAE': mean_absolute_error(y_test, y_pred)
}

print("\nXGBoost Performance:")
for metric, value in xgb_metrics.items():
    print(f"{metric}: {value:.4f}")

# 9. Feature Importance - CORRECTED VERSION
# Get the feature names from the original X (before scaling)
feature_names = X.columns

# Create importance DataFrame ensuring equal lengths
xgb_importance = pd.DataFrame({
    'feature': feature_names,
    'importance': best_xgb.feature_importances_[:len(feature_names)]  # Ensure matching length
})

# Sort by importance
xgb_importance = xgb_importance.sort_values('importance', ascending=False)

print("\nTop 10 Important Features:")
print(xgb_importance.head(10))

# 10. Plotting
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
sns.barplot(x='importance', y='feature', data=xgb_importance.head(10))
plt.title('Top 10 Important Features')

plt.subplot(1, 2, 2)
plt.plot(y_test.values, label='Actual')
plt.plot(y_pred, label='Predicted')
plt.legend()
plt.title('Actual vs Predicted')

plt.tight_layout()
plt.show()

In [None]:
import dash
from dash import dcc, html
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import plotly.io as pio

# Set Plotly theme
pio.templates.default = "plotly_white"

# Colors
COLORS = {
    'primary': '#3366CC',
    'secondary': '#DC3912',
    'tertiary': '#FF9900',
    'quaternary': '#109618',
    'background': '#f8f9fa'
}

# Load sample data (replace with your actual data loading code)
# You can use the COVID-19 data from Our World in Data
url = 'https://covid.ourworldindata.org/data/owid-covid-data.csv'
try:
    df = pd.read_csv(url)
    print("Data loaded successfully")
except Exception as e:
    print(f"Error loading data: {e}")
    # If the online source fails, use a small sample dataset
    data = {
        'date': pd.date_range(start='2020-01-01', end='2023-01-01', freq='D'),
        'location': ['USA'] * 1097 + ['Germany'] * 1097 + ['France'] * 1097 + ['United Kingdom'] * 1097,
        'total_cases': np.cumsum(np.random.gamma(1, 1, 1097) * 10000) * 1.2 +
                       np.random.normal(0, 5000, 1097),
        'new_cases': np.random.gamma(1, 1, 1097) * 10000 * 1.2 +
                     np.random.normal(0, 500, 1097),
        'total_deaths': np.cumsum(np.random.gamma(1, 1, 1097) * 1000) * 1.1 +
                        np.random.normal(0, 500, 1097),
        'new_deaths': np.random.gamma(1, 1, 1097) * 1000 * 1.1 +
                      np.random.normal(0, 50, 1097),
        'hosp_patients': np.random.gamma(1, 1, 1097) * 5000 +
                         np.random.normal(0, 200, 1097),
        'icu_patients': np.random.gamma(1, 1, 1097) * 1000 +
                       np.random.normal(0, 50, 1097)
    }
    locations = ['USA', 'Germany', 'France', 'United Kingdom']
    data_expanded = {
        'date': [],
        'location': [],
        'total_cases': [],
        'new_cases': [],
        'total_deaths': [],
        'new_deaths': [],
        'hosp_patients': [],
        'icu_patients': []
    }

    for i, loc in enumerate(locations):
        data_expanded['date'].extend(data['date'])
        data_expanded['location'].extend([loc] * len(data['date']))
        for metric in ['total_cases', 'new_cases', 'total_deaths', 'new_deaths', 'hosp_patients', 'icu_patients']:
            values = data[metric] * (0.8 + i * 0.4)  # Different scale for each country
            data_expanded[metric].extend(values)

    df = pd.DataFrame(data_expanded)
    df['date'] = pd.to_datetime(df['date'])

# Convert date to datetime
df['date'] = pd.to_datetime(df['date'])

# Ensure columns exist
required_columns = ['total_cases', 'new_cases', 'total_deaths', 'new_deaths', 'hosp_patients', 'icu_patients']
for col in required_columns:
    if col not in df.columns:
        df[col] = 0

# Create app
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
server = app.server

# Custom CSS
app.index_string = '''
<!DOCTYPE html>
<html>
    <head>
        {%metas%}
        <title>COVID-19 Forecast Dashboard</title>
        {%favicon%}
        {%css%}
        <style>
            body {
                background-color: #f8f9fa;
                font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
            }
            .tab-container {
                background-color: white;
                border-radius: 5px;
                padding: 10px;
                box-shadow: 0 2px 4px rgba(0,0,0,0.05);
            }
            .control-label {
                font-weight: 600;
                color: #495057;
                margin-bottom: 5px;
            }
            .metric-card {
                background-color: white;
                border-radius: 8px;
                padding: 15px;
                box-shadow: 0 2px 4px rgba(0,0,0,0.05);
                height: 100%;
            }
            .metric-title {
                font-size: 0.9rem;
                color: #6c757d;
                margin-bottom: 5px;
            }
            .metric-value {
                font-size: 1.6rem;
                font-weight: 600;
                color: #212529;
            }
            .run-btn {
                width: 100%;
            }
            .legend-item {
                display: flex;
                align-items: center;
            }
            .legend-color {
                width: 15px;
                height: 15px;
                border-radius: 3px;
                margin-right: 5px;
            }
            .legend-label {
                font-size: 0.85rem;
            }
        </style>
    </head>
    <body>
        {%app_entry%}
        <footer>
            {%config%}
            {%scripts%}
            {%renderer%}
        </footer>
    </body>
</html>
'''

# Main content
app.layout = dbc.Container([
    # Header
    dbc.Row([
        dbc.Col([
            html.H2("RTMRP PROTOTYPE Ver. 2 DASHBOARD", className="my-4 text-primary")
        ])
    ]),

    # Tabs
    html.Div([
        dbc.Tabs([
            dbc.Tab(label="Trend", tab_id="tab-trend", labelClassName="text-primary"),
            dbc.Tab(label="Compare", tab_id="tab-compare", labelClassName="text-primary"),
            dbc.Tab(label="Map", tab_id="tab-map", labelClassName="text-primary"),
            dbc.Tab(label="Resources", tab_id="tab-resources", labelClassName="text-primary"),
            dbc.Tab(label="Summary", tab_id="tab-summary", labelClassName="text-primary"),
        ], id="main-tabs", active_tab="tab-trend"),
    ], className="tab-container mb-4"),

    # Tab content
    html.Div([
        # Trend Tab Content
        html.Div([
            # Control panel
            dbc.Card([
                dbc.CardHeader("Forecast Settings"),
                dbc.CardBody([
                    dbc.Row([
                        # Country selector
                        dbc.Col([
                            html.Label("Location", className="control-label"),
                            dcc.Dropdown(
                                id='country-dropdown',
                                options=[{'label': loc, 'value': loc} for loc in sorted(df['location'].unique())],
                                value='USA',
                                clearable=False,
                                className="mb-3"
                            )
                        ], width=12, lg=3),

                        # Metric selector
                        dbc.Col([
                            html.Label("Metric", className="control-label"),
                            dcc.Dropdown(
                                id='target-variable',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                clearable=False,
                                className="mb-3"
                            )
                        ], width=12, lg=3),

                        # Model selector
                        dbc.Col([
                            html.Label("Model", className="control-label"),
                            dcc.Dropdown(
                                id='model-type',
                                options=[
                                    {'label': 'Ensemble', 'value': 'Ensemble'},
                                    {'label': 'Random Forest', 'value': 'Random Forest'},
                                    {'label': 'Prophet', 'value': 'Prophet'},
                                    {'label': 'XGBoost', 'value': 'XGBoost'}
                                ],
                                value='Ensemble',
                                clearable=False,
                                className="mb-3"
                            )
                        ], width=12, lg=3),

                        # Days selector
                        dbc.Col([
                            html.Label("Projection Days", className="control-label"),
                            dcc.Slider(
                                id='forecast-days',
                                min=7,
                                max=90,
                                step=7,
                                value=30,
                                marks={
                                    7: '7d',
                                    30: '30d',
                                    60: '60d',
                                    90: '90d'
                                },
                                className="mb-3"
                            )
                        ], width=12, lg=3),
                    ]),

                    # Run button
                    dbc.Button(
                        "Update Forecast",
                        id="run-btn",
                        color="primary",
                        className="run-btn mt-2"
                    )
                ])
            ], className="mb-4"),

            # Trend overview
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Forecast Trend"),
                        dbc.CardBody([
                            # Legend for the chart
                            dbc.Row([
                                dbc.Col([
                                    html.Div([
                                        html.Div([
                                            html.Div(className="legend-color", style={"background-color": COLORS['primary']}),
                                            html.Div("Historical Data", className="legend-label")
                                        ], className="legend-item"),
                                        html.Div([
                                            html.Div(className="legend-color", style={"background-color": COLORS['quaternary']}),
                                            html.Div("Forecast", className="legend-label")
                                        ], className="legend-item"),
                                        html.Div([
                                            html.Div(className="legend-color", style={"background-color": "rgba(220, 57, 18, 0.2)"}),
                                            html.Div("Uncertainty Interval", className="legend-label")
                                        ], className="legend-item"),
                                    ], style={"display": "flex", "justify-content": "center", "gap": "20px"})
                                ], width=12)
                            ], className="mb-3"),

                            # Chart
                            dcc.Loading(
                                id="loading-prediction",
                                type="circle",
                                children=dcc.Graph(
                                    id='prediction-plot',
                                    config={'displayModeBar': True, 'scrollZoom': True},
                                    style={'height': '400px'}
                                )
                            )
                        ])
                    ])
                ], width=12)
            ], className="mb-4"),

            # Data insights
            dbc.Row([
                # Key metrics
                dbc.Col([
                    html.H5("Key Metrics", className="mb-3"),
                    dbc.Row([
                        # Current cases
                        dbc.Col([
                            html.Div([
                                html.Div("Current Count", className="metric-title"),
                                html.Div(id="current-count", className="metric-value"),
                                html.Div(id="current-date", className="mt-2", style={"font-size": "0.8rem", "color": "#6c757d"})
                            ], className="metric-card")
                        ], width=6, lg=3, className="mb-3"),

                        # Projected peak
                        dbc.Col([
                            html.Div([
                                html.Div("Projected Peak", className="metric-title"),
                                html.Div(id="projected-peak", className="metric-value"),
                                html.Div(id="peak-date", className="mt-2", style={"font-size": "0.8rem", "color": "#6c757d"})
                            ], className="metric-card")
                        ], width=6, lg=3, className="mb-3"),

                        # Growth rate
                        dbc.Col([
                            html.Div([
                                html.Div("Growth Rate", className="metric-title"),
                                html.Div(id="growth-rate", className="metric-value"),
                                html.Div("Weekly average", className="mt-2", style={"font-size": "0.8rem", "color": "#6c757d"})
                            ], className="metric-card")
                        ], width=6, lg=3, className="mb-3"),

                        # Trend direction
                        dbc.Col([
                            html.Div([
                                html.Div("Trend", className="metric-title"),
                                html.Div(id="trend-direction", className="metric-value"),
                                html.Div("7-day trend", className="mt-2", style={"font-size": "0.8rem", "color": "#6c757d"})
                            ], className="metric-card")
                        ], width=6, lg=3, className="mb-3")
                    ])
                ], width=12)
            ], className="mb-4"),

            # Model performance
            dbc.Row([
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Model Performance"),
                        dbc.CardBody([
                            html.Div(id="performance-metrics")
                        ])
                    ])
                ], width=12, lg=6),

                # Feature importance
                dbc.Col([
                    dbc.Card([
                        dbc.CardHeader("Feature Importance"),
                        dbc.CardBody([
                            dcc.Loading(
                                id="loading-features",
                                type="circle",
                                children=dcc.Graph(
                                    id='feature-importance-plot',
                                    config={'displayModeBar': False}
                                )
                            )
                        ])
                    ])
                ], width=12, lg=6)
            ])
        ], id="tab-trend-content"),

        # Compare Tab Content
        html.Div([
            dbc.Card([
                dbc.CardHeader("Compare Locations"),
                dbc.CardBody([
                    dbc.Row([
                        # Locations multi-select
                        dbc.Col([
                            html.Label("Select Locations", className="control-label"),
                            dcc.Dropdown(
                                id='compare-countries',
                                options=[{'label': loc, 'value': loc} for loc in sorted(df['location'].unique())],
                                value=['USA', 'United Kingdom', 'Germany', 'France'],
                                multi=True,
                                className="mb-3"
                            )
                        ], width=12, lg=6),

                        # Metric selector
                        dbc.Col([
                            html.Label("Select Metric", className="control-label"),
                            dcc.Dropdown(
                                id='compare-metric',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                clearable=False,
                                className="mb-3"
                            )
                        ], width=12, lg=6),
                    ]),

                    # Normalization option
                    dbc.Row([
                        dbc.Col([
                            dbc.Checklist(
                                options=[
                                    {"label": "Normalize by population", "value": 1},
                                ],
                                value=[],
                                id="normalize-data",
                                switch=True,
                            ),
                        ], width=12, className="mb-3"),
                    ]),

                    # Comparison chart
                    dbc.Row([
                        dbc.Col([
                            dcc.Loading(
                                id="loading-compare",
                                type="circle",
                                children=dcc.Graph(
                                    id='compare-plot',
                                    config={'displayModeBar': True, 'scrollZoom': True},
                                    style={'height': '500px'}
                                )
                            )
                        ], width=12)
                    ])
                ])
            ])
        ], id="tab-compare-content", style={'display': 'none'}),

        # Map Tab Content
        html.Div([
            dbc.Card([
                dbc.CardHeader("Global Map View"),
                dbc.CardBody([
                    dbc.Row([
                        # Metric selector for map
                        dbc.Col([
                            html.Label("Select Metric", className="control-label"),
                            dcc.Dropdown(
                                id='map-metric',
                                options=[
                                    {'label': 'Total Cases', 'value': 'total_cases'},
                                    {'label': 'New Cases', 'value': 'new_cases'},
                                    {'label': 'Total Deaths', 'value': 'total_deaths'},
                                    {'label': 'New Deaths', 'value': 'new_deaths'},
                                    {'label': 'Hospital Patients', 'value': 'hosp_patients'},
                                    {'label': 'ICU Patients', 'value': 'icu_patients'}
                                ],
                                value='total_cases',
                                clearable=False,
                                className="mb-3"
                            )
                        ], width=12, lg=6),

                        # Date selector for map
                        dbc.Col([
                            html.Label("Select Date", className="control-label"),
                            dcc.DatePickerSingle(
                                id='map-date',
                                min_date_allowed=df['date'].min().date(),
                                max_date_allowed=df['date'].max().date(),
                                initial_visible_month=df['date'].max().date(),
                                date=df['date'].max().date(),
                                className="mb-3"
                            )
                        ], width=12, lg=6),
                    ]),

                    # Map
                    dbc.Row([
                        dbc.Col([
                            dcc.Loading(
                                id="loading-map",
                                type="circle",
                                children=dcc.Graph(
                                    id='world-map',
                                    config={'displayModeBar': True},
                                    style={'height': '600px'}
                                )
                            )
                        ], width=12)
                    ])
                ])
            ])
        ], id="tab-map-content", style={'display': 'none'}),

        # Resources Tab Content
        html.Div([
            dbc.Card([
                dbc.CardHeader("COVID-19 Resources"),
                dbc.CardBody([
                    html.H4("Official Sources", className="mb-3"),
                    dbc.Row([
                        dbc.Col([
                            dbc.Card([
                                dbc.CardBody([
                                    html.H5("World Health Organization"),
                                    html.P("Official WHO COVID-19 information and guidance."),
                                    dbc.Button("Visit WHO", color="primary", href="https://www.who.int/emergencies/diseases/novel-coronavirus-2019", target="_blank")
                                ])
                            ])
                        ], width=12, md=4, className="mb-3"),
                        dbc.Col([
                            dbc.Card([
                                dbc.CardBody([
                                    html.H5("CDC"),
                                    html.P("US Centers for Disease Control COVID-19 resources."),
                                    dbc.Button("Visit CDC", color="primary", href="https://www.cdc.gov/coronavirus/2019-ncov/index.html", target="_blank")
                                ])
                            ])
                        ], width=12, md=4, className="mb-3"),
                        dbc.Col([
                            dbc.Card([
                                dbc.CardBody([
                                    html.H5("ECDC"),
                                    html.P("European Centre for Disease Prevention and Control."),
                                    dbc.Button("Visit ECDC", color="primary", href="https://www.ecdc.europa.eu/en/covid-19", target="_blank")
                                ])
                            ])
                        ], width=12, md=4, className="mb-3"),
                    ]),

                    html.H4("Data Sources", className="my-3"),
                    dbc.Row([
                        dbc.Col([
                            dbc.Card([
                                dbc.CardBody([
                                    html.H5("Our World in Data"),
                                    html.P("Comprehensive COVID-19 statistics and visualizations."),
                                    dbc.Button("Visit OWID", color="primary", href="https://ourworldindata.org/coronavirus", target="_blank")
                                ])
                            ])
                        ], width=12, md=4, className="mb-3"),
                        dbc.Col([
                            dbc.Card([
                                dbc.CardBody([
                                    html.H5("Johns Hopkins"),
                                    html.P("Johns Hopkins Coronavirus Resource Center."),
                                    dbc.Button("Visit JHU", color="primary", href="https://coronavirus.jhu.edu/", target="_blank")
                                ])
                            ])
                        ], width=12, md=4, className="mb-3"),
                        dbc.Col([
                            dbc.Card([
                                dbc.CardBody([
                                    html.H5("IHME"),
                                    html.P("Institute for Health Metrics and Evaluation COVID-19 projections."),
                                    dbc.Button("Visit IHME", color="primary", href="https://covid19.healthdata.org/", target="_blank")
                                ])
                            ])
                        ], width=12, md=4, className="mb-3"),
                    ]),

                    html.H4("Research Papers", className="my-3"),
                    dbc.Row([
                        dbc.Col([
                            dbc.Card([
                                dbc.CardBody([
                                    html.H5("COVID-19 Research Database"),
                                    html.P("Collection of research papers on COVID-19."),
                                    dbc.Button("Access Papers", color="primary", href="https://www.ncbi.nlm.nih.gov/research/coronavirus/", target="_blank")
                                ])
                            ])
                        ], width=12, md=6, className="mb-3"),
                        dbc.Col([
                            dbc.Card([
                                dbc.CardBody([
                                    html.H5("medRxiv COVID-19 Preprints"),
                                    html.P("Latest preprint research papers on COVID-19."),
                                    dbc.Button("Access Preprints", color="primary", href="https://connect.medrxiv.org/relate/content/181", target="_blank")
                                ])
                            ])
                        ], width=12, md=6, className="mb-3"),
                    ]),
                ])
            ])
        ], id="tab-resources-content", style={'display': 'none'}),

        # Summary Tab Content
        html.Div([
            dbc.Card([
                dbc.CardHeader("COVID-19 Situation Summary"),
                dbc.CardBody([
                    # Global summary stats
                    dbc.Row([
                        dbc.Col([
                            html.H4("Global Overview", className="mb-3"),
                            dbc.Row([
                                dbc.Col([
                                    html.Div([
                                        html.Div("Total Cases", className="metric-title"),
                                        html.Div(id="global-cases", className="metric-value"),
                                    ], className="metric-card")
                                ], width=12, md=3, className="mb-3"),
                                dbc.Col([
                                    html.Div([
                                        html.Div("Total Deaths", className="metric-title"),
                                        html.Div(id="global-deaths", className="metric-value"),
                                    ], className="metric-card")
                                ], width=12, md=3, className="mb-3"),
                                dbc.Col([
                                    html.Div([
                                        html.Div("New Cases (7d avg)", className="metric-title"),
                                        html.Div(id="global-new-cases", className="metric-value"),
                                    ], className="metric-card")
                                ], width=12, md=3, className="mb-3"),
                                dbc.Col([
                                    html.Div([
                                        html.Div("New Deaths (7d avg)", className="metric-title"),
                                        html.Div(id="global-new-deaths", className="metric-value"),
                                    ], className="metric-card")
                                ], width=12, md=3, className="mb-3"),
                            ]),
                        ], width=12)
                    ], className="mb-4"),

                    # Top countries table
                    dbc.Row([
                        dbc.Col([
                            html.H4("Most Affected Countries", className="mb-3"),
                            dbc.Card([
                                dbc.CardBody([
                                    dcc.Loading(
                                        id="loading-table",
                                        type="circle",
                                        children=html.Div(id="top-countries-table")
                                    )
                                ])
                            ])
                        ], width=12)
                    ], className="mb-4"),

                    # Recent trends
                    dbc.Row([
                        dbc.Col([
                            html.H4("Recent Trends", className="mb-3"),
                            dcc.Loading(
                                id="loading-trends",
                                type="circle",
                                children=dcc.Graph(
                                    id='global-trends-plot',
                                    config={'displayModeBar': False},
                                    style={'height': '400px'}
                                )
                            )
                        ], width=12)
                    ])
                ])
            ])
        ], id="tab-summary-content", style={'display': 'none'}),
    ], className="tab-content")
], fluid=True, className="pb-5")

# Callbacks

# Tab switching
@app.callback(
    [Output("tab-trend-content", "style"),
     Output("tab-compare-content", "style"),
     Output("tab-map-content", "style"),
     Output("tab-resources-content", "style"),
     Output("tab-summary-content", "style")],
    [Input("main-tabs", "active_tab")]
)
def switch_tab(active_tab):
    styles = []
    for tab_id in ["tab-trend", "tab-compare", "tab-map", "tab-resources", "tab-summary"]:
        if tab_id == active_tab:
            styles.append({"display": "block"})
        else:
            styles.append({"display": "none"})
    return styles

# Generate forecasts for trend tab
@app.callback(
    [Output("prediction-plot", "figure"),
     Output("current-count", "children"),
     Output("current-date", "children"),
     Output("projected-peak", "children"),
     Output("peak-date", "children"),
     Output("growth-rate", "children"),
     Output("trend-direction", "children"),
     Output("trend-direction", "style"),
     Output("performance-metrics", "children"),
     Output("feature-importance-plot", "figure")],
    [Input("run-btn", "n_clicks")],
    [State("country-dropdown", "value"),
     State("target-variable", "value"),
     State("model-type", "value"),
     State("forecast-days", "value")]
)
def update_forecast(n_clicks, country, target_var, model_type, forecast_days):
    # Filter data
    country_data = df[df['location'] == country].sort_values('date')

    # Get the most recent value
    latest_data = country_data.iloc[-1]
    latest_value = latest_data[target_var]

    # Format the current count based on variable type
    if target_var.startswith('total'):
        current_count = f"{int(latest_value):,}"
    else:
        current_count = f"{int(latest_value):,}"

    # Format the current date
    current_date = f"As of {latest_data['date'].strftime('%b %d, %Y')}"

    # Generate mock forecast data
    last_date = country_data['date'].max()
    future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=forecast_days)

    # Create forecasts (simulated)
    if target_var.startswith('total'):
        # For cumulative metrics, create an increasing curve
        base_value = latest_value
        forecast_values = []
        for i in range(forecast_days):
            growth = np.random.normal(1.02, 0.005)  # Slight growth with randomness
            base_value *= growth
            forecast_values.append(base_value)
    else:
        # For daily metrics, create oscillating values
        base_value = country_data[target_var].iloc[-30:].mean()
        forecast_values = []
        for i in range(forecast_days):
            oscillation = 0.3 * np.sin(i / 7 * np.pi) + np.random.normal(0, 0.1)
            new_value = base_value * (1 + oscillation)
            forecast_values.append(max(0, new_value))  # Ensure no negative values

    # Create confidence intervals (simulated)
    lower_bound = [val * (1 - 0.2 * np.sqrt((i+1)/forecast_days)) for i, val in enumerate(forecast_values)]
    upper_bound = [val * (1 + 0.2 * np.sqrt((i+1)/forecast_days)) for i, val in enumerate(forecast_values)]

    # Create the forecast plot
    fig_forecast = go.Figure()

    # Add historical data
    fig_forecast.add_trace(go.Scatter(
        x=country_data['date'],
        y=country_data[target_var],
        mode='lines',
        name='Historical Data',
        line=dict(color=COLORS['primary'], width=2)
    ))

    # Add forecast
    fig_forecast.add_trace(go.Scatter(
        x=future_dates,
        y=forecast_values,
        mode='lines',
        name='Forecast',
        line=dict(color=COLORS['quaternary'], width=2)
    ))

    # Add confidence interval
    fig_forecast.add_trace(go.Scatter(
        x=list(future_dates) + list(future_dates)[::-1],
        y=list(upper_bound) + list(lower_bound)[::-1],
        fill='toself',
        fillcolor='rgba(220, 57, 18, 0.2)',
        line=dict(color='rgba(255, 255, 255, 0)'),
        hoverinfo='skip',
        name='Confidence Interval'
    ))

    # Update layout
    fig_forecast.update_layout(
        title=f"{target_var.replace('_', ' ').title()} Forecast for {country}",
        xaxis_title="Date",
        yaxis_title=target_var.replace('_', ' ').title(),
        hovermode="x unified",
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        ),
        margin=dict(l=40, r=40, t=60, b=40),
    )

    # Calculate projected peak (for demonstration)
    peak_index = np.argmax(forecast_values)
    peak_value = forecast_values[peak_index]
    peak_date = future_dates[peak_index]

    if target_var.startswith('total'):
        # For cumulative metrics, peak is typically the last day
        projected_peak_value = f"{int(forecast_values[-1]):,}"
        peak_date_str = f"By {future_dates[-1].strftime('%b %d, %Y')}"
    else:
        # For daily metrics, find the maximum
        projected_peak_value = f"{int(peak_value):,}"
        peak_date_str = f"~{peak_date.strftime('%b %d, %Y')}"

    # Calculate growth rate (7-day average)
    recent_values = country_data[target_var].iloc[-8:-1].values  # Last 7 days excluding today
    current_value = country_data[target_var].iloc[-1]

    if sum(recent_values) > 0:
        avg_growth = ((current_value / np.mean(recent_values)) - 1) * 100
        growth_rate_str = f"{avg_growth:.1f}%" if avg_growth > 0 else f"{avg_growth:.1f}%"
    else:
        growth_rate_str = "N/A"

    # Determine trend direction
    if avg_growth > 5:
        trend_direction = "↗️ Rising"
        trend_style = {"color": "#dc3545"}  # Red for rising
    elif avg_growth < -5:
        trend_direction = "↘️ Falling"
        trend_style = {"color": "#28a745"}  # Green for falling
    else:
        trend_direction = "→ Stable"
        trend_style = {"color": "#17a2b8"}  # Blue for stable

    # Create mock performance metrics based on model type
    if model_type == "Ensemble":
        performance_metrics = html.Div([
            html.P(f"Model: Ensemble (Random Forest, XGBoost, Prophet)"),
            html.P(f"MAPE: 3.2%"),
            html.P(f"RMSE: 214.5"),
            html.P(f"R²: 0.96")
        ])
    elif model_type == "Random Forest":
        performance_metrics = html.Div([
            html.P(f"Model: Random Forest"),
            html.P(f"MAPE: 4.8%"),
            html.P(f"RMSE: 287.3"),
            html.P(f"R²: 0.91")
        ])
    elif model_type == "Prophet":
        performance_metrics = html.Div([
            html.P(f"Model: Prophet"),
            html.P(f"MAPE: 5.2%"),
            html.P(f"RMSE: 328.9"),
            html.P(f"R²: 0.89")
        ])
    else:  # XGBoost
        performance_metrics = html.Div([
            html.P(f"Model: XGBoost"),
            html.P(f"MAPE: 3.9%"),
            html.P(f"RMSE: 251.2"),
            html.P(f"R²: 0.94")
        ])

    # Create mock feature importance plot
    feature_names = ['Time', 'Seasonality', 'Previous Values', 'Testing Rate', 'Policy Measures']
    feature_importances = [0.35, 0.25, 0.20, 0.12, 0.08]

    fig_features = go.Figure()
    fig_features.add_trace(go.Bar(
        x=feature_importances,
        y=feature_names,
        orientation='h',
        marker=dict(color=COLORS['primary'])
    ))

    fig_features.update_layout(
        xaxis_title="Importance",
        margin=dict(l=40, r=40, t=20, b=20),
        height=300
    )

    return (fig_forecast, current_count, current_date, projected_peak_value, peak_date_str,
            growth_rate_str, trend_direction, trend_style, performance_metrics, fig_features)

# Compare tab callback
@app.callback(
    Output("compare-plot", "figure"),
    [Input("compare-countries", "value"),
     Input("compare-metric", "value"),
     Input("normalize-data", "value")]
)
def update_comparison(countries, metric, normalize):
    if not countries:
        return go.Figure()

    fig = go.Figure()

    for i, country in enumerate(countries):
        country_data = df[df['location'] == country].sort_values('date')

        # Skip if no data
        if country_data.empty:
            continue

        y_values = country_data[metric].values

        # Normalize by population if selected (mock data)
        if normalize:
            # Mock population data
            populations = {
                'USA': 331000000,
                'United Kingdom': 67000000,
                'Germany': 83000000,
                'France': 67000000,
                'Italy': 60000000,
                'Spain': 47000000,
                'India': 1380000000,
                'Brazil': 212000000,
                'Russia': 144000000,
                'Japan': 126000000
            }

            # Default population if country not in the dict
            population = populations.get(country, 100000000)
            y_values = (y_values / population) * 1000000  # Per million

        # Add trace for this country
        fig.add_trace(go.Scatter(
            x=country_data['date'],
            y=y_values,
            mode='lines',
            name=country
        ))

    # Set y-axis title
    metric_name = metric.replace('_', ' ').title()
    y_title = f"{metric_name} per Million" if normalize else metric_name

    fig.update_layout(
        title=f"{y_title} by Country",
        xaxis_title="Date",
        yaxis_title=y_title,
        hovermode="x unified",
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        ),
        margin=dict(l=40, r=40, t=60, b=40),
    )

    return fig

# Map tab callback
@app.callback(
    Output("world-map", "figure"),
    [Input("map-metric", "value"),
     Input("map-date", "date")]
)
def update_map(metric, date_str):
    if not date_str:
        return go.Figure()

    # Convert string date to datetime
    selected_date = pd.to_datetime(date_str)

    # Find closest date in the data
    closest_date = min(df['date'].unique(), key=lambda d: abs(d - selected_date))

    # Filter data for the selected date
    date_data = df[df['date'] == closest_date]

    # Create choropleth map
    fig = px.choropleth(
        date_data,
        locations="location",  # Use location names (would normally use ISO codes)
        locationmode="country names",  # This works for many but not all countries
        color=metric,
        hover_name="location",
        color_continuous_scale=px.colors.sequential.Plasma,
        title=f"{metric.replace('_', ' ').title()} by Country on {closest_date.strftime('%b %d, %Y')}"
    )

    fig.update_layout(
        margin=dict(l=0, r=0, t=50, b=0),
        coloraxis_colorbar=dict(
            title=metric.replace('_', ' ').title()
        )
    )

    return fig

# Summary tab callbacks
@app.callback(
    [Output("global-cases", "children"),
     Output("global-deaths", "children"),
     Output("global-new-cases", "children"),
     Output("global-new-deaths", "children"),
     Output("top-countries-table", "children"),
     Output("global-trends-plot", "figure")],
    [Input("main-tabs", "active_tab")]
)
def update_summary(active_tab):
    if active_tab != "tab-summary":
        # Return empty values if not on summary tab
        return "N/A", "N/A", "N/A", "N/A", None, go.Figure()

    # Calculate global stats
    latest_date = df['date'].max()

    # Get the most recent global totals
    global_data = df[df['date'] == latest_date]
    total_cases = global_data['total_cases'].sum()
    total_deaths = global_data['total_deaths'].sum()

    # Get 7-day average for new cases/deaths
    last_week = df[df['date'] > (latest_date - timedelta(days=7))]
    week_before = df[(df['date'] > (latest_date - timedelta(days=14))) &
                      (df['date'] <= (latest_date - timedelta(days=7)))]

    new_cases_avg = last_week['new_cases'].sum() / 7
    new_deaths_avg = last_week['new_deaths'].sum() / 7

    # Format the stats
    global_cases_str = f"{int(total_cases):,}"
    global_deaths_str = f"{int(total_deaths):,}"
    global_new_cases_str = f"{int(new_cases_avg):,}"
    global_new_deaths_str = f"{int(new_deaths_avg):,}"

    # Create top countries table
    top_countries = global_data.groupby('location')['total_cases'].sum().reset_index()
    top_countries = top_countries.sort_values('total_cases', ascending=False).head(10)

    top_countries_table = html.Table([
        html.Thead(
            html.Tr([
                html.Th("Rank"),
                html.Th("Country"),
                html.Th("Total Cases"),
                html.Th("Total Deaths"),
                html.Th("New Cases (7d avg)")
            ])
        ),
        html.Tbody([
            html.Tr([
                html.Td(i+1),
                html.Td(country),
                html.Td(f"{int(global_data[global_data['location'] == country]['total_cases'].iloc[0]):,}"),
                html.Td(f"{int(global_data[global_data['location'] == country]['total_deaths'].iloc[0]):,}"),
                html.Td(f"{int(last_week[last_week['location'] == country]['new_cases'].mean()):,}")
            ]) for i, country in enumerate(top_countries['location'])
        ])
    ], className="table table-striped table-hover")

    # Create global trends plot
    # Get last 30 days of data
    last_30_days = df[df['date'] > (latest_date - timedelta(days=30))]
    daily_totals = last_30_days.groupby('date')[['new_cases', 'new_deaths']].sum().reset_index()

    fig_trends = go.Figure()

    # New cases trend
    fig_trends.add_trace(go.Scatter(
        x=daily_totals['date'],
        y=daily_totals['new_cases'],
        mode='lines',
        name='New Cases',
        line=dict(color=COLORS['primary'], width=2)
    ))

    # New deaths trend
    fig_trends.add_trace(go.Scatter(
        x=daily_totals['date'],
        y=daily_totals['new_deaths'] * 10,  # Scale up for visibility
        mode='lines',
        name='New Deaths (×10)',
        line=dict(color=COLORS['secondary'], width=2)
    ))

    fig_trends.update_layout(
        title="Global Daily Trends (Last 30 Days)",
        xaxis_title="Date",
        yaxis_title="Count",
        hovermode="x unified",
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        ),
        margin=dict(l=40, r=40, t=60, b=40),
    )

    return (global_cases_str, global_deaths_str, global_new_cases_str,
            global_new_deaths_str, top_countries_table, fig_trends)


if __name__ == '__main__':
    app.run(debug=True)