## This code runs for Case 2 Critical Futures - Chapter 6
## Parameterisation + Times series plotting + Demand Flow generation + Evaluation metrics
## last updated on 04-08-2024
contact asarfraz1@sheffield.ac.uk for any querries


In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from scipy.stats import qmc
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

def assure_path_exists(path):
    if not os.path.exists(path):
        print(f"Creating directory: {path}")
        os.makedirs(path)
    else:
        print(f"Directory already exists: {path}")

def generate_demand_flows(n_scenarios):
    demand_param_ranges = {
        "Cotton": 99.61,
        "Rice": 66.18,
        "Wheat": 23.67,
        "Sugarcane": 4.16,
        "Miscellaneous": 16.42,
    }
    crop_percentages = {
        'Cotton': [0.00, 0.00, 0.00, 0.00, 0.10, 0.15, 0.20, 0.20, 0.15, 0.15, 0.05, 0.00],
        'Rice': [0.00, 0.00, 0.00, 0.00, 0.00, 0.25, 0.25, 0.20, 0.20, 0.10, 0.00, 0.00],
        'Wheat': [0.20, 0.25, 0.25, 0.15, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.15],
        'Sugarcane': [0.083] * 12,
        'Miscellaneous': [0.083] * 12
    }
    
    demand_flows = {}
    for crop, annual_demand in demand_param_ranges.items():
        monthly_demands = np.outer(np.ones(n_scenarios), crop_percentages[crop]) * annual_demand
        demand_flows[crop] = monthly_demands
        print(f"{crop} Demand: Value = {annual_demand:.3f}")
    return demand_flows

def lhs_params(param_ranges, n_samples):
    n_params = len(param_ranges)
    sampler = qmc.LatinHypercube(d=n_params)
    lhs_samples = sampler.random(n=n_samples)
    for i in range(n_params):
        min_val, max_val = param_ranges[i]
        lhs_samples[:, i] = min_val + (max_val - min_val) * lhs_samples[:, i]
    return lhs_samples

def ComponentRescaling(sites, HistoricalMonthlyQ, LHsamples, rescaled_folder):
    mMonths = 12
    nScenarios = LHsamples.shape[0]
    nYears = len(HistoricalMonthlyQ) // mMonths

    historical_avg = np.mean(HistoricalMonthlyQ.reshape(nYears, mMonths), axis=0)
    
    for j in range(nScenarios):
        amplitude, lambda_coeff = LHsamples[j, :]
        
        NewMonthlyQ = np.zeros(len(HistoricalMonthlyQ))
        for m in range(12):
            k = (m + 3) % 12
            q_new = amplitude * (lambda_coeff * historical_avg[k] + (1 - lambda_coeff) * historical_avg[m])
            NewMonthlyQ[m::mMonths] = q_new
        
        for site in sites:
            file_path = os.path.join(rescaled_folder, f"{site}_Scenario{j+1}.csv")
            np.savetxt(file_path, NewMonthlyQ, fmt='%.3f', delimiter=',')
        
        scaling_factors = NewMonthlyQ[:mMonths] / historical_avg
        scaling_factors_path = os.path.join(rescaled_folder, f"Scenario{j+1}_scaling_factors.txt")
        np.savetxt(scaling_factors_path, scaling_factors, fmt='%.3f')

