### Imports

In [10]:
import gurobipy as gp
from gurobipy import GRB
import numpy as np
import scipy.sparse as sp
import pandas as pd
import matplotlib.pyplot as plt
import math 
import os 

### Visualization Functions

In [11]:
# Function to save and plot SIRV proportions
def save_and_plot_sirv_proportions(S, I, R, X, population, counties, time_periods, output_folder):
    # Load county names
    county_names_path = os.path.join("inputData", "floridaCountyNames.csv")
    county_names_df = pd.read_csv(county_names_path)
    county_names = county_names_df["County"].tolist()

    # Collect SIRV proportions data
    sirv_data = []
    for i in counties:
        for t in time_periods:
            total_population = population[i]
            s_prop = S[i, t].x / total_population if total_population > 0 else 0
            i_prop = I[i, t].x / total_population if total_population > 0 else 0
            r_prop = R[i, t].x / total_population if total_population > 0 else 0
            x_prop = X[i, t].x / total_population if total_population > 0 else 0
            sirv_data.append([i, t, s_prop, i_prop, r_prop, x_prop])
    
    sirv_df = pd.DataFrame(sirv_data, columns=["region", "timePeriod", "S_proportion", "I_proportion", "R_proportion", "X_proportion"])
    sirv_csv_path = os.path.join(output_folder, "sirvProportions.csv")
    sirv_df.to_csv(sirv_csv_path, index=False)
    print(f"SIRV proportions data saved to '{sirv_csv_path}'")

    # Plot SIRV proportions
    num_regions = len(counties)
    rows = min(23, num_regions)
    cols = math.ceil(num_regions / rows)

    fig, axes = plt.subplots(rows, cols, figsize=(15, 60), sharex=True, sharey=True)
    fig.suptitle("SIRV Proportions Over Time for Each County", fontsize=16, y=0.95)
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])

    for idx, i in enumerate(counties):
        row = idx // cols
        col = idx % cols
        ax = axes[row, col] if rows > 1 else axes[col]
        
        region_data = sirv_df[sirv_df["region"] == i]
        county_name = county_names[i] if i < len(county_names) else f"Region {i}"
        ax.plot(region_data["timePeriod"], region_data["S_proportion"], label="S", color="blue")
        ax.plot(region_data["timePeriod"], region_data["I_proportion"], label="I", color="red")
        ax.plot(region_data["timePeriod"], region_data["R_proportion"], label="R", color="green")
        ax.plot(region_data["timePeriod"], region_data["X_proportion"], label="V", color="orange")
        ax.legend(loc="upper right", fontsize=6)
        ax.set_title(county_name, fontsize=10)
        ax.grid(False)
        
        if row == rows - 1:
            ax.set_xlabel("Time Period")
        if col == 0:
            ax.set_ylabel("Proportion")

            

    plot_path = os.path.join(output_folder, "sirvProportions.png")
    plt.savefig(plot_path, bbox_inches='tight')
    print(f"SIRV proportions plot saved to '{plot_path}'")
    plt.show()

# Function to save decision variables to CSV files
def save_decision_variables(Z, A, S, I, R, X, u, output_folder, time_periods, decision_periods, counties):
    # Save patient transfer data based on binary variable A
    transfer_data = []
    for i in counties:
        for j in counties:
            for t_prime in decision_periods:
                if A[i, j, t_prime].x > 0.5:  # Check if A[i, j, t_prime] is active (binary decision variable)
                    transfer_data.append([i, j, t_prime, int(A[i, j, t_prime].x)])
    transfer_csv_path = os.path.join(output_folder, "transferData.csv")
    transfer_df = pd.DataFrame(transfer_data, columns=["fromRegion", "toRegion", "timePeriod", "transferIndicator"])
    transfer_df.to_csv(transfer_csv_path, index=False)
    print(f"Patient transfer data saved to '{transfer_csv_path}'")

    # Save actual patient movement quantities in Z
    travel_data = []
    for i in counties:
        for j in counties:
            for t_prime in decision_periods:
                quantity = Z[i, j, t_prime].x
                if quantity > 0:
                    travel_data.append([i, j, t_prime, quantity])
    travel_csv_path = os.path.join(output_folder, "travelData.csv")
    travel_df = pd.DataFrame(travel_data, columns=["fromRegion", "toFacility", "timePeriod", "quantity"])
    travel_df.to_csv(travel_csv_path, index=False)
    print(f"Travel data saved to '{travel_csv_path}'")

    # Save vaccine administration data
    vaccine_admin_data = []
    for i in counties:
        for t in time_periods:
            quantity = X[i, t].x
            if quantity > 0:
                vaccine_admin_data.append([i, t, quantity])
    vaccine_admin_csv_path = os.path.join(output_folder, "vaccineAdminData.csv")
    vaccine_admin_df = pd.DataFrame(vaccine_admin_data, columns=["region", "timePeriod", "administeredVaccines"])
    vaccine_admin_df.to_csv(vaccine_admin_csv_path, index=False)
    print(f"Vaccine administration data saved to '{vaccine_admin_csv_path}'")

    # Save SIRV data
    save_variable_to_csv(S, "S", "susceptibleData.csv", output_folder, counties, time_periods)
    save_variable_to_csv(I, "I", "infectedData.csv", output_folder, counties, time_periods)
    save_variable_to_csv(R, "R", "recoveredData.csv", output_folder, counties, time_periods)
    save_variable_to_csv(X, "X", "vaccinatedData.csv", output_folder, counties, time_periods)
    save_variable_to_csv(u, "u", "unmetDemandData.csv", output_folder, counties, decision_periods)

# Helper function to save each variable to CSV
def save_variable_to_csv(var, var_name, file_name, output_folder, counties, time_periods):
    data = []
    for i in counties:
        for t in time_periods:
            data.append([i, t, var[i, t].x])
    df = pd.DataFrame(data, columns=["region", "timePeriod", var_name])
    path = os.path.join(output_folder, file_name)
    df.to_csv(path, index=False)
    print(f"{var_name} data saved to '{path}'")

# Function to plot unmet demand
def plot_unmet_demand(u, decision_periods, counties):
    unmet_demand = [sum(u[i, t_prime].x for i in counties) for t_prime in decision_periods]
    plt.figure(figsize=(10, 6))
    plt.plot(decision_periods, unmet_demand, marker='o', linestyle='-', color='b')
    plt.title("Total Unmet Healthcare Demand Over Time")
    plt.xlabel("Time Period (days)")
    plt.ylabel("Total Unmet Demand (thousands of individuals)")
    plt.grid(True)
    plt.xticks(decision_periods)
    plt.show()


### Run the Model (Travel Constraints Not Working)

### Plot and Save Results