In [1]:
import pandas as pd
import geopandas as gpd
import io
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import seaborn as sns
import imageio
import PIL
from tqdm import tqdm
from datetime import datetime
import plotly.graph_objects as go

In [2]:
def world_map(cases_df):
    
    # Load the World Map
    url = "https://naciscdn.org/naturalearth/110m/cultural/ne_110m_admin_0_countries.zip"
    world = gpd.read_file(url)
    
    # Unique country names
    who_countries = set(cases_df["country"].unique())
    world_countries = set(world["NAME"].unique())
    
    # Mismatched country names
    mismatches = {name: name for name in world_countries if name not in who_countries}
    
    # Manually map mismatched names
    country_name_mapping = {
        "United States of America": "United States",
        "Russian Federation": "Russia",
        "Democratic Republic of the Congo": "Democratic Republic of the Congo",
        "Czechia": "Czech Republic",
        "Viet Nam": "Vietnam",
        "Iran (Islamic Republic of)": "Iran",
        "Republic of Korea": "South Korea",
        "Syrian Arab Republic": "Syria",
        "United Kingdom of Great Britain and Northern Ireland": "United Kingdom",
        "Venezuela (Bolivarian Republic of)": "Venezuela",
        "Bolivia (Plurinational State of)": "Bolivia",
        "Lao People's Democratic Republic": "Laos",
        "Brunei Darussalam": "Brunei",
        "Republic of Moldova": "Moldova",
        "Taiwan": "Taiwan, China",
    }
    
    # Apply name mapping
    world["NAME"] = world["NAME"].replace(country_name_mapping)

    return world

In [3]:
# Function to validate if a string is a date
def is_valid_date(date_str):
    try:
        datetime.strptime(date_str, "%Y-%m-%d")
        return True
    except ValueError:
        return False

In [4]:
def time_lapse_world_map_plot(df, world, arg):
    
    merge = world.join(df, on = 'NAME', how = 'left')
    dates_list = merge.columns.to_list()
    
    # Convert all elements to string (if needed) and ensure they are valid
    sorted_dates = sorted(
        filter(is_valid_date, dates_list),
        key=lambda date: datetime.strptime(date, "%Y-%m-%d")
    )
    
    plt.rcParams.update({'figure.max_open_warning': 0}) #supresses the number of images opened warning during runtime
    
    img_frames = []
    
    for dates in sorted_dates[:45]:
        
        # Define figure and axis
        fig, ax = plt.subplots(figsize = (10,10))

        # Define color map and classification bins
        cmap = plt.get_cmap('Blues')
        bins = [5, 100, 1000, 5000, 10000, 25000, 50000, 100000]  # Custom bins for classification


        # Plot the map
        merge.plot(
            column = dates, cmap = cmap, edgecolor = 'black', linewidth = 0.4,
            scheme = 'user_defined',
            classification_kwds = {'bins':bins},
            legend = False,
            ax = ax
        )
        
        # Create a colorbar with discrete boundaries
        norm = mcolors.BoundaryNorm(bins, cmap.N) 
        sm = cm.ScalarMappable(norm=norm, cmap=cmap)

        # Add a horizontal colorbar
        cbar = fig.colorbar(sm, ax = ax, orientation = 'horizontal', fraction = 0.018, pad = 0.02, aspect = 45) #shrink = 0.8)
        
        # Customize title and aesthetics
        ax.set_title(f'Daily Confirmed Coronavirus {arg}: {dates}', fontdict = {'fontsize': 16}, pad = 12.5)
        ax.set_axis_off()
        
        plt.tight_layout()
        img = ax.get_figure()
        
        #Pillow is useful for saving the images generated in each loop and then appending it to make it a gif
        f = io.BytesIO()  
        img.savefig(f, format = 'png', bbox_inches = 'tight')
        f.seek(0)
        img_frames.append(PIL.Image.open(f))
    
    img_frames[0].save(f'COVID-19_{arg}_map.gif', format = 'GIF', append_images = img_frames[1:], save_all = True, duration = 200, loop = 0)
    
    f.close()
    return

In [5]:
# Function to get the last word of country names
def get_last_word(name):
    return name.split()[-1]