def CombineRescaledWithRainbase(sites, nScenarios, rescaled_folder, rainbase_data, combined_folder):
    nYears = 30
    mMonths = 12
    total_months = nYears * mMonths
    
    if len(rainbase_data) < total_months:
        rainbase_data = np.tile(rainbase_data, (total_months // len(rainbase_data) + 1))[:total_months]
    
    rainbase_reshaped = rainbase_data.reshape(nYears, mMonths)
    
    for i in range(nScenarios):
        for site in sites:
            rescaled_data = np.loadtxt(os.path.join(rescaled_folder, f"{site}_Scenario{i+1}.csv"), delimiter=',')
            
            if rescaled_data.shape != (nYears, mMonths):
                rescaled_data = rescaled_data.reshape(nYears, mMonths)
            
            combined_data = rescaled_data + rainbase_reshaped
            file_path = os.path.join(combined_folder, f"total_{site}_Scenario{i+1}.csv")
            np.savetxt(file_path, combined_data, fmt='%.3f', delimiter=',')

def CombineSupplyFlows(sites, nScenarios, combined_folder, supply_flows_rescaled_folder):
    nYears = 30
    mMonths = 12
    for i in range(1, nScenarios + 1):
        combined_sum = np.zeros((nYears, mMonths))
        
        for site in sites:
            file_path = os.path.join(combined_folder, f"total_{site}_Scenario{i}.csv")
            data = np.loadtxt(file_path, delimiter=',')
            combined_sum += data
        
        output_file_path = os.path.join(supply_flows_rescaled_folder, f"supply_Scenario{i}.csv")
        np.savetxt(output_file_path, combined_sum, fmt='%.3f', delimiter=',')
        print(f"Saved combined supply flow to {output_file_path}")

def calculate_average_flows(input_folder, output_folder, prefix):
    assure_path_exists(output_folder)
    
    for filename in os.listdir(input_folder):
        if filename.endswith(".csv"):
            file_path = os.path.join(input_folder, filename)
            data = np.loadtxt(file_path, delimiter=',')
            
            average_data = np.mean(data, axis=0)
            
            output_file_path = os.path.join(output_folder, f"{prefix}_{filename}")
            np.savetxt(output_file_path, average_data.reshape(1, -1), fmt='%.3f', delimiter=',')

In [2]:
def month_range_str(start_month):
    months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
    return f"{months[start_month]}-{months[(start_month+1)%12]}-{months[(start_month+2)%12]}"

def create_time_series_plots_with_crops(supply_flows_rescaled_FINAL, demand_flows, results_folder):
    supply_files = [f for f in os.listdir(supply_flows_rescaled_FINAL) if f.endswith('.csv')]
    n_scenarios = len(supply_files)
    n_months = 12

    supply_data = []
    for file in supply_files:
        supply = np.loadtxt(os.path.join(supply_flows_rescaled_FINAL, file), delimiter=',')
        supply_data.append(supply)

    supply_max = np.max(supply_data, axis=0)
    supply_min = np.min(supply_data, axis=0)
    supply_avg = np.mean(supply_data, axis=0)

    plt.figure(figsize=(12, 8))
    months = range(1, n_months + 1)
    
    plt.fill_between(months, supply_min, supply_max, alpha=0.2, color='gray', label='Supply Range')
    plt.plot(months, supply_avg, label='Average Supply', color='gray', linewidth=2)
    
    crops = ['Cotton', 'Rice', 'Wheat', 'Sugarcane', 'Miscellaneous']
    colors = ['yellow', 'purple', 'blue', 'red', 'cyan']
    
    for crop, color in zip(crops, colors):
        crop_demand_avg = np.mean(demand_flows[crop], axis=0)
        plt.plot(months, crop_demand_avg, label=f'{crop} Demand', color=color, linewidth=2)

    total_demand_data = np.sum([demand_flows[crop] for crop in crops], axis=0)
    total_demand_avg = np.mean(total_demand_data, axis=0)
    plt.plot(months, total_demand_avg, label='Total Demand', color='black', linewidth=2, linestyle='--')

    plt.xlabel('Month')
    plt.ylabel('Water Volume')
    plt.title('Supply and Demand Time Series')
    plt.xticks(months, ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
    plt.legend(bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=4)

    plt.tight_layout()
    plt.savefig(os.path.join(results_folder, 'case1_historic.png'), bbox_inches='tight',dpi=600)
    plt.close()

    print(f"Time series plot has been saved in {results_folder}")

    print("\nSummary Statistics:")
    print(f"Supply - Min: {supply_min.min():.4f}, Max: {supply_max.max():.4f}, Avg: {supply_avg.mean():.4f}")
    print(f"Total Demand - Min: {total_demand_avg.min():.4f}, Max: {total_demand_avg.max():.4f}, Avg: {total_demand_avg.mean():.4f}")
    for crop in crops:
        crop_demand_avg = np.mean(demand_flows[crop], axis=0)
        print(f"{crop} Demand - Min: {crop_demand_avg.min():.4f}, Max: {crop_demand_avg.max():.4f}, Avg: {crop_demand_avg.mean():.4f}")

In [4]:
def calculate_var_metrics(supply_data, demand_flows, current_storage=14.3, potential_storage=18):
    n_scenarios = len(supply_data)
    
    annual_scarcity = np.zeros(n_scenarios)
    surplus_capacity = np.zeros(n_scenarios)
    deficit_severity = np.zeros(n_scenarios)
    variability = np.zeros(n_scenarios)
    best_months = np.zeros(n_scenarios, dtype=int)
    worst_months = np.zeros(n_scenarios, dtype=int)
    storage_adequacy = np.zeros(n_scenarios)
    storage_adequacy_points = np.zeros(n_scenarios, dtype=int)
    
    for i in range(n_scenarios):
        monthly_supply = np.mean(supply_data[i], axis=0)
        monthly_demand = np.sum([demand_flows[crop][i] for crop in demand_flows], axis=0)
        
        annual_supply = np.sum(monthly_supply)
        annual_demand = np.sum(monthly_demand)
        annual_scarcity[i] = annual_demand / annual_supply
        
        extended_supply = np.concatenate([monthly_supply[-2:], monthly_supply, monthly_supply[:2]])
        extended_demand = np.concatenate([monthly_demand[-2:], monthly_demand, monthly_demand[:2]])
        
        rolling_supply = np.convolve(extended_supply, np.ones(3), 'valid') /3
        rolling_demand = np.convolve(extended_demand, np.ones(3), 'valid') /3
        
        rolling_balance = rolling_supply - rolling_demand
        
        best_balance = np.max(rolling_balance)
        worst_balance = np.min(rolling_balance)
        variability[i] = best_balance - worst_balance
        
        best_months[i] = np.argmax(rolling_balance) % 12
        worst_months[i] = np.argmin(rolling_balance) % 12
        
        surplus_capacity[i] = best_balance
        deficit_severity[i] = -worst_balance
        
        storage_adequacy[i] = current_storage / variability[i]
        
        if variability[i] <= current_storage:
            storage_adequacy_points[i] = 1  # Adequate for both current and potential storage
        elif variability[i] <= (current_storage + potential_storage):
            storage_adequacy_points[i] = 2  # Adequate for future combined
        else:
            storage_adequacy_points[i] = 0  # Inadequate for both
    
    return annual_scarcity, variability, surplus_capacity, deficit_severity, best_months, worst_months, storage_adequacy, storage_adequacy_points

def plot_lambda_vs_amplitude_updated_again(data, metrics, output_file):
    descriptive_labels = {
        'Variability': 'Variability',
        'Deficit_Severity': 'Deficit Severity',
        'Surplus_Capacity': 'Surplus Capacity',
    }
    fig, axs = plt.subplots(2, 3, figsize=(18, 10))  # Increased figure width
    fig.suptitle('Lambda vs Amplitude', fontsize=16)
    
    axs = axs.flatten()
    subplot_labels = ['(a)', '(b)', '(c)', '(d)', '(e)']
    
    for i, metric in enumerate(metrics):
        ax = axs[i]
        
        if metric == 'Annual_Scarcity':
            scatter = ax.scatter(data['Input_lambda'], data['Input_Amplitude'], 
                                 c=data[metric], cmap='coolwarm')
            ax.set_title(f'Annual Scarcity')
            plt.colorbar(scatter, ax=ax, label='Annual Scarcity')
        
        elif metric in ['Variability','Deficit_Severity','Surplus_Capacity']:
            scatter = ax.scatter(data['Input_lambda'], data['Input_Amplitude'], 
                                 c=data[metric], cmap='coolwarm')
            ax.set_title(descriptive_labels[metric])
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label(descriptive_labels[metric],labelpad=15)
        
        elif metric == 'Storage_Adequacy_Points':
            colors = ['red' if val == 0 else 'yellow' if val == 1 else '#1f77b4' for val in data[metric]]
            scatter = ax.scatter(data['Input_lambda'], data['Input_Amplitude'], c=colors)
            ax.set_title('Storage Adequacy')
        
        ax.set_xlabel('Lambda')
        ax.set_ylabel('Amplitude')
        ax.text(0.03, 1.07, subplot_labels[i], transform=ax.transAxes, fontsize=12,
                verticalalignment='top', bbox=dict(boxstyle='round,pad=0.3', edgecolor='none', facecolor='white'))
    
    # Create legend for Storage Adequacy
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', label='Inadequate for both (0)',
               markerfacecolor='red', markersize=10),
        Line2D([0], [0], marker='o', color='w', label='Adequate for both current and potential (1)',
               markerfacecolor='yellow', markersize=10),
        Line2D([0], [0], marker='o', color='w', label='Adequate for future combined (2)',
               markerfacecolor='#1f77b4', markersize=10)
    ]
    
    # Place legend in the empty space on the right
    fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(0.9, 0.3), title='Storage Adequacy')
    
    fig.delaxes(axs[5])
    
    plt.tight_layout()
    plt.savefig(output_file, bbox_inches='tight', dpi=600)
    plt.close()

