In [None]:
import os
import sys
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timedelta

# Optional imports
try:
    import seaborn as sns
    sns.set()
except Exception:
    pass

try:
    import requests
except Exception:
    print(" Please install the 'requests' package: pip install requests")
    sys.exit(1)

# -----------------------
# Config
# -----------------------
OWID_URL = "https://covid.ourworldindata.org/data/owid-covid-data.csv"
LOCAL_SAMPLE = os.path.join(os.getcwd(), "sample_covid_data.csv")
CLEANED_OUTPUT = os.path.join(os.getcwd(), "cleaned_covid_data.csv")

# -----------------------
# Create synthetic sample dataset
# -----------------------
def create_sample_csv(path, days=120, seed=42):
    np.random.seed(seed)
    start_date = datetime(2021, 1, 1)
    countries = ['CountryA', 'CountryB']
    rows = []

    for country in countries:
        total_cases = 1000 if country == 'CountryA' else 5000
        total_deaths = 50 if country == 'CountryA' else 200
        vaccinated = 0.0 if country == 'CountryA' else 5.0
        fully_vaccinated = 0.0 if country == 'CountryA' else 1.0

        for i in range(days):
            date = start_date + timedelta(days=i)
            new_cases = int(max(0, np.abs(120 * np.sin(i/18.0) +
                                          (30 if country == 'CountryB' else 8) +
                                          np.random.normal(0, 18))))
            new_deaths = max(0, int(new_cases * (0.01 + (0.005 if country == 'CountryB' else 0.002)) +
                                    np.random.normal(0, 1)))
            total_cases += new_cases
            total_deaths += new_deaths
            vaccinated = min(100.0, vaccinated + max(0, np.random.normal(0.25, 0.06)))
            fully_vaccinated = min(100.0, fully_vaccinated + max(0, np.random.normal(0.18, 0.05)))
            stringency_index = max(0, min(100, 60 + 18*np.sin(i/28.0) + np.random.normal(0, 6)))

            rows.append({
                'date': date.strftime('%Y-%m-%d'),
                'location': country,
                'new_cases': new_cases,
                'new_deaths': new_deaths,
                'total_cases': total_cases,
                'total_deaths': total_deaths,
                'people_vaccinated_per_hundred': round(vaccinated, 2),
                'people_fully_vaccinated_per_hundred': round(fully_vaccinated, 2),
                'stringency_index': round(stringency_index, 2),
                'population': 1_000_000 if country == 'CountryA' else 5_000_000
            })

    df = pd.DataFrame(rows)
    df.to_csv(path, index=False)
    print(f" Created sample CSV at: {path}")
    return df

# -----------------------
# Load dataset
# -----------------------
def load_dataset():
    try:
        print("â¬‡ Downloading OWID dataset...")
        r = requests.get(OWID_URL, timeout=15)
        r.raise_for_status()
        from io import StringIO
        df = pd.read_csv(StringIO(r.text), parse_dates=['date'], low_memory=False)
        print(f"Downloaded OWID dataset. Shape: {df.shape}")
        return df, "owid"
    except Exception as exc:
        print(f" Download failed: {exc}")
        if not os.path.exists(LOCAL_SAMPLE):
            df = create_sample_csv(LOCAL_SAMPLE)
        else:
            print(f" Using local sample dataset from {LOCAL_SAMPLE}")
            df = pd.read_csv(LOCAL_SAMPLE, parse_dates=['date'])
        return df, "sample"

df, source = load_dataset()

# -----------------------
# Cleaning & preprocessing
# -----------------------
df.columns = [c.strip() for c in df.columns]

if df['date'].dtype == object:
    df['date'] = pd.to_datetime(df['date'], errors='coerce')

df = df.dropna(subset=['date', 'location'])

# Convert numeric
for col in df.columns:
    if col not in ['date', 'location'] and df[col].dtype == object:
        df[col] = pd.to_numeric(df[col], errors='coerce')

numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
df[numeric_cols] = df[numeric_cols].fillna(0)