In [6]:
def time_lapse_bar_plot(cases_df):

    # Step 2: Aggregate Cumulative Deaths by Country and Date
    cumulative_deaths = cases_df.pivot(index="date", columns="country", values="cumulative_deaths").fillna(0)
    
    # Step 3: Get List of Dates for Animation (Sort Chronologically)
    dates_list = cumulative_deaths.index.sort_values()

    # Define fixed colors for top 10 bars
    colors = ['#264653', '#2a9d8f', '#e9c46a', '#f4a261','#e76f51',
              '#f4acb7','#9f6976','#52796f','#f8ffe5','#ef476f']
    
    # Step 4: Create the Bar Chart Animation
    fig, ax = plt.subplots(figsize = (12, 8))
    
    img_frames = []
    
    for date in dates_list[:45]:
        ax.clear()
        
        # Get top 10 countries for the given date
        top_10 = cumulative_deaths.loc[date].nlargest(10).sort_values()
        
        # Extract last word of country names
        country_labels = [get_last_word(country) for country in top_10.index]
    
        # Ensure we assign colors correctly (only for top 10 countries)
        assigned_colors = colors[:len(top_10)]
    
        # Plot horizontal bar chart with fixed colors
        ax.barh(country_labels, top_10.values, color = assigned_colors, edgecolor = "black", height = 0.6)
        
        ax.set_title(f"Num of Deaths due to COVID-19: {date.strftime('%Y-%m-%d')}", fontsize = 16)
        ax.set_xlabel("Cumulative Deaths", fontsize = 12)
        ax.set_ylabel("Country", fontsize = 12)
        ax.grid(True, linestyle="--", alpha = 0.5)
    
        # Save frame
        f = io.BytesIO()
        plt.savefig(f, format = "png", bbox_inches = "tight")
        f.seek(0)
        img_frames.append(PIL.Image.open(f))
    
    # Step 5: Save as GIF
    img_frames[0].save("COVID-19_Deaths_timeline.gif", format = "GIF", append_images = img_frames[1:], save_all = True, duration = 200, loop = 0)

    return

In [7]:
def cases_deaths_plot(daily_cases, daily_deaths):

    # COVID-19 waves with specific colors
    covid_waves = {
        "First Wave": ("2020-03-01", "2020-07-31", "#FF9999"),   # Light Red
        "Second Wave": ("2020-10-01", "2021-03-31", "#FFCC99"),  # Light Orange
        "Delta Wave": ("2021-06-01", "2021-09-30", "#99CCFF"),   # Light Blue
        "Omicron Wave": ("2021-11-01", "2022-02-28", "#99FF99")  # Light Green
    }
    
    # Daily new cases & deaths with COVID Waves
    fig, ax1 = plt.subplots(figsize=(14, 7))
    
    # Y-axis : Daily cases
    ax1.set_xlabel("Date", fontsize=12)
    ax1.set_ylabel("New Cases", fontsize=12, color="blue")
    sns.lineplot(x=daily_cases.index, y=daily_cases.values, linewidth=2, color="blue", label="Daily Cases", ax=ax1)
    ax1.tick_params(axis='y', labelcolor="blue")
    
    # Y-Axis : Daily deaths
    ax2 = ax1.twinx()  
    ax2.set_ylabel("New Deaths", fontsize=12, color="red")
    sns.lineplot(x=daily_deaths.index, y=daily_deaths.values, linewidth=2, color="red", label="Daily Deaths", ax=ax2)
    ax2.tick_params(axis='y', labelcolor="red")
    
    for wave, (start, end, color) in covid_waves.items():
        start_date = pd.to_datetime(start)
        end_date = pd.to_datetime(end)
        ax1.axvspan(start_date, end_date, color=color, alpha=0.2, label=wave)  # Shaded wave regions
        ax1.axvline(start_date, color=color, linestyle="--", alpha=0.8)  # Vertical line at wave start
        ax1.text(start_date, max(daily_cases.values) * 0.6, wave, rotation=90, fontsize=10, color="black")  # Annotate wave
        
    plt.title("Global Daily COVID-19 Cases & Deaths with Waves", fontsize=14)
    fig.tight_layout()  
    #ax1.grid(True, linestyle="--", alpha=0.5)
    
    # Show Plot
    plt.show()