In [5]:
def main():
    sites = ['All']  
    nMonths = 12
    
    all_param_ranges = [
        [0.47, 1.78],  # Amplitude
        [0, 1]        # lambda
    ]
    
    n_samples = 1000

    all_lhs_samples = lhs_params(all_param_ranges, n_samples)
    np.savetxt("LHsamples_all_params.txt", all_lhs_samples, fmt='%.6f')
    
    global supply_lhs_samples
    supply_lhs_samples = all_lhs_samples

    demand_flows = generate_demand_flows(n_samples)

    Qgs = np.loadtxt('meltwater_hist_km3.txt')
    Qall = np.loadtxt('total_km3.txt')
    rainbase_data = np.loadtxt('rainbase_km3.txt')

    hist_data = np.loadtxt('meltwater_hist_km3.txt')
    scenario_data = np.loadtxt('total_km3.txt')
    scenario_type = 'Scenario'

    supply_flows_folder = './Case2_km3'
    rescaled_folder = os.path.join(supply_flows_folder, '1. rescaled_meltwater')
    combined_folder = os.path.join(supply_flows_folder, '2. combined')
    supply_flows_rescaled_folder = os.path.join(supply_flows_folder, '3. supply_flows_rescaled_combined')
    supply_flows_rescaled_FINAL = os.path.join(supply_flows_folder, '4. supply_flows_rescaled_FINAL')
    supply_demand_plot_folder = os.path.join(supply_flows_folder, '5. case2_supply_demand_plots')
    results_folder = os.path.join(supply_flows_folder, '6. results')
    individual_plots_folder = os.path.join(supply_flows_folder, '7. individual_scenario_plots')
    
    for folder in [supply_flows_folder, rescaled_folder, combined_folder, supply_flows_rescaled_folder, 
                   supply_flows_rescaled_FINAL, supply_demand_plot_folder, results_folder, 
                   individual_plots_folder]:
        assure_path_exists(folder)

    ComponentRescaling(sites, Qgs, supply_lhs_samples, rescaled_folder)
    CombineRescaledWithRainbase(sites, n_samples, rescaled_folder, rainbase_data, combined_folder)
    CombineSupplyFlows(sites, n_samples, combined_folder, supply_flows_rescaled_folder)
    calculate_average_flows(supply_flows_rescaled_folder, supply_flows_rescaled_FINAL, "FINAL_combined_rescaled")

    supply_data = []
    for i in range(n_samples):
        try:
            supply = np.loadtxt(os.path.join(supply_flows_rescaled_folder, f"supply_Scenario{i+1}.csv"), delimiter=',')
            supply_data.append(supply)
        except FileNotFoundError:
            print(f"Warning: supply_Scenario{i+1}.csv not found. Skipping this scenario.")
    
    print(f"Number of supply scenarios loaded: {len(supply_data)}")

    create_time_series_plots_with_crops(supply_flows_rescaled_FINAL, demand_flows, results_folder)

    current_storage = 14.3  # km³
    potential_storage = 18.0  # km³

    annual_scarcity, variability, surplus_capacity, deficit_severity, best_months, worst_months, storage_adequacy, storage_adequacy_points = calculate_var_metrics(supply_data, demand_flows, current_storage, potential_storage)

    print(f"Storage Adequacy - Min: {storage_adequacy.min():.4f}, Max: {storage_adequacy.max():.4f}, Mean: {storage_adequacy.mean():.4f}")
    print(f"Storage Adequacy Points - Counts: {np.bincount(storage_adequacy_points)}")

    scenario_data = pd.DataFrame({
        'Scenario': range(1, n_samples + 1),
        'Annual_Scarcity': annual_scarcity,
        'Variability': variability,
        'Surplus_Capacity': surplus_capacity,
        'Deficit_Severity': deficit_severity,
        'Storage_Adequacy': storage_adequacy,
        'Storage_Adequacy_Points': storage_adequacy_points,
        'Best_Months': [month_range_str(m) for m in best_months],
        'Worst_Months': [month_range_str(m) for m in worst_months],
        'Input_Amplitude': supply_lhs_samples[:, 0],
        'Input_lambda': supply_lhs_samples[:, 1]
    })

    # Save results
    scenario_data.to_csv(os.path.join(results_folder, 'complete_results.csv'), index=False)

    # Create summary statistics
    summary_stats = pd.DataFrame({
        'Statistic': ['Mean Annual Scarcity', 'Mean Variability', 'Mean Surplus Capacity', 'Mean Deficit Severity', 'Mean Storage Adequacy', 'Storage Adequacy Points Distribution'],
        'Value': [
            np.mean(annual_scarcity),
            np.mean(variability),
            np.mean(surplus_capacity),
            np.mean(deficit_severity),
            np.mean(storage_adequacy),
            str(np.bincount(storage_adequacy_points))
        ]
    })
    summary_stats.to_csv(os.path.join(results_folder, 'summary_statistics.csv'), index=False)

    # Plot results
    metrics = ['Annual_Scarcity', 'Variability', 'Deficit_Severity', 'Surplus_Capacity','Storage_Adequacy_Points']
    plot_lambda_vs_amplitude_updated_again(scenario_data, metrics, os.path.join(results_folder, 'Case2_results_5_panels.png'))

    print("Process completed successfully!")

