<a href="https://colab.research.google.com/github/Gaurav478-dawn/Internship_Assignment/blob/main/Gaurav_intern_Assignment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 1. Force-enable Colab Widget Manager & Install
from google.colab import output
output.enable_custom_widget_manager()
# 1. Force-enable Colab Widget Manager & Install

!pip install -q ipywidgets pandas plotly scikit-learn

import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import ipywidgets as widgets
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from IPython.display import display, HTML, clear_output

# -----------------------------------------------------------------------------
# 2. Data Loading
# -----------------------------------------------------------------------------
try:
    df = pd.read_csv('/content/covid_19_data.csv')
    df['ObservationDate'] = pd.to_datetime(df['ObservationDate'])
    df['Country/Region'] = df['Country/Region'].replace('Mainland China', 'China')
    df['Active'] = df['Confirmed'] - df['Deaths'] - df['Recovered']
    df = df.sort_values('ObservationDate')
except FileNotFoundError:
    df = pd.DataFrame()

# -----------------------------------------------------------------------------
# 3. UI Helpers
# -----------------------------------------------------------------------------
def render_header():
    return """
    <div style="background-color: #f8f9fa; padding: 20px; border-bottom: 2px solid #3498db; margin-bottom: 20px; font-family: Arial, sans-serif;">
        <h2 style="margin: 0; color: #2c3e50;">Project: Disease Spread Analysis & Prediction</h2>
        <p style="margin: 5px 0 0 0; color: #555;"><b>Name:</b> Gaurav Khadka | <b>Domain:</b> AI in HealthCare</p>
    </div>
    """

def render_kpi(title, value, color):
    return f"""
    <div style="background-color: {color}; color: white; padding: 12px; border-radius: 4px;
    text-align: center; margin: 5px; flex: 1; min-width: 100px; box-shadow: 0 1px 3px rgba(0,0,0,0.12);">
        <div style="font-size:12px; opacity:0.9;">{title}</div>
        <div style="font-size:22px; font-weight:bold;">{value:,.0f}</div>
    </div>
    """

# -----------------------------------------------------------------------------
# 4. Machine Learning Logic (Linear Regression)
# -----------------------------------------------------------------------------
def run_prediction(country_df, days_to_predict=15):
    # Prepare Data
    country_df = country_df.copy()
    country_df['Day_Num'] = (country_df['ObservationDate'] - country_df['ObservationDate'].min()).dt.days

    # Feature (X) and Target (y)
    X = country_df[['Day_Num']]
    y = country_df['Confirmed']

    # Train Model
    model = LinearRegression()
    model.fit(X, y)

    # Predict Future
    last_day = X['Day_Num'].max()
    future_days = np.array(range(last_day + 1, last_day + days_to_predict + 1)).reshape(-1, 1)
    future_preds = model.predict(future_days)

    # Generate Future Dates
    last_date = country_df['ObservationDate'].max()
    future_dates = [last_date + pd.Timedelta(days=x) for x in range(1, days_to_predict + 1)]

    return future_dates, future_preds, model

# -----------------------------------------------------------------------------
# 5. Main UI Logic
# -----------------------------------------------------------------------------
out_display = widgets.Output()