In [8]:
def interactive_plot(daily_cases, daily_deaths, vaccination_all_dose_w, vaccination_one_dose_w):
    # Define COVID Waves (Start Date, End Date, Color)
    covid_waves  =  {
        "First Wave": ("2020-03-01", "2020-07-31", "rgba(255, 165, 0, 0.5)"),  # Light Orange
        "Second Wave": ("2020-10-01", "2021-03-31", "rgba(255, 99, 71, 0.5)"), # Light Red
        "Delta Wave": ("2021-06-01", "2021-09-30", "rgba(30, 144, 255, 0.5)"), # Light Blue
        "Omicron Wave": ("2021-11-01", "2022-02-28", "rgba(34, 139, 34, 0.5)") # Light Green
    }
    
    # Create figure
    fig  =  go.Figure()
    
    # Add COVID Waves as shaded vertical regions
    for wave, (start, end, color) in covid_waves.items():
        fig.add_shape(
            type = "rect",
            x0 = start,
            x1 = end,
            y0 = 0,
            y1 = 1,
            yref = "paper",
            fillcolor = color,
            opacity = 0.3,
            layer = "below",
            line_width = 0
        )
        # Add wave label
        fig.add_annotation(
            x = pd.to_datetime(start) + (pd.to_datetime(end) - pd.to_datetime(start)) / 2,
            y = 1,
            yref = "paper",
            text = wave,
            showarrow = False,
            font = dict(size = 10, color = "black"),
            textangle = 270
        )

    # Add Daily Cases (Smoothed)
    fig.add_trace(go.Scatter(
        x = daily_cases.index, 
        y = daily_cases.values, 
        mode = 'lines',
        name = "Daily Cases",
        line = dict(color = 'rgba(0, 0, 255, 0.4)', width = 2)
    ))
    
    # Add Daily Deaths (Smoothed)
    fig.add_trace(go.Scatter(
        x = daily_deaths.index, 
        y = daily_deaths.values, 
        mode = 'lines',
        name = "Daily Deaths",
        line = dict(color = 'rgba(255, 0, 0, 0.4)', width = 2),
        yaxis = "y2"
    ))
    
    # Add Fully Vaccinated (%)
    fig.add_trace(go.Scatter(
        x = vaccination_all_dose_w['day'], 
        y = vaccination_all_dose_w['people_fully_vaccinated'], 
        mode = 'lines',
        name = "Fully Vaccinated (%)",
        line = dict(color = 'rgba(0, 128, 0, 0.7)', width = 2),
        yaxis = "y3"
    ))
    
    # Add One Dose Vaccinated (%)
    fig.add_trace(go.Scatter(
        x = vaccination_one_dose_w['day'], 
        y = vaccination_one_dose_w['people_one_dose_vaccinated'], 
        mode = 'lines',
        name = "One Dose Vaccinated (%)",
        line = dict(color = 'rgba(128, 0, 128, 0.7)', width = 2),
        yaxis = "y4"
    ))
    
    # Update Layout
    fig.update_layout(
        title = "Interactive Global COVID-19 Cases, Deaths & Vaccination Rates",
        width = 1000,
        height = 600,
        title_x = 0.5,
        title_font_size = 16,
        xaxis = dict(
            title = "Date",
            showgrid = False,
            tickangle = -30,
            tickfont = dict(size = 10)
        ),
        
        # First Y-axis (Cases)
        yaxis = dict(
            title = dict(text = "New Cases", font = dict(color = "blue", size = 12)), 
            tickfont = dict(color = "blue", size = 10)
        ),
        
        # Second Y-axis (Deaths)
        yaxis2 = dict(
            title = dict(text = "New Deaths", font = dict(color = "red", size = 12)), 
            tickfont = dict(color = "red", size = 10),
            overlaying = "y", 
            side = "right"
        ),
        
        # Third Y-axis (Fully Vaccinated %) - **No Title, No Ticks**
        yaxis3 = dict(
            showticklabels = False,  
            overlaying = "y", 
            side = "right",
            anchor = "free",
            position = 1
        ),
        
        # Fourth Y-axis (One Dose Vaccinated %) - **No Title, No Ticks**
        yaxis4 = dict(
            showticklabels = False,  
            overlaying = "y", 
            side = "right",
            anchor = "free",
            position = 1
        ),
    
        # Move legend below the plot (horizontal)
        legend = dict(
            orientation = "h",
            yanchor = "top", 
            y = -0.2, 
            xanchor = "center", 
            x = 0.5,
            font = dict(size = 12)
        ),
        
        # Make background transparent
        plot_bgcolor = 'rgba(0,0,0,0)',  
        paper_bgcolor = 'rgba(0,0,0,0)',  
        
        hovermode = "x"
    )
    
    # Show interactive plot
    fig.show()

    return