if __name__ == "__main__":
    main()

Cotton Demand: Value = 99.610
Rice Demand: Value = 66.180
Wheat Demand: Value = 23.670
Sugarcane Demand: Value = 4.160
Miscellaneous Demand: Value = 16.420
Creating directory: ./Case2_km3_06
Creating directory: ./Case2_km3_06\1. rescaled_meltwater
Creating directory: ./Case2_km3_06\2. combined
Creating directory: ./Case2_km3_06\3. supply_flows_rescaled_combined
Creating directory: ./Case2_km3_06\4. supply_flows_rescaled_FINAL
Creating directory: ./Case2_km3_06\5. case2_supply_demand_plots
Creating directory: ./Case2_km3_06\6. results
Creating directory: ./Case2_km3_06\7. individual_scenario_plots
Saved combined supply flow to ./Case2_km3_06\3. supply_flows_rescaled_combined\supply_Scenario1.csv
Saved combined supply flow to ./Case2_km3_06\3. supply_flows_rescaled_combined\supply_Scenario2.csv
Saved combined supply flow to ./Case2_km3_06\3. supply_flows_rescaled_combined\supply_Scenario3.csv
Saved combined supply flow to ./Case2_km3_06\3. supply_flows_rescaled_combined\supply_Scenario4.

In [3]:
# ## this is updated as of 04-08-2024