def update_view(change=None):
    with out_display:
        clear_output(wait=True)

        mode = w_mode.value
        country = w_country.value
        date_range = w_date.value

        # Filter Data
        start, end = pd.to_datetime(date_range[0]), pd.to_datetime(date_range[1])
        mask = (df['ObservationDate'] >= start) & (df['ObservationDate'] <= end)
        df_filtered = df.loc[mask]

        if df_filtered.empty: return

        # === VIEW 1: GLOBAL ===
        if mode == 'Global':
            latest = df_filtered.groupby('ObservationDate')[['Confirmed', 'Active', 'Deaths']].sum().iloc[-1]
            display(HTML(f"""<div style="display: flex; gap: 10px; margin-bottom: 10px;">
                {render_kpi("Confirmed", latest['Confirmed'], "#3498db")}
                {render_kpi("Active", latest['Active'], "#f39c12")}
                {render_kpi("Deaths", latest['Deaths'], "#e74c3c")}
            </div>"""))

            trend = df_filtered.groupby('ObservationDate')[['Confirmed', 'Recovered', 'Deaths']].sum().reset_index()
            fig = px.line(trend, x='ObservationDate', y=['Confirmed', 'Recovered', 'Deaths'],
                          title='Global Trends', height=400, template='simple_white')
            fig.show()

        # === VIEW 2: NATION ANALYSIS ===
        elif mode == 'Nation':
            df_nation = df_filtered[df_filtered['Country/Region'] == country]
            if df_nation.empty: return

            latest = df_nation.iloc[-1]
            display(HTML(f"""<div style="display: flex; gap: 10px; margin-bottom: 10px;">
                {render_kpi(f"{country} Confirmed", latest['Confirmed'], "#2980b9")}
                {render_kpi("Active", latest['Active'], "#d35400")}
                {render_kpi("Deaths", latest['Deaths'], "#c0392b")}
            </div>"""))

            trend = df_nation.groupby('ObservationDate')[['Confirmed', 'Active', 'Deaths']].sum().reset_index()
            fig = go.Figure()
            fig.add_trace(go.Scatter(x=trend['ObservationDate'], y=trend['Confirmed'], name='Confirmed', line=dict(color='#3498db')))
            fig.add_trace(go.Scatter(x=trend['ObservationDate'], y=trend['Active'], name='Active', line=dict(dash='dash', color='#f39c12')))
            fig.update_layout(title=f'Spread Trajectory: {country}', height=400, template='simple_white')
            fig.show()

        # === VIEW 3: ML PREDICTION ===
        elif mode == 'Prediction (ML)':
            df_nation = df[df['Country/Region'] == country] # Use full data for training

            dates, preds, model = run_prediction(df_nation)

            display(HTML(f"""<div style="padding:10px; background:#e8f4f8; border-radius:5px; margin-bottom:10px;">
                <b>ML Model Used:</b> Linear Regression (Scikit-Learn)<br>
                <b>Projection:</b> Predicting next 15 days for {country} based on historical trajectory.
            </div>"""))

            fig = go.Figure()
            # Historical
            fig.add_trace(go.Scatter(x=df_nation['ObservationDate'], y=df_nation['Confirmed'], name='Historical Data', line=dict(color='gray')))
            # Prediction
            fig.add_trace(go.Scatter(x=dates, y=preds, name='ML Forecast (Next 15 Days)', line=dict(color='red', width=3, dash='dot')))

            fig.update_layout(title=f'AI Forecast: Confirmed Cases in {country}', height=400, template='simple_white')
            fig.show()

# -----------------------------------------------------------------------------
# 6. Widget Construction
# -----------------------------------------------------------------------------
if not df.empty:
    w_mode = widgets.ToggleButtons(options=['Global', 'Nation', 'Prediction (ML)'], description='View:', style={'description_width': 'initial'})

    dates = df['ObservationDate'].sort_values().unique()
    opts = [pd.to_datetime(d).date() for d in dates]
    w_date = widgets.SelectionRangeSlider(options=opts, index=(0, len(opts)-1), description='Range:', layout={'width': '600px'}, style={'description_width': 'initial'})

    countries = sorted(df['Country/Region'].unique())
    w_country = widgets.Dropdown(options=countries, value='US' if 'US' in countries else countries[0], description='Nation:', style={'description_width': 'initial'})

    # Visibility Logic
    def on_mode_change(change):
        if change['new'] == 'Global':
            w_country.layout.display = 'none'
            w_date.layout.display = 'flex'
        elif change['new'] == 'Prediction (ML)':
            w_country.layout.display = 'block'
            w_date.layout.display = 'none' # Hide date slider for prediction (uses full data)
        else:
            w_country.layout.display = 'block'
            w_date.layout.display = 'flex'
        update_view()

    w_mode.observe(on_mode_change, names='value')
    w_country.observe(update_view, names='value')
    w_date.observe(update_view, names='value')

    w_country.layout.display = 'none' # Initial

    ui = widgets.VBox([
        w_mode,
        widgets.HBox([w_country, w_date]),
        widgets.HTML("<hr style='margin: 10px 0; border: 0; border-top: 1px solid #ccc;'>"),
        out_display
    ])

    display(HTML(render_header()))
    display(ui)
    update_view()

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m53.6 MB/s[0m eta [36m0:00:00[0m
[?25h

VBox(children=(ToggleButtons(description='View:', options=('Global', 'Nation', 'Prediction (ML)'), style=Toggl…