# Add new derived metrics
if 'population' in df.columns and 'total_cases' in df.columns:
    df['cases_per_million'] = (df['total_cases'] / df['population']) * 1e6
    df['new_cases_per_million'] = (df['new_cases'] / df['population']) * 1e6

if 'total_deaths' in df.columns and 'total_cases' in df.columns:
    df['case_fatality_rate'] = (df['total_deaths'] / df['total_cases']).replace([np.inf, np.nan], 0) * 100

if 'people_vaccinated_per_hundred' in df.columns and 'people_fully_vaccinated_per_hundred' in df.columns:
    df['vaccination_gap'] = df['people_vaccinated_per_hundred'] - df['people_fully_vaccinated_per_hundred']

# Save cleaned dataset
df.to_csv(CLEANED_OUTPUT, index=False)
print(f" Cleaned dataset saved at: {CLEANED_OUTPUT}")

# -----------------------
# Plots
# -----------------------
def plot_total_cases_and_deaths(location):
    d = df[df['location'] == location].sort_values('date')
    if d.empty:
        return
    plt.figure(figsize=(12, 6))
    plt.plot(d['date'], d['total_cases'], label="Total cases", linewidth=2)
    plt.plot(d['date'], d['total_deaths'], label="Total deaths", linewidth=2)
    plt.title(f"Cumulative COVID-19 impact in {location}")
    plt.xlabel("Date")
    plt.ylabel("Count")
    plt.legend()
    plt.tight_layout()
    plt.show()

def plot_vaccination_progress(location):
    d = df[df['location'] == location].sort_values('date')
    if d.empty:
        return
    if 'people_vaccinated_per_hundred' not in d.columns or 'people_fully_vaccinated_per_hundred' not in d.columns:
        print(f" No vaccination data for {location}")
        return
    
    plt.figure(figsize=(12, 6))
    plt.plot(d['date'], d['people_vaccinated_per_hundred'], label="At least one dose (%)", linewidth=2)
    plt.plot(d['date'], d['people_fully_vaccinated_per_hundred'], label="Fully vaccinated (%)", linewidth=2)
    plt.title(f" Vaccination Progress in {location}")
    plt.xlabel("Date")
    plt.ylabel("Percent of population (%)")
    plt.legend()
    plt.tight_layout()
    plt.show()

def covid_heatmap(metric="cases_per_million", top_n=20):
    if metric not in df.columns:
        print(f" Metric '{metric}' not found in dataset")
        return
    
    pivot = df.pivot(index="date", columns="location", values=metric)
    top_countries = pivot.max().sort_values(ascending=False).head(top_n).index
    pivot = pivot[top_countries]

    plt.figure(figsize=(14, 8))
    sns.heatmap(pivot.T, cmap="Reds", cbar_kws={'label': metric}, xticklabels=30)
    plt.title(f" COVID-19 {metric} Heatmap (Top {top_n} countries)")
    plt.xlabel("Date")
    plt.ylabel("Country")
    plt.tight_layout()
    plt.show()

# Example usage
locations = df['location'].unique().tolist()
loc1, loc2 = (locations + locations)[:2]  # ensures at least 2

print(f"\n Plots for: {loc1} and {loc2}")
plot_total_cases_and_deaths(loc1)
plot_total_cases_and_deaths(loc2)
plot_vaccination_progress(loc1)
plot_vaccination_progress(loc2)
covid_heatmap("cases_per_million", top_n=15)

# -----------------------
# Summary statistics
# -----------------------
print("\n Summary Statistics:")

if 'cases_per_million' in df.columns:
    top_countries = df.groupby('location')['cases_per_million'].max().sort_values(ascending=False).head(5)
    print(" Top 5 countries by cases per million:\n", top_countries)

if 'people_fully_vaccinated_per_hundred' in df.columns:
    avg_vacc = df.groupby('location')['people_fully_vaccinated_per_hundred'].max().mean()
    print(f"\n Average vaccination coverage across countries: {avg_vacc:.2f}%")

if 'case_fatality_rate' in df.columns:
    avg_cfr = df.groupby('location')['case_fatality_rate'].max().mean()
    print(f" Average case fatality rate: {avg_cfr:.2f}%")

print("\n Done.")