# def calculate_var_metrics(supply_data, demand_flows, current_storage=14.3, potential_storage=18):
#     n_scenarios = len(supply_data)
    
#     annual_scarcity = np.zeros(n_scenarios)
#     surplus_capacity = np.zeros(n_scenarios)
#     deficit_severity = np.zeros(n_scenarios)
#     variability = np.zeros(n_scenarios)
#     best_months = np.zeros(n_scenarios, dtype=int)
#     worst_months = np.zeros(n_scenarios, dtype=int)
#     storage_adequacy = np.zeros(n_scenarios)
#     storage_adequacy_points = np.zeros(n_scenarios, dtype=int)
    
#     for i in range(n_scenarios):
#         monthly_supply = np.mean(supply_data[i], axis=0)
#         monthly_demand = np.sum([demand_flows[crop][i] for crop in demand_flows], axis=0)
        
#         annual_supply = np.sum(monthly_supply)
#         annual_demand = np.sum(monthly_demand)
#         annual_scarcity[i] = annual_demand / annual_supply
        
#         extended_supply = np.concatenate([monthly_supply[-2:], monthly_supply, monthly_supply[:2]])
#         extended_demand = np.concatenate([monthly_demand[-2:], monthly_demand, monthly_demand[:2]])
        
#         rolling_supply = np.convolve(extended_supply, np.ones(3), 'valid') /3
#         rolling_demand = np.convolve(extended_demand, np.ones(3), 'valid') /3
        
#         rolling_balance = rolling_supply - rolling_demand
        
#         best_balance = np.max(rolling_balance)
#         worst_balance = np.min(rolling_balance)
#         variability[i] = best_balance - worst_balance
        
#         best_months[i] = np.argmax(rolling_balance) % 12
#         worst_months[i] = np.argmin(rolling_balance) % 12
        
#         surplus_capacity[i] = best_balance
#         deficit_severity[i] = -worst_balance
        
#         storage_adequacy[i] = current_storage / variability[i]
        
#         if variability[i] <= current_storage:
#             storage_adequacy_points[i] = 1  # Adequate for current storage
#         elif variability[i] <= (current_storage + potential_storage):
#             storage_adequacy_points[i] = 2  # Adequate for current + potential storage
#         else:
#             storage_adequacy_points[i] = 0  # Inadequate
    
#     return annual_scarcity, variability, surplus_capacity, deficit_severity, best_months, worst_months, storage_adequacy, storage_adequacy_points


# ## plots evaluation metrics 
# def plot_lambda_vs_amplitude_updated_again(data, metrics, output_file):
#     descriptive_labels = {
#         'Variability': 'Variability',
#         'Deficit_Severity': 'Deficit Severity',
#         'Surplus_Capacity': 'Surplus Capacity',
#     }
#     fig, axs = plt.subplots(2, 3, figsize=(15, 10))
#     fig.suptitle('Lambda vs Amplitude', fontsize=16)
    
#     axs = axs.flatten()
#     subplot_labels = ['(a)', '(b)', '(c)', '(d)', '(e)']
    
#     for i, metric in enumerate(metrics):
#         ax = axs[i]
        
#         if metric == 'Annual_Scarcity':
#             scatter = ax.scatter(data['Input_lambda'], data['Input_Amplitude'], 
#                                  c=data[metric], cmap='coolwarm')
#             ax.set_title(f'Annual Scarcity')
#             plt.colorbar(scatter, ax=ax, label='Annual Scarcity')
        
#         elif metric in ['Variability','Deficit_Severity','Surplus_Capacity', ]:
#             scatter = ax.scatter(data['Input_lambda'], data['Input_Amplitude'], 
#                                  c=data[metric], cmap='coolwarm')
#             ax.set_title(descriptive_labels[metric])
#             cbar = plt.colorbar(scatter, ax=ax)
#             cbar.set_label(descriptive_labels[metric],labelpad=15)
        
#         elif metric == 'Storage_Adequacy_Points':
#             colors = ['red' if val == 0 else 'yellow' if val == 1 else '#1f77b4' for val in data[metric]]
#             scatter = ax.scatter(data['Input_lambda'], data['Input_Amplitude'], c=colors)
#             ax.set_title('Storage Adequacy')
            
#             from matplotlib.lines import Line2D
#             legend_elements = [Line2D([0], [0], marker='o', color='w', label='Inadequate (0)',
#                                       markerfacecolor='red', markersize=10),
#                                Line2D([0], [0], marker='o', color='w', label='Adequate for Current (1)',
#                                       markerfacecolor='yellow', markersize=10),
#                                Line2D([0], [0], marker='o', color='w', label='Adequate for Current + Potential (2)',
#                                       markerfacecolor='#1f77b4', markersize=10)]
#             ax.legend(handles=legend_elements, loc='best')
        
#         ax.set_xlabel('Lambda')
#         ax.set_ylabel('Amplitude')
#         ax.text(0.03, 1.07, subplot_labels[i], transform=ax.transAxes, fontsize=12,
#                 verticalalignment='top', bbox=dict(boxstyle='round,pad=0.3', edgecolor='none', facecolor='white'))
    
#     fig.delaxes(axs[5])
    
#     plt.tight_layout()
#     plt.savefig(output_file, bbox_inches='tight', dpi=600)
#     plt.close()


