In [1]:
!pip install pypsa
!pip install neptune-client
!pip install gymnasium
#!pip install ipdb

Collecting pypsa
  Downloading pypsa-0.35.2-py3-none-any.whl.metadata (13 kB)
Collecting netcdf4 (from pypsa)
  Downloading netCDF4-1.7.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting linopy>=0.4 (from pypsa)
  Downloading linopy-0.5.7-py3-none-any.whl.metadata (9.4 kB)
Collecting shapely<2.1 (from pypsa)
  Downloading shapely-2.0.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting deprecation (from pypsa)
  Downloading deprecation-2.1.0-py2.py3-none-any.whl.metadata (4.6 kB)
Collecting validators (from pypsa)
  Downloading validators-0.35.0-py3-none-any.whl.metadata (3.9 kB)
Collecting cftime (from netcdf4->pypsa)
  Downloading cftime-1.6.4.post1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Downloading pypsa-0.35.2-py3-none-any.whl (267 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m267.3/267.3 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloadi

In [2]:
import logging
# Suppress PyPSA INFO messages (keep warnings and errors)
logging.getLogger('pypsa').setLevel(logging.WARNING)

import pypsa
import pandas as pd
import numpy as np
import gymnasium as gym
from gymnasium import spaces

import gc
import psutil
import matplotlib.pyplot as plt

import neptune

from torch.utils.data import TensorDataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F

import random

import os
#import ipdb



In [3]:
def set_all_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def calculate_offset_k_initialization(network_file, reward_scale_factor, k_samples=1000, seed=42, **env_kwargs):
    """
    Calculate the offset k for replacement reward method.
    Creates a temporary environment to perform the calculation.

    We also use this value to scale the reward of the

    Parameters:
    -----------
    network_file : str
      Path to the PyPSA network file
    input_dir : str
      Directory containing constraint mappings
    k_samples : int
      Number of random samples to use for estimation
    **env_kwargs : dict
      Additional keyword arguments for environment creation

    Returns:
    --------
    k_mean: Offset value k found by mean method. This is needed for the replacement reward function
    k_worst: the highest objective value sampled. we want to scale our rewards by this number/100 so that
    constraint_df: pandas DataFrame with constraint violations for each sample
    """
    print(f"Sampling {k_samples} random states to calculate offset k...")

    # Set up reproducible random number generator for action sampling
    import numpy as np
    local_rng = np.random.RandomState(seed)
    # Set up a separate RNG for snapshot selection
    snapshot_rng = np.random.RandomState(seed + 10000)  # Different seed space


    # Create temporary environment WITH SEED
    temp_env = EnvDispatchConstr(
        network_file=network_file, no_convergence_lpf_penalty=0,reward_scale_factor=reward_scale_factor,#these params aren't used, just pass
        seed=seed,
        **env_kwargs
    )
    #I'm just making this env to access certain attributes/ methods; which should be fine since none of these attributes/methods reference these two parameters.

    # Extract constraint names in the correct order
    slack_generators = temp_env.network.generators[temp_env.network.generators.control == "Slack"].index

    #remove this because don't check line constraints during training
    #line_names = temp_env.network.lines.index
    storage_names = temp_env.storage_names

    # Build constraint names in the order they appear in info['constraint_results']
    constraint_names = []

    # p_min for each slack generator
    for gen_name in slack_generators:
        constraint_names.append(f"p_min_{gen_name}")

    # p_max for each slack generator
    for gen_name in slack_generators:
        constraint_names.append(f"p_max_{gen_name}")

    # # s_max for each line
    # for line_name in line_names:
    #     constraint_names.append(f"s_max_{line_name}")

    # soc_min for each storage
    for storage_name in storage_names:
        constraint_names.append(f"soc_min_{storage_name}")

    # soc_max for each storage
    for storage_name in storage_names:
        constraint_names.append(f"soc_max_{storage_name}")

    print(f"Identified {len(constraint_names)} constraints:")
    for i, name in enumerate(constraint_names):
        print(f"  {i}: {name}")

    action_dim = temp_env.action_space.shape[0]

    objective_values = []
    constraint_data = []  # List to store dictionaries for DataFrame creation
    successful_samples = 0

    try:
        for i in range(k_samples):
            try:
                # Reset environment with deterministic seed
                env_seed = seed + i  # Different seed for each sample
                temp_env.reset(seed=env_seed)
                #go to random snapshot (env reset doesnt actually do that)
                temp_env.snapshot_idx = snapshot_rng.randint(0, temp_env.total_snapshots)
                # Sample random action using local RNG (not global random state)
                random_action = local_rng.uniform(0, 1, size=action_dim)  # [0,1] range# This ensures [0,1] range
                # Take step - this handles all action scaling and application
                # make sure to get snapshot before the step to use when call evaluate_objective_direc() again
                #i increase the current snapshot_idx by executing step() if do this line after step, when evaluate_objective_direct() is run it evaluates the objective for the next step!
                current_snapshot = temp_env.network.snapshots[temp_env.snapshot_idx]
                obs, reward, terminated, truncated, info = temp_env.step(random_action)
                # Get the base objective value (the -J(s) part, before any penalties or offsets)
                obj_value = reward_scale_factor*temp_env.evaluate_objective_direct(current_snapshot)
                objective_values.append(obj_value)

                # Create row data for this sample
                constraint_results = info['constraint_results']
                row_data = {
                    'sample_idx': successful_samples,
                    'snapshot_idx': temp_env.snapshot_idx,
                    'objective_value': obj_value
                }

                # Add constraint violations using proper names
                for j, violation in enumerate(constraint_results):
                    if j < len(constraint_names):
                        row_data[constraint_names[j]] = violation
                    else:
                        # Fallback if there's a mismatch
                        row_data[f'constraint_{j}'] = violation
                        print(f"Warning: More constraints than expected at index {j}")

                constraint_data.append(row_data)
                successful_samples += 1

                # Progress indicator every 200 samples
                if (i + 1) % 200 == 0:
                    print(f"  Completed {i + 1}/{k_samples} samples...")

            except Exception as e:
                # Skip failed samples but continue
                if i < 5:  # Only print first few errors to avoid spam
                    print(f"  Sample {i} failed: {e}")
                continue

        # Create DataFrame from collected data
        constraint_df = pd.DataFrame(constraint_data)

        # Calculate offset based on method
        if objective_values:
            k_worst = abs(max(objective_values))
            print(f"k_worst = |{max(objective_values):.2f}| = {k_worst:.2f}")
            mean_val = np.mean(objective_values)
            k_mean = abs(mean_val)
            print(f"k_mean = |{mean_val:.2f}| = {k_mean:.2f}")

            print(f"  Successfully sampled {successful_samples}/{k_samples} states")
            print(f"  Objective value range: [{min(objective_values):.2f}, {max(objective_values):.2f}]")
            print(f"  Constraint DataFrame shape: {constraint_df.shape}")
        else:
            print("  Warning: No successful samples")
            constraint_df = pd.DataFrame()  # Empty DataFrame

    except Exception as e:
          print(f"Error in offset calculation: {e}")
          import traceback
          traceback.print_exc()
          constraint_df = pd.DataFrame()  # Empty DataFrame on error

    return k_mean, k_worst, constraint_df

In [5]:
# seed = 42  # Define seed variable first
# set_all_seeds(seed)
# gdrive_base = './'  # or '/workspace/'
# network_file_path = os.path.join(gdrive_base, "networks_1_year_connected", "elec_s_10_ec_lc1.0_1h.nc")

# k_mean, k_worst, constraint_df = calculate_offset_k_initialization(network_file=network_file_path, seed=seed)

# # Save to text file
# kmean_file_path = os.path.join(gdrive_base, "offset", f"k_mean_seed{seed}.txt")
# with open(kmean_file_path, 'w') as f:
#     f.write(str(k_mean))

# kworst_file_path = os.path.join(gdrive_base, "offset", f"k_worst_seed{seed}.txt")
# with open(kworst_file_path, 'w') as f:
#     f.write(str(k_worst))

In [6]:
# def plot_constraint_violations(constraint_df, figsize=(10, 8), save_dir=None):
#     """
#     Plot constraint violations vs objective values as separate scatter plots for each constraint.

#     Parameters:
#     -----------
#     constraint_df : pandas.DataFrame
#         DataFrame returned by calculate_offset_k_initialization containing constraint data
#     figsize : tuple
#         Figure size (width, height) for each individual plot
#     save_dir : str, optional
#         If provided, save plots to this directory with names based on constraint names
#     """
#     if constraint_df.empty:
#         print("No data to plot - DataFrame is empty")
#         return

#     # Get constraint columns (excluding metadata columns)
#     constraint_cols = [col for col in constraint_df.columns
#                       if col not in ['sample_idx', 'snapshot_idx', 'objective_value']]

#     if not constraint_cols:
#         print("No constraint columns found in DataFrame")
#         return

#     num_constraints = len(constraint_cols)
#     print(f"Creating {num_constraints} separate plots for objective vs violations across {len(constraint_df)} samples")

#     # Get objective values for x-axis
#     objective_values = constraint_df['objective_value']

#     for i, constraint_col in enumerate(constraint_cols):
#         # Create a new figure for each constraint
#         fig, ax = plt.subplots(1, 1, figsize=figsize)

#         violations = constraint_df[constraint_col]

#         # Create scatter plot: objective value vs constraint violation
#         # Plot ALL data points, including NaN/inf values
#         scatter = ax.scatter(objective_values, violations, alpha=0.6, s=30, c=constraint_df['sample_idx'],
#                            cmap='viridis', edgecolors='black', linewidth=0.5)

#         # Use the actual constraint name as title
#         ax.set_title(f'{constraint_col}: Objective vs Violation', fontsize=14, fontweight='bold')
#         ax.set_xlabel('Objective Value', fontsize=12)
#         ax.set_ylabel('Constraint Violation', fontsize=12)
#         ax.grid(True, alpha=0.3)

#         # Add colorbar for sample index
#         cbar = plt.colorbar(scatter, ax=ax)
#         cbar.set_label('Sample Index', fontsize=11)

#         # Calculate statistics on ALL data (including NaN/inf)
#         mean_viol = violations.mean()
#         max_viol = violations.max()
#         min_viol = violations.min()
#         std_viol = violations.std()

#         # Count finite vs non-finite values
#         finite_count = np.isfinite(violations).sum()
#         total_count = len(violations)

#         # Calculate correlation ONLY on finite values
#         try:
#             # Filter for correlation calculation only
#             valid_mask = np.isfinite(objective_values) & np.isfinite(violations)
#             if valid_mask.sum() > 1:  # Need at least 2 valid points
#                 valid_obj = objective_values[valid_mask]
#                 valid_viol = violations[valid_mask]

#                 # Check if there's variance in both variables
#                 if np.std(valid_obj) > 1e-10 and np.std(valid_viol) > 1e-10:
#                     correlation = np.corrcoef(valid_obj, valid_viol)[0, 1]
#                     corr_text = f'{correlation:.3f}'
#                 else:
#                     corr_text = 'N/A (no variance)'
#             else:
#                 corr_text = 'N/A (insufficient data)'
#         except:
#             corr_text = 'N/A (calc error)'

#         stats_text = (f'Violation Statistics:\n'
#                      f'Mean: {mean_viol:.3f}\n'
#                      f'Max: {max_viol:.3f}\n'
#                      f'Min: {min_viol:.3f}\n'
#                      f'Std: {std_viol:.3f}\n'
#                      f'Finite values: {finite_count}/{total_count}\n'
#                      f'Correlation w/ Objective: {corr_text}')

#         ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
#                 verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.9),
#                 fontsize=10)

#         plt.tight_layout()

#         # Save individual plot if directory provided
#         if save_dir:
#             import os
#             os.makedirs(save_dir, exist_ok=True)
#             # Use constraint name for filename, replacing invalid characters
#             safe_filename = constraint_col.replace('/', '_').replace('\\', '_').replace(':', '_')
#             save_path = os.path.join(save_dir, f'{safe_filename}.png')
#             plt.savefig(save_path, dpi=300, bbox_inches='tight')
#             print(f"Saved {constraint_col} plot to {save_path}")

#         plt.show()

#     print(f"Completed plotting {num_constraints} constraint plots")

# def plot_constraint_violations_summary(constraint_df, figsize=(12, 8), save_path=None):
#     """
#     Create summary plots for constraint violations: histograms and box plots.

#     Parameters:
#     -----------
#     constraint_df : pandas.DataFrame
#         DataFrame returned by calculate_offset_k_initialization containing constraint data
#     figsize : tuple
#         Figure size (width, height)
#     save_path : str, optional
#         If provided, save the plot to this path
#     """
#     if constraint_df.empty:
#         print("No data to plot - DataFrame is empty")
#         return

#     # Get constraint columns (excluding metadata columns)
#     constraint_cols = [col for col in constraint_df.columns
#                       if col not in ['sample_idx', 'snapshot_idx', 'objective_value']]

#     if not constraint_cols:
#         print("No constraint columns found in DataFrame")
#         return

#     fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)

#     # Box plot of all constraints - plot ALL data
#     violation_data = constraint_df[constraint_cols].values

#     # For box plots, we need to handle NaN/inf differently since matplotlib can't plot them
#     # Create a list of arrays, removing only NaN/inf for the box plot visualization
#     clean_data_for_boxplot = []
#     constraint_labels = []
#     for col in constraint_cols:
#         col_data = constraint_df[col].values
#         finite_vals = col_data[np.isfinite(col_data)]
#         clean_data_for_boxplot.append(finite_vals)
#         # Shorten label names for readability
#         short_label = col.replace('p_min_', 'Pmin_').replace('p_max_', 'Pmax_').replace('s_max_', 'Smax_').replace('soc_min_', 'SOCmin_').replace('soc_max_', 'SOCmax_')
#         constraint_labels.append(short_label)

#     ax1.boxplot(clean_data_for_boxplot, labels=constraint_labels)
#     ax1.set_title('Constraint Violation Distributions\n(NaN/inf excluded from display)')
#     ax1.set_xlabel('Constraint')
#     ax1.set_ylabel('Violation Value')
#     ax1.grid(True, alpha=0.3)
#     # Rotate x labels for better readability
#     ax1.tick_params(axis='x', rotation=45)

#     # Histogram of violation magnitudes - plot ALL finite data
#     all_violations = violation_data.flatten()
#     finite_violations = all_violations[np.isfinite(all_violations)]

#     if len(finite_violations) > 0:
#         ax2.hist(finite_violations, bins=50, alpha=0.7, edgecolor='black')
#         ax2.set_title('Distribution of All Violation Values\n(NaN/inf excluded from display)')
#         ax2.set_xlabel('Violation Value')
#         ax2.set_ylabel('Frequency')
#         ax2.grid(True, alpha=0.3)

#         # Calculate statistics on finite data only
#         mean_val = np.mean(finite_violations)
#         median_val = np.median(finite_violations)
#         ax2.axvline(mean_val, color='red', linestyle='--',
#                     label=f'Mean: {mean_val:.3f}')
#         ax2.axvline(median_val, color='green', linestyle='--',
#                     label=f'Median: {median_val:.3f}')
#         ax2.legend()

#         # Add info about excluded data
#         total_points = len(all_violations)
#         finite_points = len(finite_violations)
#         excluded_points = total_points - finite_points
#         if excluded_points > 0:
#             ax2.text(0.02, 0.98, f'Excluded {excluded_points}/{total_points} NaN/inf values',
#                     transform=ax2.transAxes, verticalalignment='top',
#                     bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.7))
#     else:
#         ax2.text(0.5, 0.5, 'All data contains NaN/inf values', transform=ax2.transAxes,
#                 ha='center', va='center')
#         ax2.set_title('Distribution of All Violation Values')

#     plt.tight_layout()

#     if save_path:
#         plt.savefig(save_path, dpi=300, bbox_inches='tight')
#         print(f"Summary plot saved to {save_path}")

#     plt.show()

# def plot_total_violations(constraint_df, figsize=(10, 8), save_path=None):
#     """
#     Plot objective values vs total constraint violations (sum of all individual constraint violations).

#     Parameters:
#     -----------
#     constraint_df : pandas.DataFrame
#         DataFrame returned by calculate_offset_k_initialization containing constraint data
#     figsize : tuple
#         Figure size (width, height)
#     save_path : str, optional
#         If provided, save the plot to this path
#     """
#     if constraint_df.empty:
#         print("No data to plot - DataFrame is empty")
#         return

#     # Get constraint columns (excluding metadata columns)
#     constraint_cols = [col for col in constraint_df.columns
#                       if col not in ['sample_idx', 'snapshot_idx', 'objective_value']]

#     if not constraint_cols:
#         print("No constraint columns found in DataFrame")
#         return

#     print(f"Calculating total violations from {len(constraint_cols)} constraints across {len(constraint_df)} samples")

#     # Calculate total violation for each sample (sum across all constraints)
#     total_violations = constraint_df[constraint_cols].sum(axis=1)
#     objective_values = constraint_df['objective_value']

#     # Find maximum total violation
#     max_total_violation = total_violations.max()
#     max_violation_idx = total_violations.idxmax()
#     max_sample = constraint_df.loc[max_violation_idx]

#     print(f"\nMaximum total violation: {max_total_violation:.4f}")
#     print(f"Occurred at sample {max_sample['sample_idx']} (snapshot {max_sample['snapshot_idx']})")
#     print(f"Objective value at max violation: {max_sample['objective_value']:.2f}")

#     # Create the plot
#     fig, ax = plt.subplots(1, 1, figsize=figsize)

#     # Create scatter plot: objective value vs total violation
#     scatter = ax.scatter(objective_values, total_violations, alpha=0.6, s=30,
#                         c=constraint_df['sample_idx'], cmap='viridis',
#                         edgecolors='black', linewidth=0.5)

#     # Highlight the maximum violation point
#     ax.scatter(max_sample['objective_value'], max_total_violation,
#               color='red', s=100, marker='*', edgecolors='black', linewidth=1,
#               label=f'Max Total Violation: {max_total_violation:.4f}')

#     ax.set_title('Objective vs Total Constraint Violations', fontsize=14, fontweight='bold')
#     ax.set_xlabel('Objective Value', fontsize=12)
#     ax.set_ylabel('Total Constraint Violation (Sum)', fontsize=12)
#     ax.grid(True, alpha=0.3)
#     ax.legend()

#     # Add colorbar for sample index
#     cbar = plt.colorbar(scatter, ax=ax)
#     cbar.set_label('Sample Index', fontsize=11)

#     # Calculate statistics on total violations
#     mean_total = total_violations.mean()
#     max_total = total_violations.max()
#     min_total = total_violations.min()
#     std_total = total_violations.std()

#     # Count finite vs non-finite values
#     finite_count = np.isfinite(total_violations).sum()
#     total_count = len(total_violations)

#     # Calculate correlation between objective and total violations
#     try:
#         valid_mask = np.isfinite(objective_values) & np.isfinite(total_violations)
#         if valid_mask.sum() > 1:
#             valid_obj = objective_values[valid_mask]
#             valid_total = total_violations[valid_mask]

#             if np.std(valid_obj) > 1e-10 and np.std(valid_total) > 1e-10:
#                 correlation = np.corrcoef(valid_obj, valid_total)[0, 1]
#                 corr_text = f'{correlation:.3f}'
#             else:
#                 corr_text = 'N/A (no variance)'
#         else:
#             corr_text = 'N/A (insufficient data)'
#     except:
#         corr_text = 'N/A (calc error)'

#     stats_text = (f'Total Violation Statistics:\n'
#                  f'Mean: {mean_total:.4f}\n'
#                  f'Max: {max_total:.4f}\n'
#                  f'Min: {min_total:.4f}\n'
#                  f'Std: {std_total:.4f}\n'
#                  f'Finite values: {finite_count}/{total_count}\n'
#                  f'Correlation w/ Objective: {corr_text}')

#     ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
#             verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.9),
#             fontsize=10)

#     plt.tight_layout()

#     if save_path:
#         plt.savefig(save_path, dpi=300, bbox_inches='tight')
#         print(f"Total violations plot saved to {save_path}")

#     plt.show()

#     # Print breakdown of violations for the worst case
#     print(f"\nBreakdown of violations for sample with maximum total violation:")
#     print(f"Sample {max_sample['sample_idx']} (snapshot {max_sample['snapshot_idx']}):")
#     for constraint_col in constraint_cols:
#         violation_val = max_sample[constraint_col]
#         if pd.notna(violation_val) and violation_val != 0:
#             print(f"  {constraint_col}: {violation_val:.4f}")

#     return max_total_violation, max_violation_idx

In [7]:
# # Plot individual constraint violations as separate plots
# plot_constraint_violations(constraint_df, figsize=(10, 8), save_dir="workspace/constraints")

# # Plot objective vs total violations
# max_total_violation, worst_sample_idx = plot_total_violations(
#     constraint_df,
#     save_path="workspace/total_violations.png"
# )
# print(f"The worst case total violation was: {max_total_violation}")
# plot_constraint_violations_summary(constraint_df, save_path="workspace/constraint_summary.png")

In [8]:
# def plot_violation_count_distribution(constraint_df, violation_threshold=0.0, figsize=(12, 8), save_path=None):
#     """
#     Plot the distribution of number of violations per sample.

#     Parameters:
#     -----------
#     constraint_df : pandas.DataFrame
#         DataFrame returned by calculate_offset_k_initialization containing constraint data
#     violation_threshold : float
#         Threshold above which a constraint is considered violated (default: 0.0)
#     figsize : tuple
#         Figure size (width, height)
#     save_path : str, optional
#         If provided, save the plot to this path
#     """
#     if constraint_df.empty:
#         print("No data to plot - DataFrame is empty")
#         return

#     # Get constraint columns (excluding metadata columns)
#     constraint_cols = [col for col in constraint_df.columns
#                       if col not in ['sample_idx', 'snapshot_idx', 'objective_value']]

#     if not constraint_cols:
#         print("No constraint columns found in DataFrame")
#         return

#     print(f"Analyzing violation counts from {len(constraint_cols)} constraints across {len(constraint_df)} samples")
#     print(f"Using violation threshold: {violation_threshold}")

#     # Count violations per sample (number of constraints violated)
#     # A constraint is violated if its value > violation_threshold
#     violation_counts = []
#     for _, row in constraint_df.iterrows():
#         count = 0
#         valid_constraints = 0
#         for constraint_col in constraint_cols:
#             violation_val = row[constraint_col]
#             if pd.notna(violation_val) and np.isfinite(violation_val):
#                 valid_constraints += 1
#                 if violation_val > violation_threshold:
#                     count += 1
#         violation_counts.append(count)

#     violation_counts = np.array(violation_counts)

#     # Create subplots
#     fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)

#     # 1. Histogram of violation counts
#     unique_counts, count_frequencies = np.unique(violation_counts, return_counts=True)

#     ax1.bar(unique_counts, count_frequencies, alpha=0.7, edgecolor='black', width=0.8)
#     ax1.set_title('Distribution of Number of Violations per Sample')
#     ax1.set_xlabel('Number of Constraints Violated')
#     ax1.set_ylabel('Number of Samples')
#     ax1.grid(True, alpha=0.3, axis='y')

#     # Add percentage labels on bars
#     total_samples = len(violation_counts)
#     for i, (count, freq) in enumerate(zip(unique_counts, count_frequencies)):
#         percentage = (freq / total_samples) * 100
#         ax1.text(count, freq + 0.01 * max(count_frequencies), f'{percentage:.1f}%',
#                 ha='center', va='bottom', fontsize=9)

#     # 2. Pie chart of violation categories
#     no_violations = np.sum(violation_counts == 0)
#     some_violations = np.sum((violation_counts > 0) & (violation_counts < len(constraint_cols)))
#     all_violations = np.sum(violation_counts == len(constraint_cols))

#     categories = []
#     sizes = []
#     colors = ['green', 'orange', 'red']
#     labels = []

#     if no_violations > 0:
#         categories.append('No Violations')
#         sizes.append(no_violations)
#         labels.append(f'No Violations\n({no_violations} samples)')

#     if some_violations > 0:
#         categories.append('Some Violations')
#         sizes.append(some_violations)
#         labels.append(f'Some Violations\n({some_violations} samples)')

#     if all_violations > 0:
#         categories.append('All Violated')
#         sizes.append(all_violations)
#         labels.append(f'All Violated\n({all_violations} samples)')

#     if sizes:
#         wedges, texts, autotexts = ax2.pie(sizes, labels=labels, colors=colors[:len(sizes)],
#                                           autopct='%1.1f%%', startangle=90)
#         ax2.set_title('Violation Categories')
#     else:
#         ax2.text(0.5, 0.5, 'No valid data', ha='center', va='center', transform=ax2.transAxes)

#     # 3. Objective value vs number of violations
#     objective_values = constraint_df['objective_value']
#     scatter = ax3.scatter(objective_values, violation_counts, alpha=0.6, s=30,
#                          c=constraint_df['sample_idx'], cmap='viridis',
#                          edgecolors='black', linewidth=0.5)

#     ax3.set_title('Objective Value vs Number of Violations')
#     ax3.set_xlabel('Objective Value')
#     ax3.set_ylabel('Number of Constraints Violated')
#     ax3.grid(True, alpha=0.3)
#     ax3.set_yticks(range(max(violation_counts) + 1))

#     # Add colorbar
#     cbar = plt.colorbar(scatter, ax=ax3)
#     cbar.set_label('Sample Index', fontsize=10)

#     # 4. Box plot of objective values by violation count
#     violation_count_groups = {}
#     for i, count in enumerate(violation_counts):
#         if count not in violation_count_groups:
#             violation_count_groups[count] = []
#         obj_val = objective_values.iloc[i]
#         if pd.notna(obj_val) and np.isfinite(obj_val):
#             violation_count_groups[count].append(obj_val)

#     # Only include groups with at least one valid objective value
#     valid_groups = {k: v for k, v in violation_count_groups.items() if len(v) > 0}

#     if valid_groups:
#         box_data = [valid_groups[count] for count in sorted(valid_groups.keys())]
#         box_labels = [f'{count}\n(n={len(valid_groups[count])})' for count in sorted(valid_groups.keys())]

#         ax4.boxplot(box_data, labels=box_labels)
#         ax4.set_title('Objective Value Distribution by Violation Count')
#         ax4.set_xlabel('Number of Violations')
#         ax4.set_ylabel('Objective Value')
#         ax4.grid(True, alpha=0.3)
#     else:
#         ax4.text(0.5, 0.5, 'No valid objective values', ha='center', va='center', transform=ax4.transAxes)

#     # Calculate and display summary statistics
#     mean_violations = np.mean(violation_counts)
#     median_violations = np.median(violation_counts)
#     max_violations = np.max(violation_counts)
#     std_violations = np.std(violation_counts)

#     # Find samples with maximum violations
#     max_violation_samples = constraint_df[violation_counts == max_violations]

#     # Calculate correlation between objective and violation count
#     try:
#         valid_mask = np.isfinite(objective_values) & np.isfinite(violation_counts)
#         if valid_mask.sum() > 1:
#             valid_obj = objective_values[valid_mask]
#             valid_counts = violation_counts[valid_mask]

#             if np.std(valid_obj) > 1e-10 and np.std(valid_counts) > 1e-10:
#                 correlation = np.corrcoef(valid_obj, valid_counts)[0, 1]
#                 corr_text = f'{correlation:.3f}'
#             else:
#                 corr_text = 'N/A (no variance)'
#         else:
#             corr_text = 'N/A (insufficient data)'
#     except:
#         corr_text = 'N/A (calc error)'

#     # Add summary text box
#     summary_text = (f'Violation Count Statistics:\n'
#                    f'Mean: {mean_violations:.2f}\n'
#                    f'Median: {median_violations:.1f}\n'
#                    f'Max: {max_violations}\n'
#                    f'Std: {std_violations:.2f}\n'
#                    f'Total constraints: {len(constraint_cols)}\n'
#                    f'Samples with 0 violations: {no_violations}/{total_samples}\n'
#                    f'Correlation w/ Objective: {corr_text}')

#     fig.text(0.02, 0.02, summary_text, fontsize=10,
#              bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

#     plt.tight_layout()
#     plt.subplots_adjust(bottom=0.2)  # Make room for summary text

#     if save_path:
#         plt.savefig(save_path, dpi=300, bbox_inches='tight')
#         print(f"Violation count distribution plot saved to {save_path}")

#     plt.show()

#     # Print detailed summary
#     print(f"\n=== VIOLATION COUNT ANALYSIS ===")
#     print(f"Total samples: {total_samples}")
#     print(f"Total constraints: {len(constraint_cols)}")
#     print(f"Violation threshold: {violation_threshold}")
#     print(f"\nViolation count distribution:")
#     for count, freq in zip(unique_counts, count_frequencies):
#         percentage = (freq / total_samples) * 100
#         print(f"  {count} violations: {freq} samples ({percentage:.1f}%)")

#     print(f"\nSamples with maximum violations ({max_violations}):")
#     for _, sample in max_violation_samples.iterrows():
#         print(f"  Sample {sample['sample_idx']} (snapshot {sample['snapshot_idx']}) - Objective: {sample['objective_value']:.2f}")

#         # Show which constraints are violated for this sample
#         violated_constraints = []
#         for constraint_col in constraint_cols:
#             violation_val = sample[constraint_col]
#             if pd.notna(violation_val) and np.isfinite(violation_val) and violation_val > violation_threshold:
#                 violated_constraints.append(f"{constraint_col}: {violation_val:.4f}")

#         if violated_constraints:
#             print(f"    Violated constraints: {', '.join(violated_constraints[:3])}" +
#                   (f" (and {len(violated_constraints)-3} more)" if len(violated_constraints) > 3 else ""))

#     return violation_counts, unique_counts, count_frequencies

In [9]:
 # violation_counts, unique_counts, frequencies = plot_violation_count_distribution(constraint_df)


In [5]:
def fix_artificial_lines_reasonable(network):
    """
    Fix artificial lines with reasonable capacity values:
    - s_nom = based on connected bus demand (with safety factor)
    - s_nom_extendable = False (non-extendable)
    - Keep capacity high enough to meet demand
    """
    print("=== FIXING ARTIFICIAL LINES WITH REASONABLE CAPACITY ===")

    # Find artificial lines
    artificial_lines = [line for line in network.lines.index
                       if any(keyword in str(line).lower() for keyword in ['new', '<->', 'artificial'])]

    if not artificial_lines:
        # If no artificial lines found by name, look for lines with s_nom=0
        # which is often a sign of artificial lines
        zero_capacity_lines = network.lines[network.lines.s_nom == 0].index.tolist()
        if zero_capacity_lines:
            artificial_lines = zero_capacity_lines

    print(f"Found {len(artificial_lines)} artificial lines to fix:")

    # Get maximum demand per bus across all snapshots
    bus_max_demand = {}
    for bus in network.buses.index:
        bus_demand = 0
        for load_name, load in network.loads.iterrows():
            if load.bus == bus and load_name in network.loads_t.p_set.columns:
                bus_demand = max(bus_demand, network.loads_t.p_set[load_name].max())
        bus_max_demand[bus] = bus_demand

    # Fix each artificial line with reasonable capacity
    for line_name in artificial_lines:
        # Get connected buses
        bus0 = network.lines.loc[line_name, 'bus0']
        bus1 = network.lines.loc[line_name, 'bus1']

        # Get maximum demand at these buses
        bus0_demand = bus_max_demand.get(bus0, 0)
        bus1_demand = bus_max_demand.get(bus1, 0)

        # Calculate required capacity with safety factor
        # Use 3x the higher demand to ensure adequate capacity
        safety_factor = 3.0
        required_capacity = max(bus0_demand, bus1_demand) * safety_factor

        # Ensure minimum reasonable capacity (1000 MW)
        required_capacity = max(required_capacity, 1000)

        print(f"\n Fixing: {line_name}")
        print(f"    Connected buses: {bus0} ↔ {bus1}")
        print(f"    Bus demands: {bus0}: {bus0_demand:.1f} MW, {bus1}: {bus1_demand:.1f} MW")

        # Set s_nom to required capacity
        old_s_nom = network.lines.loc[line_name, 's_nom']
        network.lines.loc[line_name, 's_nom'] = required_capacity
        print(f"    s_nom: {old_s_nom} → {required_capacity:.1f} MW")

        # Make sure line is not extendable
        if 's_nom_extendable' not in network.lines.columns:
            network.lines['s_nom_extendable'] = False
        network.lines.loc[line_name, 's_nom_extendable'] = False
        print(f"    s_nom_extendable: → False")

    return network

def remove_offshore_wind(network):
    """
    Remove offshore wind generators.
    All of these have zero nominal capacity (likely missing data).
    Need to remove them to avoid division by zero error in constraint check for slack gens.
    Problem is still feasible without offwind slack since pypsa optimize still feasible.
    """

    # First, identify offshore wind generators
    offwind_gens = network.generators[
        network.generators.index.str.contains('offwind', case=False, na=False)
    ].index

    print(f"Found {len(offwind_gens)} offshore wind generators:")
    print(offwind_gens.tolist())

    # Check their properties
    offwind_data = network.generators.loc[offwind_gens, ['p_nom', 'control', 'carrier']]
    print("\nOffshore wind generator details:")
    print(offwind_data)

    # Remove offshore wind generators one by one
    print(f"\nRemoving {len(offwind_gens)} offshore wind generators...")
    for gen in offwind_gens:
        network.remove("Generator", gen)

def create_pypsa_network(network_file):
    """Create a PyPSA network from the .nc file."""
    # Initialize network
    network = pypsa.Network(network_file)
    for storage_name in network.storage_units.index:
        # Use .loc for direct assignment to avoid SettingWithCopyWarning
        network.storage_units.loc[storage_name, 'cyclic_state_of_charge'] = False

        # Set marginal_cost to 0.01
        network.storage_units.loc[storage_name, 'marginal_cost'] = 0.01

        # Set marginal_cost_storage to 0.01
        network.storage_units.loc[storage_name, 'marginal_cost_storage'] = 0.01

        # Set spill_cost to 0.1
        network.storage_units.loc[storage_name, 'spill_cost'] = 0.1

        network.storage_units.loc[storage_name, 'efficiency_store'] = 0.866025 #use phs efficiency (hydro didnt have an efficiency, but i want to model them all as the same)

        # Fix unrealistic max_hours values
        current_max_hours = network.storage_units.loc[storage_name, 'max_hours']

        if 'PHS' in storage_name:
            # PHS with missing data - set to typical range
            network.storage_units.loc[storage_name, 'max_hours'] = 8.0
            print(f"Fixed {storage_name}: set max_hours to 8.0")

        elif 'hydro' in storage_name:
            # Hydro with unrealistic data - set to validated range
            network.storage_units.loc[storage_name, 'max_hours'] = 6.0
            print(f"Fixed {storage_name}: corrected max_hours from {current_max_hours} to 6.0")


    fix_artificial_lines_reasonable(network)
    remove_offshore_wind(network)

    return network

In [6]:
class EnvDispatchConstr(gym.Env):
    """
    OpenAI Gym environment for Optimal Power Flow using PyPSA.
    Enhanced to handle dispatchable generators, renewable generators, and storage units.

    Action Space: Continuous setpoints for all controllable components within their capacity limits
    - Dispatchable generators: scaled between p_min_pu*p_nom and p_max_pu*p_nom
    - Renewable generators: scaled between 0 and current p_max_pu*p_nom (time-varying)
    - Storage units: scaled between -p_nom (charging) and +p_nom (discharging)
    (This follows http://arxiv.org/abs/2403.17831.)

    Has train/test split functionality
    """

    def __init__(self, network_file, no_convergence_lpf_penalty, reward_scale_factor, constraint_penalty_factor=10, seed=None,
                 #test_start_date='2013-12-01 00:00:00', fixed_episode_length=None
                ):
        super().__init__()
        if seed is not None:
            np.random.seed(seed)  # Set initial seed

        self.network_file = network_file # Store network file path

        # Use provided network or create new one
        self.network =create_pypsa_network(network_file)
        self.reward_scale_factor=reward_scale_factor

        #self.test_start_date = pd.Timestamp(test_start_date)
        #Omit self._train_test_snapshots()

        self.penalty_factor= constraint_penalty_factor
        self.no_convergence_lpf_penalty= no_convergence_lpf_penalty
        self.reward_method = "summation" # Default reward method for the base class

        # Omit episode length configuration

        # Episode management
        self.total_snapshots = len(self.network.snapshots)
        self.snapshot_idx = 0  # Current snapshot index (cycles through all snapshots)

        # Initialize component categorization
        self._categorize_components()

        # Create action space
        self._create_action_space()

        # Initialize the network state
        self.reset()

        # Create observation space
        low_bounds, high_bounds = self.create_observation_bounds()
        self.observation_space = spaces.Box(
            low=low_bounds,
            high=high_bounds,
            dtype=np.float32
        )

        slack_generators = self.network.generators[self.network.generators.control == "Slack"].index

        self.n_slack=len(slack_generators)
        self.n_lines=len(self.network.lines.index)


        total_size_test=(2 * self.n_slack + self.n_lines + 2 * self.n_storage)
        self.initial_constr_test = np.zeros(total_size_test, dtype=np.float64)
        self.n_constr_test=len(self.initial_constr_test)

        # Remove line constraints for performance - only check slack and storage constraints
        total_size = (2 * self.n_slack + 2 * self.n_storage)
        self.initial_constr = np.zeros(total_size, dtype=np.float64)
        self.n_constr=len(self.initial_constr)

        # Cache component mappings for performance optimization
        # self._cache_component_mappings()
        self._initialize_power_flow_matrices()


    #Omit _train_test_snapshots which defines self.train_snapshots and self.test_snapshots
    def _initialize_power_flow_matrices(self):
        """Pre-compute network topology and power flow matrices for fast LPF."""
        print("Pre-computing network topology and power flow matrices...")

        # Step 1: Determine network topology (identifies sub-networks)
        self.network.determine_network_topology()

        # Step 2: Pre-compute power flow matrices for each sub-network
        for sub in self.network.sub_networks.obj:
            sub.calculate_B_H()

        print(f"Initialized {len(self.network.sub_networks.obj)} sub-networks for fast power flow")

    def _categorize_components(self):
        """
        Categorize generators and identify storage units for action space.
        """
        # Get generators with time-varying p_max_pu (renewable generators)
        renewable_gens = self.network.generators_t.p_max_pu.columns

        slack_generators = self.network.generators[self.network.generators.control == "Slack"].index
        # in the 10-node SA network there are 4 slack gens so this should return a list of indexes

        # Dispatchable generators: not slack, not renewable
        self.dispatchable_gens = self.network.generators[
            (~self.network.generators.index.isin(slack_generators)) &
            (~self.network.generators.index.isin(renewable_gens))
        ].index

        # Renewable generators: have time-varying p_max_pu, not slack
        self.renewable_gens = self.network.generators[
            (self.network.generators.index.isin(renewable_gens)) &
            (~self.network.generators.index.isin(slack_generators))
        ].index

        # Storage units (if any exist in the network)
        self.storage_units = self.network.storage_units.index

        # Store names as lists for easier indexing
        self.dispatchable_names = list(self.dispatchable_gens)
        self.renewable_names = list(self.renewable_gens)
        self.storage_names = list(self.storage_units)

        # Store counts
        self.n_dispatchable = len(self.dispatchable_names)
        self.n_renewable = len(self.renewable_names)
        self.n_storage = len(self.storage_names)

        # Get static limits for dispatchable generators
        if self.n_dispatchable > 0:
            dispatchable_df = self.network.generators.loc[self.dispatchable_gens]
            self.disp_p_min = (dispatchable_df.p_min_pu * dispatchable_df.p_nom).values#returns numpy arrays
            self.disp_p_max = (dispatchable_df.p_max_pu * dispatchable_df.p_nom).values
        else:
            self.disp_p_min = np.array([])
            self.disp_p_max = np.array([])

        # Get nominal capacities and minimum limits for renewable generators
        if self.n_renewable > 0:
            renewable_df = self.network.generators.loc[self.renewable_gens]
            self.renewable_p_nom = renewable_df.p_nom.values
            self.renewable_p_min_pu = renewable_df.p_min_pu.values
        else:
            self.renewable_p_nom = np.array([])
            self.renewable_p_min_pu = np.array([])

        # Classify storage types based on p_min_pu
        if self.n_storage > 0:
            storage_df = self.network.storage_units.loc[self.storage_units]

            # Bidirectional storage (PHS): p_min_pu < 0
            self.bidirectional_storage = storage_df[storage_df.p_min_pu < 0].index

            # Unidirectional storage (Hydro): p_min_pu >= 0
            self.unidirectional_storage = storage_df[storage_df.p_min_pu >= 0].index

            self.n_bidirectional = len(self.bidirectional_storage)
            self.n_unidirectional = len(self.unidirectional_storage)

            # Store names as lists
            self.bidirectional_names = list(self.bidirectional_storage)
            self.unidirectional_names = list(self.unidirectional_storage)
        else:
            self.bidirectional_storage = pd.Index([])
            self.unidirectional_storage = pd.Index([])
            self.n_bidirectional = 0
            self.n_unidirectional = 0
            self.bidirectional_names = []
            self.unidirectional_names = []

    def _cache_component_mappings(self):
        """Cache component-to-bus mappings and indices for fast power balance calculations."""
        # Cache bus mappings
        self._load_buses = self.network.loads['bus'].values
        self._gen_buses = self.network.generators['bus'].values
        self._storage_buses = self.network.storage_units['bus'].values if len(self.storage_names) > 0 else np.array([])

        # Cache component indices for iloc access
        self._all_buses = self.network.buses.index
        self._load_names = self.network.loads.index
        self._gen_names = self.network.generators.index

        # Cache non-slack generator info
        self._non_slack_gens = self.network.generators[self.network.generators.control != "Slack"].index
        self._non_slack_gen_buses = self.network.generators.loc[self._non_slack_gens, 'bus'].values

        # Cache slack generator info
        self._slack_gens = self.network.generators[self.network.generators.control == "Slack"].index
        if len(self._slack_gens) > 1:
            slack_p_noms = self.network.generators.loc[self._slack_gens, 'p_nom']
            self._slack_weights = (slack_p_noms / slack_p_noms.sum()).values
        else:
            self._slack_weights = None

        # Pre-create component-to-index mappings for fast iloc access
        self._load_idx_map = {name: idx for idx, name in enumerate(self._load_names)}
        self._gen_idx_map = {name: idx for idx, name in enumerate(self._gen_names)}
        self._storage_idx_map = {name: idx for idx, name in enumerate(self.storage_names)}

    def _create_action_space(self):
        """
        Create action space with four distinct parts:
        1. Dispatchable generators: [0,1] scaled to [p_min, p_max]
        2. Renewable generators: [0,1] scaled to [0, current_p_max_pu * p_nom]
        3. Storage p_store: [0,1] scaled to [0,p_nom] (charging magnitude)
        4. Storage p_dispatch: [0,1] scaled to [0, p_nom] (discharging magnitude)

        Create action space with different handling for storage types:
        - Bidirectional storage (PHS): 2 actions (p_store, p_dispatch)
        - Unidirectional storage (Hydro): 1 action (p_dispatch only)
        """
        total_actions = (self.n_dispatchable + self.n_renewable +
                    (2 * self.n_bidirectional) + self.n_unidirectional)
        self.action_space = gym.spaces.Box(0, 1, shape=(total_actions,))

        # Store action space structure
        current_idx = 0

        self.action_structure = {
            'dispatchable': {
                'start': current_idx,
                'end': current_idx + self.n_dispatchable,
                'count': self.n_dispatchable
            }
        }
        current_idx += self.n_dispatchable

        self.action_structure['renewable'] = {
            'start': current_idx,
            'end': current_idx + self.n_renewable,
            'count': self.n_renewable
        }
        current_idx += self.n_renewable

        # Bidirectional storage (2 actions each)
        self.action_structure['bidirectional_p_store'] = {
            'start': current_idx,
            'end': current_idx + self.n_bidirectional,
            'count': self.n_bidirectional
        }
        current_idx += self.n_bidirectional

        self.action_structure['bidirectional_p_dispatch'] = {
            'start': current_idx,
            'end': current_idx + self.n_bidirectional,
            'count': self.n_bidirectional
        }
        current_idx += self.n_bidirectional

        # Unidirectional storage (1 action each)
        self.action_structure['unidirectional_p_dispatch'] = {
            'start': current_idx,
            'end': current_idx + self.n_unidirectional,
            'count': self.n_unidirectional
        }

    def _get_storage_observation(self):
        """
        Get current storage unit states for observation.
        Returns current inflow (normalized) for each storage unit.
        """
        if self.n_storage == 0:
            return np.array([])

        current_snapshot = self.network.snapshots[self.snapshot_idx]
        storage_obs = []

        for storage_name in self.storage_names:
            # Current Inflow (normalized by p_nom)
            p_nom = self.network.storage_units.loc[storage_name, 'p_nom']
            if hasattr(self.network.storage_units_t, 'inflow') and storage_name in self.network.storage_units_t.inflow.columns:
                current_inflow = self.network.storage_units_t.inflow.loc[current_snapshot, storage_name]
                # Normalize by p_nom for consistent scaling
                normalized_inflow = current_inflow / p_nom if p_nom > 0 else 0
            else:
                # If no inflow data exists, use zero
                normalized_inflow = 0.0
            storage_obs.append(normalized_inflow)

        return np.array(storage_obs, dtype=np.float32)

    def create_storage_observation_bounds(self):
        """
        Create bounds for storage unit observations.
        Only includes inflow bounds since SOC is excluded.
        """
        if self.n_storage == 0:
            return np.array([]), np.array([])

        values_per_storage = 1  # Only inflow now
        total_storage_obs = self.n_storage * values_per_storage

        low_bounds = np.zeros(total_storage_obs)
        high_bounds = np.zeros(total_storage_obs)

        for i, storage_name in enumerate(self.storage_names):
            # Inflow bounds: Get from historical data
            if hasattr(self.network.storage_units_t, 'inflow'):
                p_nom = self.network.storage_units.loc[storage_name, 'p_nom']
                if storage_name in self.network.storage_units_t.inflow.columns:
                    inflow_data = self.network.storage_units_t.inflow[storage_name]

                    min_inflow_norm = inflow_data.min() / p_nom if p_nom > 0 else 0
                    max_inflow_norm = inflow_data.max() / p_nom if p_nom > 0 else 0

                    low_bounds[i] = min_inflow_norm
                    high_bounds[i] = max_inflow_norm
                else:
                    # If inflow data exists but not for this specific storage unit
                    low_bounds[i] = 0.0
                    high_bounds[i] = 0.0
            else:
                # No inflow data
                low_bounds[i] = 0.0
                high_bounds[i] = 0.0

        return low_bounds.astype(np.float32), high_bounds.astype(np.float32)

    def create_observation_bounds(self):
        """
        Create bounds for the observation space based on:
        - Load p_set values
        - Renewable generator p_max_pu values
        - Storage unit current inflow (normalized)
        """
        # 1. Load bounds
        load_p_set_all = self.network.loads_t.p_set  # DataFrame with all snapshots and loads
        load_low_bounds = load_p_set_all.min(axis=0).values  # Min across all snapshots for each load
        load_high_bounds = load_p_set_all.max(axis=0).values  # Max across all snapshots for each load

        # 2. Renewable generator bounds
        if self.n_renewable > 0:
            renewable_p_max_pu_all = self.network.generators_t.p_max_pu[self.renewable_names]
            renewable_low_bounds = renewable_p_max_pu_all.min(axis=0).values
            renewable_high_bounds = renewable_p_max_pu_all.max(axis=0).values
        else:
            renewable_low_bounds = np.array([])
            renewable_high_bounds = np.array([])

        # 3. Storage bounds (current inflow)
        storage_low_bounds, storage_high_bounds = self.create_storage_observation_bounds()

        # 4. Combine all bounds
        low_bounds = np.concatenate([load_low_bounds, renewable_low_bounds, storage_low_bounds])
        high_bounds = np.concatenate([load_high_bounds, renewable_high_bounds, storage_high_bounds])

        return low_bounds.astype(np.float32), high_bounds.astype(np.float32)

    def _get_observation(self):
        """
        Get current network state as observation.

        Returns observation vector with structure:
        [load_1_demand, load_2_demand, ..., load_n_demand,
        renewable_1_p_max_pu, renewable_2_p_max_pu, ..., renewable_m_p_max_pu,
        storage_1_current_inflow_norm, storage_2_current_inflow_norm, ..., storage_k_current_inflow_norm]
        """

        # 1. Load demands (dynamic values at current snapshot)
        load_demands = self.network.loads_t.p_set.iloc[self.snapshot_idx].values

        # 2. Renewable generator p_max_pu values (time-varying availability at current snapshot)
        if self.n_renewable > 0:
            renewable_p_max_pu = self.network.generators_t.p_max_pu.iloc[self.snapshot_idx][self.renewable_names].values
        else:
            renewable_p_max_pu = np.array([])

        # 3. Storage states (previous SOC normalized + current inflow normalized)
        storage_states = self._get_storage_observation()

        # 4. Combine all observations
        observation = np.concatenate([load_demands, renewable_p_max_pu, storage_states])

        return observation.astype(np.float32)

    def reset_network(self):
        """Reset and ensure essential DataFrames exist."""
        #Note that we do not just create a new network here, as this consumes more memory and previously led to a segmentation fault
        # we reset these ttributes for all snapshots, but they all start empty when the network is created so i think that's fine
        # Initialize/reset generators_t.p_set
        if not hasattr(self.network.generators_t, 'p_set') or self.network.generators_t.p_set.empty:
            self.network.generators_t.p_set = pd.DataFrame(
                0.0,
                index=self.network.snapshots,
                columns=self.network.generators.index
            )
        else:
            self.network.generators_t.p_set.iloc[:, :] = 0.0

        # Initialize/reset storage_units_t attributes
        if not hasattr(self.network.storage_units_t, 'p_set') or self.network.storage_units_t.p_set.empty:
            self.network.storage_units_t.p_set = pd.DataFrame(
                0.0,
                index=self.network.snapshots,
                columns=self.network.storage_units.index
            )
        else:
            self.network.storage_units_t.p_set.iloc[:, :] = 0.0


        if not hasattr(self.network.storage_units_t, 'p_dispatch') or self.network.storage_units_t.p_dispatch.empty:
            self.network.storage_units_t.p_dispatch = pd.DataFrame(
                0.0,
                index=self.network.snapshots,
                columns=self.network.storage_units.index
            )
        else:
            self.network.storage_units_t.p_dispatch.iloc[:, :] = 0.0

        if not hasattr(self.network.storage_units_t, 'p_store') or self.network.storage_units_t.p_store.empty:
            self.network.storage_units_t.p_store = pd.DataFrame(
                0.0,
                index=self.network.snapshots,
                columns=self.network.storage_units.index
            )
        else:
            self.network.storage_units_t.p_store.iloc[:, :] = 0.0

        if not hasattr(self.network.storage_units_t, 'state_of_charge') or self.network.storage_units_t.state_of_charge.empty:
            self.network.storage_units_t.state_of_charge = pd.DataFrame(
                0.0,
                index=self.network.snapshots,
                columns=self.network.storage_units.index
            )
        else:
            self.network.storage_units_t.state_of_charge.iloc[:, :] = 0.0

        if not hasattr(self.network.storage_units_t, 'spill') or self.network.storage_units_t.spill.empty:
            self.network.storage_units_t.spill = pd.DataFrame(
                0.0,
                index=self.network.snapshots,
                columns=self.network.storage_units.index
            )
        else:
            self.network.storage_units_t.spill.iloc[:, :] = 0.0

    def reset(self, seed=None, options=None):
        """
        Reset to training data with proper episode length handling.
        """
        # Set seed but don't use it since start from first timestep in data
        if seed is not None:
            np.random.seed(seed)

        # Reset counter
        self.snapshot_idx = 0

        self.reset_network()

        obs = self._get_observation()
        info = {
            'snapshot_idx': self.snapshot_idx,
            'is_training': True
        }

        return obs, info

    # omit reset_for_testing; we will cycle through train data for initial testing
    # omit get_test_snapshots
    # omit compute_storage_power_bounds; we don't use it since now enforce storage unit constraint with penalty

    def scale_action(self, action):
        """
        Simplified storage action scaling to avoid infeasible bounds.
        """
        scaled_actions = {}

        # Dispatchable generators
        if self.n_dispatchable > 0:
            disp_actions = action[self.action_structure['dispatchable']['start']:
                                self.action_structure['dispatchable']['end']]
            scaled_actions['dispatchable'] = self.disp_p_min + disp_actions * (self.disp_p_max - self.disp_p_min)
        else:
            scaled_actions['dispatchable'] = np.array([])

        # Renewable generators
        if self.n_renewable > 0:
            renewable_actions = action[self.action_structure['renewable']['start']:
                                    self.action_structure['renewable']['end']]
            current_p_max_pu = self.network.generators_t.p_max_pu.iloc[self.snapshot_idx][self.renewable_names].values
            current_p_max = current_p_max_pu * self.renewable_p_nom
            current_p_min = self.renewable_p_min_pu * self.renewable_p_nom
            scaled_actions['renewable'] = current_p_min + renewable_actions * (current_p_max - current_p_min)
        else:
            scaled_actions['renewable'] = np.array([])

        # Simplified storage scaling
        # Bidirectional storage scaling
        if self.n_bidirectional > 0:
            bid_p_store_actions = action[self.action_structure['bidirectional_p_store']['start']:
                                       self.action_structure['bidirectional_p_store']['end']]
            bid_p_dispatch_actions = action[self.action_structure['bidirectional_p_dispatch']['start']:
                                           self.action_structure['bidirectional_p_dispatch']['end']]

            scaled_p_store_bid = np.zeros(self.n_bidirectional)
            scaled_p_dispatch_bid = np.zeros(self.n_bidirectional)

            for i, storage_name in enumerate(self.bidirectional_names):
                storage_df = self.network.storage_units.loc[storage_name]
                p_nom = storage_df.p_nom
                p_max_pu = storage_df.p_max_pu
                p_bound = p_nom * p_max_pu

                scaled_p_store_bid[i] = bid_p_store_actions[i] * p_bound
                scaled_p_dispatch_bid[i] = bid_p_dispatch_actions[i] * p_bound

            scaled_actions['bidirectional_p_store'] = scaled_p_store_bid
            scaled_actions['bidirectional_p_dispatch'] = scaled_p_dispatch_bid
        else:
            scaled_actions['bidirectional_p_store'] = np.array([])
            scaled_actions['bidirectional_p_dispatch'] = np.array([])

        # Unidirectional storage scaling (p_dispatch only)
        if self.n_unidirectional > 0:
            uni_p_dispatch_actions = action[self.action_structure['unidirectional_p_dispatch']['start']:
                                           self.action_structure['unidirectional_p_dispatch']['end']]

            scaled_p_dispatch_uni = np.zeros(self.n_unidirectional)

            for i, storage_name in enumerate(self.unidirectional_names):
                storage_df = self.network.storage_units.loc[storage_name]
                p_nom = storage_df.p_nom
                p_max_pu = storage_df.p_max_pu
                p_bound = p_nom * p_max_pu

                scaled_p_dispatch_uni[i] = uni_p_dispatch_actions[i] * p_bound

            scaled_actions['unidirectional_p_dispatch'] = scaled_p_dispatch_uni
        else:
            scaled_actions['unidirectional_p_dispatch'] = np.array([])

        return scaled_actions

    def _simplified_power_balance(self):
        """
        Optimized power balance calculation using vectorized operations and cached mappings.
        """
        snapshot_idx = self.snapshot_idx

        # 1. Vectorized generator updates using iloc
        if len(self._non_slack_gens) > 0:
            # Get column indices for non-slack generators
            non_slack_cols = [self.network.generators_t.p_set.columns.get_loc(gen) for gen in self._non_slack_gens]

            # Vectorized copy from p_set to p
            p_set_values = self.network.generators_t.p_set.iloc[snapshot_idx, non_slack_cols].values

            # Set realized power
            for i, gen_name in enumerate(self._non_slack_gens):
                gen_col = self.network.generators_t.p.columns.get_loc(gen_name)
                self.network.generators_t.p.iloc[snapshot_idx, gen_col] = p_set_values[i]

        # 2. Vectorized storage updates using iloc
        if len(self.storage_names) > 0:
            storage_cols = [self.network.storage_units_t.p_set.columns.get_loc(name) for name in self.storage_names]
            p_set_values = self.network.storage_units_t.p_set.iloc[snapshot_idx, storage_cols].values

            for i, storage_name in enumerate(self.storage_names):
                storage_col = self.network.storage_units_t.p.columns.get_loc(storage_name)
                self.network.storage_units_t.p.iloc[snapshot_idx, storage_col] = p_set_values[i]

        # 3. Vectorized bus power calculation using cached mappings
        # Get all power values at once using iloc
        load_powers = -self.network.loads_t.p_set.iloc[snapshot_idx].values  # Negative for consumption

        if len(self._non_slack_gens) > 0:
            non_slack_gen_cols = [self.network.generators_t.p.columns.get_loc(gen) for gen in self._non_slack_gens]
            gen_powers = self.network.generators_t.p.iloc[snapshot_idx, non_slack_gen_cols].values
        else:
            gen_powers = np.array([])

        if len(self.storage_names) > 0:
            storage_cols = [self.network.storage_units_t.p.columns.get_loc(name) for name in self.storage_names]
            storage_powers = self.network.storage_units_t.p.iloc[snapshot_idx, storage_cols].values
        else:
            storage_powers = np.array([])

        # Use pandas Series with cached bus mappings for fast groupby
        load_series = pd.Series(load_powers, index=self._load_names)
        load_bus_power = load_series.groupby(self._load_buses).sum()

        if len(gen_powers) > 0:
            gen_series = pd.Series(gen_powers, index=self._non_slack_gens)
            gen_bus_power = gen_series.groupby(self._non_slack_gen_buses).sum()
        else:
            gen_bus_power = pd.Series(dtype=float)

        if len(storage_powers) > 0:
            storage_series = pd.Series(storage_powers, index=self.storage_names)
            storage_bus_power = storage_series.groupby(self._storage_buses).sum()
        else:
            storage_bus_power = pd.Series(dtype=float)

        # Combine all bus injections efficiently
        total_bus_injection = (load_bus_power.reindex(self._all_buses, fill_value=0) +
                              gen_bus_power.reindex(self._all_buses, fill_value=0) +
                              storage_bus_power.reindex(self._all_buses, fill_value=0))

        # 4. Calculate total imbalance
        total_imbalance = total_bus_injection.sum()

        # 5. Vectorized slack generator update using iloc
        if len(self._slack_gens) == 1:
            slack_gen = self._slack_gens[0]
            slack_col_set = self.network.generators_t.p_set.columns.get_loc(slack_gen)
            slack_col_p = self.network.generators_t.p.columns.get_loc(slack_gen)

            current_p_set = self.network.generators_t.p_set.iloc[snapshot_idx, slack_col_set]
            self.network.generators_t.p.iloc[snapshot_idx, slack_col_p] = current_p_set - total_imbalance

        elif len(self._slack_gens) > 1:
            # Multiple slack generators with cached weights
            slack_cols_set = [self.network.generators_t.p_set.columns.get_loc(gen) for gen in self._slack_gens]
            slack_cols_p = [self.network.generators_t.p.columns.get_loc(gen) for gen in self._slack_gens]

            current_p_sets = self.network.generators_t.p_set.iloc[snapshot_idx, slack_cols_set].values
            slack_adjustments = self._slack_weights * (-total_imbalance)
            new_p_values = current_p_sets + slack_adjustments

            for i, slack_col_p in enumerate(slack_cols_p):
                self.network.generators_t.p.iloc[snapshot_idx, slack_col_p] = new_p_values[i]

    def _update_storage_soc_single_snapshot(self, storage_name):
        if self.snapshot_idx == 0:
            # For first snapshot, previous SOC is the initial value
            soc_prev = self.network.storage_units.state_of_charge_initial.loc[storage_name]
        else:
            previous_snapshot = self.network.snapshots[self.snapshot_idx - 1]
            soc_prev = self.network.storage_units_t.state_of_charge.loc[previous_snapshot, storage_name]

        current_snapshot = self.network.snapshots[self.snapshot_idx]

        #Get storage parameters
        storage_unit = self.network.storage_units.loc[storage_name]
        soc_max = storage_unit.p_nom * storage_unit.max_hours
        eff_store = storage_unit.efficiency_store
        eff_dispatch = storage_unit.efficiency_dispatch
        standing_loss = storage_unit.standing_loss

        # Get time step
        if hasattr(self.network.snapshot_weightings, 'stores'):
            delta_t = self.network.snapshot_weightings.stores.iloc[self.snapshot_idx]
        else:
            delta_t = self.network.snapshot_weightings.iloc[self.snapshot_idx]

        eff_standing = (1 - standing_loss) ** delta_t

        # Get current operations (these determine the SOC change)
        p_store = self.network.storage_units_t.p_store.loc[current_snapshot, storage_name]
        p_dispatch = self.network.storage_units_t.p_dispatch.loc[current_snapshot, storage_name]
        if storage_name in self.network.storage_units_t.inflow.columns:
          inflow = self.network.storage_units_t.inflow.loc[current_snapshot, storage_name]
        else:
          inflow=0

        # Calculate SOC without spill (could be non-zero even if soc_prev=0)
        soc_without_spill = (soc_prev * eff_standing +
                            (p_store * eff_store - p_dispatch/eff_dispatch + inflow) * delta_t)

        # Calculate required spill
        required_spill = max(0, (soc_without_spill - soc_max) / delta_t)

        # Final SOC after spill
        soc_actual = min(soc_without_spill, soc_max)

        # Update the network
        self.network.storage_units_t.state_of_charge.loc[current_snapshot, storage_name] = soc_actual
        if hasattr(self.network, 'storage_units_t') and 'spill' in self.network.storage_units_t:
            self.network.storage_units_t.spill.loc[current_snapshot, storage_name] = required_spill

    def evaluate_objective_direct(self, current_snapshot):
        """
        Direct evaluation of PyPSA operational objective function terms.

        This function evaluates only the operational terms (marginal costs) that PyPSA
        optimizes for generators and storage units, excluding capital costs and other
        investment-related terms.

        Returns
        -------
        float
            Total operational cost for the current snapshot including snapshot weighting
        """
        total_cost = 0.0

        # Get snapshot weighting for proper cost calculation
        snapshot_weighting = self.network.snapshot_weightings.objective.loc[current_snapshot]

        # Generator operational costs
        if len(self.network.generators) > 0:
            # Get marginal costs and power output
            gen_marginal_costs = self.network.generators['marginal_cost']

            gen_power = self.network.generators_t.p.loc[current_snapshot]

            # Calculate generator operational cost
            gen_cost = (gen_marginal_costs * gen_power).sum()
            total_cost += gen_cost

        # Storage unit operational costs
        if len(self.network.storage_units) > 0:
            # Marginal cost for storage dispatch (discharge)
            storage_marginal_costs = self.network.storage_units['marginal_cost']
            storage_p_dispatch = self.network.storage_units_t.p_dispatch.loc[current_snapshot]
            storage_cost = (storage_marginal_costs * storage_p_dispatch).sum()
            #multiply correpsonding entries of the pandas columns and then sum them
            total_cost += storage_cost

            # Marginal cost for storage charging
            storage_marginal_costs_storage = self.network.storage_units['marginal_cost_storage']
            storage_store_power = self.network.storage_units_t.p_store.loc[current_snapshot]
            storage_store_cost = (storage_marginal_costs_storage * storage_store_power).sum()
            total_cost += storage_store_cost

            spill_costs = self.network.storage_units['spill_cost']
            spill_amounts = self.network.storage_units_t.spill.loc[current_snapshot]
            spill_cost = (spill_costs * spill_amounts).sum()
            total_cost += spill_cost

        # Apply snapshot weighting (this is crucial for proper cost calculation)
        total_cost *= snapshot_weighting

        return total_cost

    def _calculate_reward(self):
        """Calculate reward using stored objective components."""
        # Get the current snapshot name

        current_snapshot = self.network.snapshots[self.snapshot_idx]
        return -1*self.reward_scale_factor * self.evaluate_objective_direct(current_snapshot)

    def calculate_reward_no_penalty(self):
        # Get base reward from objective function (negative for minimization)
        base_reward = self._calculate_reward()

        current_snapshot = self.network.snapshots[self.snapshot_idx]

        # Initialize constraint tracking
        constraint_results=self.initial_constr.copy()  # Use copy to avoid modifying original

        constraint_results_idx=0
        slack_generators = self.network.generators[self.network.generators.control == "Slack"].index
        if not slack_generators.empty:
            for gen_name in slack_generators:
                # Get actual power output after power flow
                p_actual = self.network.generators_t.p.loc[current_snapshot, gen_name]

                # Get limits
                p_min = self.network.generators.loc[gen_name, 'p_min_pu'] * self.network.generators.loc[gen_name, 'p_nom']
                p_max = self.network.generators.loc[gen_name, 'p_max_pu'] * self.network.generators.loc[gen_name, 'p_nom']

                # Check lower bound
                if p_actual < p_min:
                    violation = min(float((p_min - p_actual)/(p_max-p_min)), 1.0)
                    constraint_results[constraint_results_idx]=violation
                #if not violated, the entry remains at zero
                #The violation represents "what fraction of the allowable operating range am I violating by?"

                # Check upper bound
                if p_actual > p_max:
                    violation = min(float((p_actual - p_max)/(p_max-p_min)), 1.0)
                    constraint_results[constraint_results_idx+self.n_slack]=violation

                constraint_results_idx+=1

        constraint_results_idx=2 * self.n_slack

        # 3. Check SOC constraints for storage units
        #In PyPSA, storage unit SOC bounds are: 0 ≤ soc ≤ 1 * (p_nom * max_hours)
        # Even for PHS, SOC is nonnegative since its the energy stored in the reservoir (always ≥ 0)
        if self.n_storage > 0:
            for storage_name in self.storage_names:
                # Get current SOC
                current_soc = self.network.storage_units_t.state_of_charge.loc[current_snapshot, storage_name]

                # Get SOC limits
                storage_unit = self.network.storage_units.loc[storage_name]
                soc_min = 0.0  # Minimum SOC is typically 0
                soc_max = storage_unit.p_nom * storage_unit.max_hours  # Maximum energy capacity

                # Check lower SOC bound
                if current_soc < soc_min:
                    violation = min(float((soc_min - current_soc)/(soc_max-soc_min)),1.0)
                    constraint_results[constraint_results_idx]=violation

                # Check upper SOC bound
                if current_soc > soc_max:
                    violation = min(float((current_soc - soc_max)/(soc_max-soc_min)),1.0)
                    constraint_results[constraint_results_idx+self.n_storage]=violation

        return base_reward, constraint_results

    def calculate_reward_no_penalty_test(self):
        # Get base reward from objective function (negative for minimization)
        current_snapshot = self.network.snapshots[self.snapshot_idx]
        base_reward = self.evaluate_objective_direct(current_snapshot)

        current_snapshot = self.network.snapshots[self.snapshot_idx]

        # Initialize constraint tracking
        constraint_results=self.initial_constr_test.copy()  # Use copy to avoid modifying original

        constraint_results_idx=0
        slack_generators = self.network.generators[self.network.generators.control == "Slack"].index
        if not slack_generators.empty:
            for gen_name in slack_generators:
                # Get actual power output after power flow
                p_actual = self.network.generators_t.p.loc[current_snapshot, gen_name]

                # Get limits
                p_min = self.network.generators.loc[gen_name, 'p_min_pu'] * self.network.generators.loc[gen_name, 'p_nom']
                p_max = self.network.generators.loc[gen_name, 'p_max_pu'] * self.network.generators.loc[gen_name, 'p_nom']

                # Check lower bound
                if p_actual < p_min:
                    violation = min(float((p_min - p_actual)/(p_max-p_min)), 1.0)
                    constraint_results[constraint_results_idx]=violation
                #if not violated, the entry remains at zero
                #The violation represents "what fraction of the allowable operating range am I violating by?"

                # Check upper bound
                if p_actual > p_max:
                    violation = min(float((p_actual - p_max)/(p_max-p_min)), 1.0)
                    constraint_results[constraint_results_idx+self.n_slack]=violation

                constraint_results_idx+=1

        constraint_results_idx=2 * self.n_slack

        # 3. Check SOC constraints for storage units
        #In PyPSA, storage unit SOC bounds are: 0 ≤ soc ≤ 1 * (p_nom * max_hours)
        # Even for PHS, SOC is nonnegative since its the energy stored in the reservoir (always ≥ 0)
        if self.n_storage > 0:
            for storage_name in self.storage_names:
                # Get current SOC
                current_soc = self.network.storage_units_t.state_of_charge.loc[current_snapshot, storage_name]

                # Get SOC limits
                storage_unit = self.network.storage_units.loc[storage_name]
                soc_min = 0.0  # Minimum SOC is typically 0
                soc_max = storage_unit.p_nom * storage_unit.max_hours  # Maximum energy capacity

                # Check lower SOC bound
                if current_soc < soc_min:
                    violation = min(float((soc_min - current_soc)/(soc_max-soc_min)),1.0)
                    constraint_results[constraint_results_idx]=violation

                # Check upper SOC bound
                if current_soc > soc_max:
                    violation = min(float((current_soc - soc_max)/(soc_max-soc_min)),1.0)
                    constraint_results[constraint_results_idx+self.n_storage]=violation

        constraint_results_idx=2 * self.n_slack+2*self.n_storage
        # 2. Check line flow constraints
        for line_name in self.network.lines.index:
            # Get line parameters
            s_nom = self.network.lines.loc[line_name, 's_nom']
            s_max_pu = 1.0  # Default, or get from lines_t.s_max_pu if it exists

            # Calculate active power limit (this is what PyPSA's linear constraints check)
            s_max = s_max_pu * s_nom

            # Get active power flow from the linear power flow
            # In PyPSA's linear formulation, this is the 's' variable value
            p0 = abs(self.network.lines_t.p0.loc[current_snapshot, line_name])

            # Check if active power flow exceeds limit
            if p0 > s_max:
                violation = min(float((p0 - s_max)/s_max),1.0)
                constraint_results[constraint_results_idx]=violation
            constraint_results_idx+=1

        return base_reward, constraint_results

    def calculate_constrained_reward(self):
        """
        Calculate reward using summation method with dynamic constraint checking.

        Summation method:
        - Reward = -J(s) - P(s)
        """
        base_reward, constraint_results = self.calculate_reward_no_penalty()

        # Calculate penalty
        total_violation = np.sum(constraint_results)
        penalty = self.penalty_factor * (total_violation/self.n_constr)

        #Compute scaled base reward
        scaled_base_reward = base_reward

        # Calculate final reward using summation method
        constrained_reward = scaled_base_reward - penalty

        return constrained_reward, constraint_results

    def take_action(self, action):
        scaled_actions = self.scale_action(action)
        # Apply dispatchable generator setpoints
        if self.n_dispatchable > 0:
            for i, gen_name in enumerate(self.dispatchable_names):
                self.network.generators_t.p_set.iloc[self.snapshot_idx, self.network.generators_t.p_set.columns.get_loc(gen_name)] = scaled_actions['dispatchable'][i]

        # Apply renewable generator setpoints
        if self.n_renewable > 0:
            for i, gen_name in enumerate(self.renewable_names):
                self.network.generators_t.p_set.iloc[self.snapshot_idx,
                    self.network.generators_t.p_set.columns.get_loc(gen_name)] = scaled_actions['renewable'][i]

        # Apply bidirectional storage setpoints
        if self.n_bidirectional > 0:
            for i, storage_name in enumerate(self.bidirectional_names):
                self.network.storage_units_t.p_store.iloc[self.snapshot_idx,
                    self.network.storage_units_t.p_store.columns.get_loc(storage_name)] = scaled_actions['bidirectional_p_store'][i]
                self.network.storage_units_t.p_dispatch.iloc[self.snapshot_idx,
                    self.network.storage_units_t.p_dispatch.columns.get_loc(storage_name)] = scaled_actions['bidirectional_p_dispatch'][i]
                self.network.storage_units_t.p_set.iloc[self.snapshot_idx,
                    self.network.storage_units_t.p_set.columns.get_loc(storage_name)] = (
                        scaled_actions['bidirectional_p_dispatch'][i] - scaled_actions['bidirectional_p_store'][i])

        # Apply unidirectional storage setpoints
        if self.n_unidirectional > 0:
            for i, storage_name in enumerate(self.unidirectional_names):
                # Only p_dispatch is controllable, p_store is always 0
                self.network.storage_units_t.p_store.iloc[self.snapshot_idx,
                    self.network.storage_units_t.p_store.columns.get_loc(storage_name)] = 0.0
                self.network.storage_units_t.p_dispatch.iloc[self.snapshot_idx,
                    self.network.storage_units_t.p_dispatch.columns.get_loc(storage_name)] = scaled_actions['unidirectional_p_dispatch'][i]
                self.network.storage_units_t.p_set.iloc[self.snapshot_idx,
                    self.network.storage_units_t.p_set.columns.get_loc(storage_name)] = scaled_actions['unidirectional_p_dispatch'][i]

        # Update SOC for all storage units
        all_storage_names = self.bidirectional_names + self.unidirectional_names
        for storage_name in all_storage_names:
            if self.snapshot_idx > 0:
                self._update_storage_soc_single_snapshot(storage_name)

    def _simplified_power_balance(self):
        """
        Simplified power balance calculation that sets generator and storage realized values
        to their setpoints and calculates slack generator dispatch without full power flow.

        Based on PyPSA's approach but optimized for cases where line flows aren't needed.
        """
        current_snapshot = self.network.snapshots[self.snapshot_idx]

        # 1. Set non-slack generator realized power to their setpoints
        # (This is what PyPSA does in _calculate_controllable_nodal_power_balance)
        non_slack_gens = self.network.generators[self.network.generators.control != "Slack"].index
        for gen_name in non_slack_gens:
            p_set = self.network.generators_t.p_set.loc[current_snapshot, gen_name]
            self.network.generators_t.p.loc[current_snapshot, gen_name] = p_set

        # 2. Set storage unit realized power to their setpoints
        # PyPSA uses p_set for storage units as the net dispatch (positive = discharge, negative = charge)
        for storage_name in self.storage_names:
            p_set = self.network.storage_units_t.p_set.loc[current_snapshot, storage_name]
            self.network.storage_units_t.p.loc[current_snapshot, storage_name] = p_set

        # 3. Calculate bus power injections (similar to PyPSA's approach)
        buses_o = self.network.buses.index

        # Initialize bus power injections to load demands (negative = consumption)
        bus_p_injection = pd.Series(0.0, index=buses_o)

        # Add load consumption (loads have negative sign in PyPSA)
        for load_name in self.network.loads.index:
            bus = self.network.loads.loc[load_name, 'bus']
            p_load = self.network.loads_t.p_set.loc[current_snapshot, load_name]
            bus_p_injection[bus] -= p_load  # Loads consume power

        # Add non-slack generator injections (positive = generation)
        for gen_name in non_slack_gens:
            bus = self.network.generators.loc[gen_name, 'bus']
            p_gen = self.network.generators_t.p.loc[current_snapshot, gen_name]
            bus_p_injection[bus] += p_gen

        # Add storage unit injections (positive = discharge, negative = charge)
        for storage_name in self.storage_names:
            bus = self.network.storage_units.loc[storage_name, 'bus']
            p_storage = self.network.storage_units_t.p.loc[current_snapshot, storage_name]
            bus_p_injection[bus] += p_storage

        # 4. Calculate total system imbalance
        total_imbalance = bus_p_injection.sum()

        # 5. Distribute imbalance to slack generators
        # (Following PyPSA's approach in sub_network_pf_singlebus)
        slack_generators = self.network.generators[self.network.generators.control == "Slack"].index

        if len(slack_generators) == 1:
            # Single slack generator takes all imbalance
            slack_gen = slack_generators[0]
            current_p_set = self.network.generators_t.p_set.loc[current_snapshot, slack_gen]
            self.network.generators_t.p.loc[current_snapshot, slack_gen] = current_p_set - total_imbalance
        else:
            # Multiple slack generators - distribute proportionally to their p_nom
            slack_p_noms = self.network.generators.loc[slack_generators, 'p_nom']
            slack_weights = slack_p_noms / slack_p_noms.sum()

            for slack_gen in slack_generators:
                current_p_set = self.network.generators_t.p_set.loc[current_snapshot, slack_gen]
                slack_share = slack_weights[slack_gen] * (-total_imbalance)
                self.network.generators_t.p.loc[current_snapshot, slack_gen] = current_p_set + slack_share


    def step(self, action):
        """
        Execute one time step within the environment.

        Args:
            action: Array of setpoints for all controllable components [disp_gen1, disp_gen2, ...,
                   renewable_gen1, renewable_gen2, ..., storage1, storage2, ...]

        Returns:
            observation: Network state after action
            reward: Reward for this action
            terminated: Whether episode is finished due to task completion
            truncated: Whether episode is finished due to time limit
            info: Additional information
        """
        self.take_action(action)

        # Run simplified power balance instead of full power flow
        try:
            # self._simplified_power_balance()
            self.network.lpf(self.network.snapshots[self.snapshot_idx],skip_pre=True)
            power_flow_converged = True
        except Exception as e:
            print(f"Power balance failed: {e}")
            power_flow_converged = False

        if not power_flow_converged:
            reward = self.no_convergence_lpf_penalty
            constraint_results= np.ones(self.n_constr) * -1# -1 indicates power flow did not converge. If want to get total constraint violation we will need to
        else:
            # Calculate reward using constrained reward function
            reward, constraint_results = self.calculate_constrained_reward()

        # Increment step counters
        self.snapshot_idx += 1

        # Check if episode is done
        truncated, terminated = self._check_done()

        # Get observation
        if terminated or truncated:
            observation = np.zeros(self.observation_space.shape, dtype=np.float32)
        else:
            observation = self._get_observation()

        # Additional info
        info = {
            'power_flow_converged': power_flow_converged,
            'snapshot_idx': self.snapshot_idx,
            'constraint_results': constraint_results
        }

        return observation, reward, terminated, truncated, info

    def step_test(self, action):
        """
        Execute one time step within the environment. Use reward without penalty so we can compare to the reward from the baseline.

        Args:
            action: Array of setpoints for all controllable components [disp_gen1, disp_gen2, ...,
                   renewable_gen1, renewable_gen2, ..., storage1, storage2, ...]

        Returns:
            observation: Network state after action
            reward: Reward for this action
            terminated: Whether episode is finished due to task completion
            truncated: Whether episode is finished due to time limit
            info: Additional information
        """
        self.take_action(action)

        # Run power flow to get new network state
        try:
            self.network.lpf(self.network.snapshots[self.snapshot_idx], skip_pre=True)
            power_flow_converged = True
        except Exception as e:
            print(f"Power flow failed: {e}")
            power_flow_converged = False

        if not power_flow_converged:
            reward = self.no_convergence_lpf_penalty
            constraint_results= np.ones(self.n_constr_test) * -1# -1 indicates power flow did not converge. If want to get total constraint violation we will need to
        else:
            # Calculate reward using constrained reward function
            reward, constraint_results = self.calculate_reward_no_penalty_test()

        # Increment step counters
        self.snapshot_idx += 1

        # Check if episode is done
        truncated, terminated = self._check_done()
        #replaced _check_done_test

        # Get observation
        if terminated or truncated:
            observation = np.zeros(self.observation_space.shape, dtype=np.float32)
        else:
            observation = self._get_observation()

        # Additional info
        info = {
            'power_flow_converged': power_flow_converged,
            'snapshot_idx': self.snapshot_idx,
            'constraint_results': constraint_results
        }

        return observation, reward, terminated, truncated, info

    def _check_done(self):
        """
        Modified to handle both fixed and variable episode lengths.
        """
        truncated=False
        terminated=False
        # For all episodes, stop if we've reached the test data boundary
        if self.snapshot_idx >= self.total_snapshots:
            terminated=True

        return truncated, terminated

    #Omit _check_done_test

    def seed(self, seed=None):
        """
        Set the random seed for reproducible experiments.
        """
        np.random.seed(seed)
        return [seed]

        #omit render function

# network_file_path= "/Users/antoniagrindrod/Documents/pypsa-earth_project/pypsa-earth-RL/networks/elec_s_10_ec_lc1.0_1h.nc"
# input_dir="/Users/antoniagrindrod/Documents/pypsa-earth_project/pypsa-earth-RL/RL/var_constraint_map"
# replacement_reward_offset=calculate_offset_k_initialization(network_file=network_file_path, input_dir=input_dir)

In [7]:
from tqdm import tqdm
class BackboneNetwork(nn.Module):
    def __init__(self, input_features, hidden_dimensions, out_features, dropout):
        super(BackboneNetwork, self).__init__()

        # SIMPLIFIED: Single hidden layer network for debugging
        self.neuralnet = nn.Sequential(
            nn.Linear(input_features, hidden_dimensions),
            nn.ReLU(),
            nn.Linear(hidden_dimensions, hidden_dimensions),
            nn.ReLU(),
            nn.Linear(hidden_dimensions, out_features)
        )

    def forward(self, x):
        output = self.neuralnet(x)
        return output

#Define the actor-critic network
class actorCritic(nn.Module):
    def __init__(self, actor, critic):
        super().__init__()
        self.actor = actor
        self.critic = critic
    def forward(self, state):
        action_pred = self.actor(state)
        value_pred = self.critic(state)
        return action_pred, value_pred
        #Returns both the action predictions and the value predictions.

#We'll use the networks defined above to create an actor and a critic. Then, we will create an agent, including the actor and the critic.
#finish this step later
# def create_agent(hidden_dimensions, dropout):
#     INPUT_FEATURES =env_train.
class PPO_agent:
    def __init__(self,
                 env,
                 device,
                 run,
                 hidden_dimensions,
                 dropout, discount_factor,
                 max_episodes,
                 print_interval,
                 PPO_steps,
                 n_trials,
                 epsilon,
                 entropy_coefficient,
                 learning_rate,
                 batch_size,
                 optimizer_name,
                 seed, training_start_time=None):

        self.seed = seed+20000
        if seed is not None:
            # Set PyTorch seed for this class
            torch.manual_seed(self.seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed(self.seed)
        self.env = env  # Store the environment as an attribute

        self.device = device
        self.run = run

        # Get observation and action space dimensions for gymnasium environment
        obs, _ = self.env.reset()
        self.action_dim = self.env.action_space.shape[0]

        self.INPUT_FEATURES = obs.shape[0]  # Flattened observation size
        self.ACTOR_OUTPUT_FEATURES = self.action_dim* 2  # 2 parameters (alpha, beta) per action dimension

        self.HIDDEN_DIMENSIONS = hidden_dimensions

        self.CRITIC_OUTPUT_FEATURES = 1
        self.DROPOUT = dropout

        self.discount_factor = discount_factor
        self.max_episodes = max_episodes
        self.print_interval = print_interval
        self.PPO_steps=PPO_steps
        self.n_trials=n_trials
        self.epsilon=epsilon
        self.entropy_coefficient=entropy_coefficient
        self.learning_rate=learning_rate

        self.batch_size=batch_size

        # Initialize actor network
        self.actor = BackboneNetwork(
            self.INPUT_FEATURES, self.HIDDEN_DIMENSIONS, self.ACTOR_OUTPUT_FEATURES, self.DROPOUT
        ).to(self.device)

        # Initialize the final layer bias for Beta distribution
        for name, param in self.actor.named_parameters():
            if 'neuralnet.4.bias' in name:  # Adjust index based on your network structure
                # Initialize to produce alpha=beta=2 (uniform-like distribution centered at 0.5)
                param.data.fill_(0.0)  # softplus(0) + 1 = 2
                print(f"Initialized Beta parameters to produce uniform-like distribution")

        # Initialize critic network
        self.critic = BackboneNetwork(
            self.INPUT_FEATURES, self.HIDDEN_DIMENSIONS, self.CRITIC_OUTPUT_FEATURES, self.DROPOUT
        ).to(self.device)

        #Better move the .to(self.device) call separately for both self.actor and self.critic. This ensures the individual parts of the model are moved to the correct device before combined into the actorCritic class
        # Combine into a single actor-critic model
        self.model = actorCritic(self.actor, self.critic)

        self.ONE_TENSOR = torch.tensor(1.0, device=self.device)

        try:
            # Try to get the optimizer from torch.optim based on the provided name
            self.optimizer = getattr(torch.optim, optimizer_name)(self.model.parameters(), lr=self.learning_rate)
        except AttributeError:
            # Raise an error if the optimizer_name is not valid
            raise ValueError(f"Optimizer '{optimizer_name}' is not available in torch.optim.")

    def save_model_to_neptune(agent, neptune_run, run_id=None):
        """Save the trained model to Neptune - synchronous version"""
        import tempfile
        import torch
        import json
        import os
        from datetime import datetime

        # Create temporary files for the model components
        with tempfile.TemporaryDirectory() as temp_dir:
            # Save actor-critic network state dict
            if hasattr(agent, 'model') and agent.model is not None:
                actor_critic_path = os.path.join(temp_dir, 'actor_critic.pt')
                torch.save(agent.model.state_dict(), actor_critic_path)
            else:
                print("Warning: Agent does not have a model attribute to save.")
                return

            # Save complete model info
            model_info = {
                'state_space_dim': agent.INPUT_FEATURES,
                'action_space_dim': agent.action_dim,
                'hidden_dimensions': agent.HIDDEN_DIMENSIONS,
                'dropout': agent.DROPOUT,
                'learning_rate': agent.learning_rate,
                'discount_factor': agent.discount_factor,
                'epsilon': agent.epsilon,
                'entropy_coefficient': agent.entropy_coefficient,
                'batch_size': agent.batch_size,
                'ppo_steps': agent.PPO_steps,
                'model_architecture': str(agent.model) if hasattr(agent, 'model') and agent.model is not None else "N/A"
            }

            if run_id is not None:
                model_info['run_id'] = run_id

            model_info_path = os.path.join(temp_dir, 'model_info.json')
            with open(model_info_path, 'w') as f:
                json.dump(model_info, f, indent=2)

            # Upload to Neptune synchronously
            neptune_run["model/actor_critic"].upload(actor_critic_path)
            neptune_run["model/model_info"].upload(model_info_path)

            # Wait for uploads to complete
            neptune_run.wait()

            neptune_run["model/saved_at"] = datetime.now().isoformat()

            print(f"✓ Model saved to Neptune")

    def calculate_returns(self, rewards):
        returns = []
        cumulative_reward = 0
        for r in reversed(rewards):
            cumulative_reward = r +cumulative_reward*self.discount_factor
            returns.insert(0, cumulative_reward)
        returns = torch.tensor(returns).to(self.device)

        # Only normalize if we have more than one element to avoid std() warning
        if returns.numel() > 1:
            epsilon = 1e-8  # Small constant to avoid division by zero
            returns_std = returns.std()
            if not torch.isnan(returns_std) and returns_std >= epsilon:
                returns = (returns - returns.mean()) / (returns_std + epsilon)

        #I had conceptual trouble with normalizing the reward by an average, because it seemed to me since we're adding more rewards for earlier timesteps, the cumulative reward for earlier times would be a lot larger. But need to consider dicount facotr.
        # Future rewards contribute significantly to the cumulative return, so earlier timesteps will likely have larger returns.
        #if gamma is close to 0, future rewards have little influence, and the return at each timestep will closely resemble the immediate reward, meaning the pattern might not be as clear.
        return returns

    #The advantage is calculated as the difference between the value predicted by the critic and the expected return from the actions chosen by the actor according to the policy.
    def calculate_advantages(self, returns, values):
        advantages = returns - values

        # Only normalize if we have more than one element to avoid std() warning
        if advantages.numel() > 1:
            epsilon = 1e-8
            advantages_std = advantages.std()
            if not torch.isnan(advantages_std) and advantages_std >= epsilon:
                advantages = (advantages - advantages.mean()) / (advantages_std + epsilon)

        return advantages

    #The standard policy gradient loss is calculated as the product of the policy action probabilities and the advantage function
    #The standard policy gradietn loss cannot make corrections for abrupt policy changes. The surrogate loss modifies the standard loss to restrict the amount the policy can change in each iteration.
    #The surrogate loss is the minimum of (policy ratio X advantage function) and (clipped value of policy ratio X advantage function) where the policy ratio is between the action probabilities according to the old versus new policies and clipping restricts the value to a region near 1.

    def calculate_surrogate_loss(self, actions_log_probability_old, actions_log_probability_new, advantages):
        advantages = advantages.detach()
        # creates a new tensor that shares the same underlying data as the original tensor but breaks the computation graph. This means:
        # The new tensor is treated as a constant with no gradients.
        # Any operations involving this tensor do not affect the gradients of earlier computations in the graph.

        #If the advantages are not detached, the backpropagation of the loss computed using the surrogate_loss would affect both the actor and the critic networks
        # The surrogate loss is meant to update only the policy (actor).
        # Allowing gradients to flow back through the advantages would inadvertently update the critic, potentially disrupting its learning process.

        policy_ratio  = (actions_log_probability_new - actions_log_probability_old).exp()
        surrogate_loss_1 = policy_ratio*advantages
        surrogate_loss_2 = torch.clamp(policy_ratio, min =1.0-self.epsilon, max = 1.0+self.epsilon)*advantages
        surrogate_loss=torch.min(surrogate_loss_1, surrogate_loss_2)
        return surrogate_loss

    #TRAINING THE AGENT
    #Policy loss is the sum of the surrogate loss and the entropy bonus. It is used to update the actor (policy network)
    #Value loss is based on the difference between the value predicted by the critic and the returns (cumulative reward) generated by the policy. This loss is used to update the critic (value network) to make predictions more accurate.

    def calculate_losses(self, surrogate_loss, entropy, returns, value_pred):
        entropy_bonus = self.entropy_coefficient*entropy
        policy_loss = -(surrogate_loss+entropy_bonus).sum()
        value_loss = torch.nn.functional.smooth_l1_loss(returns, value_pred).sum() #helps to smoothen the loss function and makes it less sensitive to outliers.
        return policy_loss, value_loss

    def init_training(self):
        #create a set of buffers as empty arrays. To be used during training to store information
        states = []
        actions = []
        actions_log_probability = []
        values = []
        rewards = []
        done = False
        episode_reward = 0
        return states, actions, actions_log_probability, values, rewards, done, episode_reward

    def forward_pass(self, episode=None):  # Add episode_num parameter
        # Use different seed for each episode but still reproducible
        episode_seed = self.seed + episode
        state, _ = self.env.reset(seed=episode_seed)

        states, actions, actions_log_probability, values, rewards, done, episode_reward = self.init_training()

        state, _ = self.env.reset()  # Gymnasium format returns (obs, info)

        self.model.train() # Set model to training mode

        # Initialize constraint violation tracking
        episode_constraint_violations = None

        while True:
            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
                states.append(state_tensor)

                # Get action predictions and values
                action_mean, value_pred = self.model(state_tensor)

                # Split actor output into alpha and beta parameters
                alpha_raw, beta_raw = torch.split(action_mean, self.action_dim, dim=-1)

                # Ensure alpha, beta > 1 for well-behaved Beta distribution
                alpha = torch.nn.functional.softplus(alpha_raw).add_(self.ONE_TENSOR)
                beta = torch.nn.functional.softplus(beta_raw).add_(self.ONE_TENSOR)

                # Create Beta distribution for continuous actions in [0,1]
                dist = torch.distributions.Beta(alpha, beta)

                action = dist.sample()

                # No clamping needed - Beta distribution naturally outputs [0,1]
                action_clamped = action

                log_prob_action = dist.log_prob(action).sum(dim=-1)  # Sum over action dimensions

            # Step environment with numpy action
            action_np = action_clamped.detach().cpu().numpy().flatten()
            state, reward, terminated, truncated, info = self.env.step(action_np)

            # Track total constraint violation (over the full episode) for each constraint.
            if 'constraint_results' in info:
                if episode_constraint_violations is None:
                    # Initialize with the same shape as constraint_results
                    episode_constraint_violations = info['constraint_results'].copy()
                else:
                    # Accumulate violations for each constraint
                    episode_constraint_violations += info['constraint_results']

            done = terminated or truncated
            actions.append(action_clamped)
            actions_log_probability.append(log_prob_action)
            values.append(value_pred)
            rewards.append(reward)
            episode_reward += reward

            if done:
                break
        states=torch.cat(states).to(self.device)#converts the list of individual states into a sinlem tensor that is necessary for later processing
        #Creates a single tensor with dimensions like (N, state_dim), where: N is the number of states collected in the episode; state_dim is the dimensionality of each state.
        #torch.cat() expects a sequence (e.g. list or tuple) of PyTorch tensors as input.
        actions=torch.cat(actions).to(self.device)
        #Note that, in the loop, both state and action are PyTorch tensors so that states and actions are both lists of PyTorch tensors
        actions_log_probability=torch.cat(actions_log_probability).to(self.device)
        values=torch.cat(values).squeeze(-1).to(self.device)# .squeeze removes a dimension of size 1 only from tensor at the specified position, in this case, -1, the last dimesion in the tensor. Note that .squeeze() does not do anything if the size of the dimension at the specified potision is not 1.
        # print(f"rewards NaNs: {torch.isnan(torch.tensor(rewards, dtype=torch.float32)).any()}")
        # print(f"values NaNs: {torch.isnan(torch.tensor(values, dtype=torch.float32)).any()}")
        returns = self.calculate_returns(rewards)
        advantages = self.calculate_advantages(returns, values)

        # print(f"Returns NaNs: {torch.isnan(returns).any()}")
        # print(f"advantages NaNs (after calculation): {torch.isnan(advantages).any()}")

        return episode_reward, states, actions, actions_log_probability, advantages, returns, episode_constraint_violations


    def update_policy(self,
            states,
            actions,
            actions_log_probability_old,
            advantages,
            returns):
        #print(f"Returns NaNs: {torch.isnan(returns).any()}")
        total_policy_loss = 0
        total_value_loss = 0
        actions_log_probability_old = actions_log_probability_old.detach()
        actions=actions.detach()

        # print(f"Returns NaNs: {torch.isnan(returns).any()}")
        # print(f"advantages NaNs (after calculation): {torch.isnan(advantages).any()}")


        #detach() is used to remove the tensor from the computation graph, meaning no gradients will be calculated for that tensor when performing backpropagation.
        #In this context, it's used to ensure that the old actions and log probabilities do not participate in the gradient computation during the optimization of the policy, as we want to update the model based on the current policy rather than the old one.
        #print(type(states), type(actions),type(actions_log_probability_old), type(advantages), type(returns))
        training_results_dataset= TensorDataset(
                states,
                actions,
                actions_log_probability_old,
                advantages,
                returns) #TensorDataset class expects all the arguments passed to it to be tensors (or other compatible types like NumPy arrays, which will be automatically converted to tensor
        batch_dataset = DataLoader(
                training_results_dataset,
                batch_size=self.batch_size,
                shuffle=False)
        #creates a DataLoader instance in PyTorch, which is used to load the training_results_dataset in batches during training.
        #batch_size defines how many samples will be included in each batch. The dataset will be divided into batches of size BATCH_SIZE. The model will then process one batch at a time, rather than all of the data at once,
        #shuffle argument controls whether or not the data will be shuffled before being split into batches.
        #Because shuffle is false, dataloader will provide the batches in the order the data appears in training_results_dataset. In this case, the batches will be formed from consecutive entries in the dataset, and the observations will appear in the same sequence as they are stored in the dataset.
        for _ in range(self.PPO_steps):
            for batch_idx, (states,actions,actions_log_probability_old, advantages, returns) in enumerate(batch_dataset):
                #get new log prob of actions for all input states
                action_mean, value_pred = self.model(states)
                value_pred = value_pred.squeeze(-1)

                # For continuous actions with Beta distribution
                alpha_raw, beta_raw = torch.split(action_mean, self.action_dim, dim=-1)

                # Ensure alpha, beta > 1 for well-behaved Beta distribution
                alpha = torch.nn.functional.softplus(alpha_raw).add_(self.ONE_TENSOR)
                beta = torch.nn.functional.softplus(beta_raw).add_(self.ONE_TENSOR)

                probability_distribution_new = torch.distributions.Beta(alpha, beta)
                entropy = probability_distribution_new.entropy().sum(dim=-1)

                #estimate new log probabilities using old actions
                actions_log_probability_new = probability_distribution_new.log_prob(actions).sum(dim=-1)
                # # Check for NaN or Inf in log probabilities
                # if torch.isnan(actions_log_probability_old).any() or torch.isinf(actions_log_probability_old).any():
                #     print("NaN or Inf detected in actions_log_probability_old!")
                #     return  # You can return or handle this case as needed

                # if torch.isnan(actions_log_probability_new).any() or torch.isinf(actions_log_probability_new).any():
                #     print("NaN or Inf detected in actions_log_probability_new!")
                #     return  # You can return or handle this case as needed

                # print(f"actions_log_probability_old NaNs: {torch.isnan(actions_log_probability_old).any()}")
                # print(f"actions_log_probability_new NaNs: {torch.isnan(actions_log_probability_new).any()}")
                # print(f"advantages NaNs: {torch.isnan(advantages).any()}")

                surrogate_loss = self.calculate_surrogate_loss(
                    actions_log_probability_old,
                    actions_log_probability_new,
                    advantages
                )

                # print(f"Surrogate Loss NaNs: {torch.isnan(surrogate_loss).any()}")
                # print(f"Entropy NaNs: {torch.isnan(entropy).any()}")
                # print(f"Returns NaNs: {torch.isnan(returns).any()}")
                # print(f"Value Predictions NaNs: {torch.isnan(value_pred).any()}")

                policy_loss, value_loss = self.calculate_losses(
                    surrogate_loss,
                    entropy,
                    returns,
                    value_pred
                )
                self.optimizer.zero_grad() #clear existing gradietns in the optimizer (so that these don't propagate accross multiple .backward(). Ensures each optimization step uses only the gradients computed during the current batch.

                # Skip backward pass if loss is NaN
                if torch.isnan(policy_loss).any():
                    print("NaN detected in policy_loss - skipping backward pass!")
                    continue
                if torch.isnan(value_loss).any():
                    print("NaN detected in value_loss - skipping backward pass!")
                    continue

                policy_loss.backward() #computes gradients for policy_loss with respect to the agent's parameters
                # #Check for NaN gradients after policy_loss backward
                # for param in self.model.parameters():
                #     if param.grad is not None:  # Check if gradients exist for this parameter
                #         if torch.isnan(param.grad).any():
                #             print("NaN gradient detected in policy_loss!")
                # #             return
                value_loss.backward()
                # Check for NaN gradients after value_loss backwardor param in self.model.parameters():
                # for param in self.model.parameters():
                #     if param.grad is not None:  # Check if gradients exist for this parameter
                #         if torch.isnan(param.grad).any():
                #             print("NaN gradient detected in value_loss!")
                #             return

                self.optimizer.step()
                #The update step is based on the learning rate and other hyperparameters of the optimizer
                # The parameters of the agent are adjusted to reduce the policy and value losses.
                total_policy_loss += policy_loss.item() #accumulate the scalar value of the policy loss for logging/ analysis
                #policy_loss.item() extracts the numerical value of the loss tensor (detaching it from the computational graph).
                #This value is added to total_policy_loss to compute the cumulative loss over all batches in the current PPO step.
                #Result: tracks the total policy loss for the current training epoch
                # The loss over the whole dataset is the sum of the losses over all batches.
                #The training dataset is split into batches during the training process. Each batch represents a subset of the collected training data from one episode.
                # Loss calculation is performed for each batch (policy loss and value loss)
                # for each batch, gradients are calculated with respect to the total loss for that batch and the optimizer then updates the network parameters using these gradients.
                # this is because the surrogate loss is only calculated over a single batch of data
                #look at the formula for surrogate loss.
                # It is written in terms of an expectation ˆ Et[. . .] that indicates the empirical average over a finite batch of samples.
                # This means you have collected a set of data (time steps) from the environment, and you're averaging over these data points. The hat symbol implies you're approximating the true expectation with a finite sample of data from the environment. This empirical average can be computed as the mean of values from the sampled transitions
                # the expectation is taken over all the data you've collected
                #If you're training with multiple batches (i.e., collecting data in chunks), then you can think of the expectation as being computed over each batch.
                #The overall expectation can indeed be seen as the sum of expectations computed for each batch, but The expectation of the sum is generally not exactly equal to the sum of the expectations unless the samples are independent, but in practical reinforcement learning algorithms, it's typically a good enough approximation
                #For samples to be independent, the outcome of one sample must not provide any information about the outcome of another. Specifically, in the context of reinforcement learning, this means that the states, actions, rewards, and subsequent states observed in different time steps or different episodes should be independent of each other.
                total_value_loss += value_loss.item()
                #Notice that we are calculating an empirical average, which is already an approximation on the true value (the true expectation would be the average over an infinite amount of data, and the empirical average is the average over the finite amount of data that we have collected).
                #But furthermore, we are approximating even the empirical average istelf. The empirical average is the average over all our collected datal, but here we actually batch our data, calculate average over each batch and then sum these averages, which is not exaclty equal to the average of the sums (but is a decent approximation).
        return total_policy_loss / self.PPO_steps, total_value_loss / self.PPO_steps

    def train(self):
        train_rewards = []
        avg_rewards =[]
        violation_rates=[]
        # test_rewards = []
        policy_losses = []
        value_losses = []
        #lens = []


        for episode in range(1, self.max_episodes + 1):

            #check timing for forward pass
            # Perform a forward pass and collect experience
            train_reward, states, actions, actions_log_probability, advantages, returns, episode_constraint_violations = self.forward_pass(episode)

            #check timing for policy update
            # Update the policy using the experience collected
            policy_loss, value_loss = self.update_policy(
                states,
                actions,
                actions_log_probability,
                advantages,
                returns)

            # Log the results
            policy_losses.append(policy_loss)
            value_losses.append(value_loss)
            train_rewards.append(train_reward)
            self.run["policy_loss"].log(policy_loss)
            self.run["value_loss"].log(value_loss)
            self.run["train_reward"].log(train_reward)
            # Total episode return (discounted)
            self.run["episode_return"].log(returns[0].item())  # The discounted version
            # Log episode constraint violations array (total violation per constraint)
            # Log individual total constraint violations for episode
            for i, violation in enumerate(episode_constraint_violations):
                self.run[f"constraint_{i}"].log(float(violation))

            # Log sum of violations for all constraints
            self.run["total_constraint_violation"].log(float(np.sum(episode_constraint_violations)))


            # Print with new metrics and save model at interval
            if episode % self.print_interval == 0:
                print(f'Episode: {episode:3} ')
                self.save_model_to_neptune(self.run)
                #the model will get overridden each time you save. Neptune will keep only the most recent version of the model files.

        return train_rewards

def plot_train_rewards(train_rewards, reward_threshold):
    plt.figure(figsize=(12, 8))
    plt.plot(train_rewards, label='Training Reward')
    plt.xlabel('Episode', fontsize=20)
    plt.ylabel('Training Reward', fontsize=20)
    plt.hlines(reward_threshold, 0, len(train_rewards), color='y')
    plt.legend(loc='lower right')
    plt.grid()
    plt.show()

def plot_test_rewards(test_rewards, reward_threshold):
    plt.figure(figsize=(12, 8))
    plt.plot(test_rewards, label='Testing Reward')
    plt.xlabel('Episode', fontsize=20)
    plt.ylabel('Testing Reward', fontsize=20)
    plt.hlines(reward_threshold, 0, len(test_rewards), color='y')
    plt.legend(loc='lower right')
    plt.grid()
    plt.show()

def plot_losses(policy_losses, value_losses):
    plt.figure(figsize=(12, 8))
    plt.plot(value_losses, label='Value Losses')
    plt.plot(policy_losses, label='Policy Losses')
    plt.xlabel('Episode', fontsize=20)
    plt.ylabel('Loss', fontsize=20)
    plt.legend(loc='lower right')
    plt.grid()
    plt.show()

In [8]:
class EnvDispatchReplacement(EnvDispatchConstr):
    """
    Environment using the Replacement reward method instead of Summation.

    Inherits from Env2Gen1LoadConstr but modifies the reward calculation
    to implement the replacement method from the RL-OPF paper.
    """

    def __init__(self,network_file,no_convergence_lpf_penalty, reward_scale_factor, constraint_penalty_factor=10, seed=None, offset_k=2500):
        """
        Initialize the replacement reward environment.

        Parameters:
        -----------
        network_file : str
            Path to the PyPSA network file
        episode_length : int, optional
            Length of episodes (defaults to total snapshots)
        constraint_penalty_factor : float
            Penalty factor for constraint violations
        offset_k : float
            Offset value for replacement reward method
        test_start_date : str
            Start date for test period (everything from this date onwards is test data)
        fixed_episode_length : int, optional
            Fixed episode length if specified, otherwise episodes are variable
        """
        # Call parent constructor - this will initialize all base attributes
        super().__init__(network_file, no_convergence_lpf_penalty, reward_scale_factor, constraint_penalty_factor=10, seed=None)

        # Add replacement-specific attributes
        self.offset_k = offset_k
        self.reward_method = "replacement"

    def calculate_constrained_reward(self):
        """
        Calculate reward using replacement method with dynamic constraint checking.

        Replacement method:
        - If all constraints satisfied: return -J(s) + k
        - If constraints violated: return -P(s)
        """
        base_reward, constraint_results = self.calculate_reward_no_penalty()

        total_violation = np.sum(constraint_results)
        penalty = self.penalty_factor * (total_violation/self.n_constr)

        #Compute scaled base reward
        scaled_base_reward = base_reward

        # Apply replacement method
        if total_violation==0:
            # All constraints satisfied: return optimization reward + offset k
            constrained_reward = scaled_base_reward + self.offset_k
        else:
            constrained_reward = -1*penalty

        # Ensure reward is a scalar
        if hasattr(constrained_reward, '__len__'):
            constrained_reward = float(constrained_reward)

        return constrained_reward, constraint_results

    def get_reward_method_info(self):
        """
        Get information about the reward method being used.

        Returns:
        --------
        dict: Information about the reward method
        """
        return {
            'method': 'replacement',
            'offset_k': self.offset_k,
            'penalty_factor': self.penalty_factor,
            'train_snapshots': self.train_snapshots,
            'test_snapshots': self.test_snapshots,
            'test_start_date': str(self.test_start_date),
            'fixed_episode_length': self.fixed_episode_length
        }

# network_file_path= "/Users/antoniagrindrod/Documents/pypsa-earth_project/pypsa-earth-RL/networks/elec_s_10_ec_lc1.0_1h.nc"
# input_dir="/Users/antoniagrindrod/Documents/pypsa-earth_project/pypsa-earth-RL/RL/var_constraint_map"
# replacement_reward_offset=calculate_offset_k_initialization(network_file=network_file_path, input_dir=input_dir)

In [9]:
from google.colab import drive
drive.mount('/content/drive')
def execute_training_simple_constr(env_class, seed):
    import os
    import torch
    import neptune
    import random

    base_params = {
        "optimizer_name": "Adam",
        "MAX_EPISODES": 1000,
        "PRINT_INTERVAL": 10,
        "N_TRIALS": 8,
        "DROPOUT": 0,
        "network_file": "elec_s_10_ec_lc1.0_1h.nc",
        "optimization_result_file": "elec_s_10_ec_lc1.0_1h_Test_Objective.txt",
        "reward_scale_factor": 10**(-8),
        "no_convergence_lpf_penalty": 100,
        "env_class": env_class,
    }

    # Parameters to sweep (from your original code)
    sweep_params = {
        "LEARNING_RATE": [1e-4, 1e-3, 5e-3],
        "EPSILON": [0.1, 0.2, 0.3],
        "ENTROPY_COEFFICIENT": [0.01, 0.1],
        "HIDDEN_DIMENSIONS": [32, 128],
        "PPO_STEPS": [8, 16],
        "BATCH_SIZE": [128, 256],
        "DISCOUNT_FACTOR": [0.95, 0.99],
        "constraint_penalty_factor": [10**(-1), 0.5*10**(-1)],
        "seed": seed,
    }

    # Base priority configurations
    priority_configs = [
        {"LEARNING_RATE": 1e-3, "EPSILON": 0.2, "ENTROPY_COEFFICIENT": 0.01,
         "HIDDEN_DIMENSIONS": 128, "PPO_STEPS": 16, "BATCH_SIZE": 256, "DISCOUNT_FACTOR": 0.99,
         "constraint_penalty_factor": 10**(-1), "seed": seed},
        {"LEARNING_RATE": 1e-4, "EPSILON": 0.2, "ENTROPY_COEFFICIENT": 0.01,
         "HIDDEN_DIMENSIONS": 128, "PPO_STEPS": 16, "BATCH_SIZE": 256, "DISCOUNT_FACTOR": 0.99,
         "constraint_penalty_factor": 10**(-1), "seed": seed},
        {"LEARNING_RATE": 1e-3, "EPSILON": 0.1, "ENTROPY_COEFFICIENT": 0.01,
         "HIDDEN_DIMENSIONS": 128, "PPO_STEPS": 16, "BATCH_SIZE": 256, "DISCOUNT_FACTOR": 0.99,
         "constraint_penalty_factor": 10**(-1), "seed": seed},
        {"LEARNING_RATE": 1e-3, "EPSILON": 0.2, "ENTROPY_COEFFICIENT": 0.01,
         "HIDDEN_DIMENSIONS": 128, "PPO_STEPS": 16, "BATCH_SIZE": 256, "DISCOUNT_FACTOR": 0.99,
         "constraint_penalty_factor": 0.5*10**(-1), "seed": seed},
        {"LEARNING_RATE": 1e-4, "EPSILON": 0.1, "ENTROPY_COEFFICIENT": 0.01,
         "HIDDEN_DIMENSIONS": 128, "PPO_STEPS": 16, "BATCH_SIZE": 256, "DISCOUNT_FACTOR": 0.99,
         "constraint_penalty_factor": 10**(-1), "seed": seed},
        {"LEARNING_RATE": 1e-3, "EPSILON": 0.1, "ENTROPY_COEFFICIENT": 0.01,
         "HIDDEN_DIMENSIONS": 128, "PPO_STEPS": 16, "BATCH_SIZE": 256, "DISCOUNT_FACTOR": 0.99,
         "constraint_penalty_factor": 0.5*10**(-1), "seed": seed},
    ]


    # Generate 10 additional random configurations
    def generate_random_config(seed_val):
        """Generate a random configuration from the sweep parameters"""
        config = {"seed": seed_val}
        for param, values in sweep_params.items():
            if param != "seed":  # seed is already set
                config[param] = random.choice(values)
        return config

    # Set random seed for reproducible random config generation
    random.seed(seed + 12345)  # Add offset to avoid conflicts with training seed

    additional_configs = []
    for i in range(10):
        random_config = generate_random_config(seed)
        additional_configs.append(random_config)

    # Combine all configurations
    all_configs = priority_configs + additional_configs

    # Store results from all configurations
    all_results = []

    # Loop through each configuration
    for config_idx, config in enumerate(all_configs):
        config_type = "priority" if config_idx < len(priority_configs) else "random"
        print(f"Running configuration {config_idx + 1}/{len(all_configs)} ({config_type})")
        print(f"Config: {config}")

        # Merge base params with current config
        full_config = {**base_params, **config}

        # Set up paths
        network_file = full_config["network_file"]
        gdrive_base = '/content/drive/My Drive/Colab_Notebooks'#'./'  # or '/workspace/'
        network_file_path = os.path.join(gdrive_base, "networks_1_year_connected", network_file)

        # Set seeds
        current_seed = full_config["seed"]
        set_all_seeds(current_seed)

        # Calculate offset (only needed for replacement env)
        replacement_reward_offset = None
        if full_config["env_class"] == "EnvDispatchReplacement":
            replacement_reward_offset,_,_ = calculate_offset_k_initialization(
                network_file=network_file_path,
                seed=current_seed,  # Pass the experiment seed
                reward_scale_factor=full_config["reward_scale_factor"]
            )

        # Initialize Neptune
        API_TOKEN = "eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI1ODQwZjA5OS05MDFmLTQ2MWYtYWJiMi0yMDkzYmEwNzgzMzEifQ=="
        PROJECT_NAME = "EnergyGridRL/elec-s-10-ec-lc10-1h-sweep"
        run = neptune.init_run(
            project=PROJECT_NAME,
            api_token=API_TOKEN
        )

        # Log parameters
        for key, value in full_config.items():
            if key not in ['run_id', 'config_name']:
                run[f"parameters/{key}"] = value

        # Log config metadata
        run["parameters/config_idx"] = config_idx
        run["parameters/config_type"] = config_type

        if replacement_reward_offset is not None:
            run["replacement_reward_offset"] = replacement_reward_offset

        # Create environment based on env_class
        if full_config["env_class"] == "EnvDispatchConstr":
            env = EnvDispatchConstr(
                network_file=network_file_path,
                no_convergence_lpf_penalty=full_config["no_convergence_lpf_penalty"],
                reward_scale_factor=full_config["reward_scale_factor"],
                constraint_penalty_factor=full_config["constraint_penalty_factor"],
                seed=current_seed
            )

        elif full_config["env_class"] == "EnvDispatchReplacement":
            env = EnvDispatchReplacement(
                network_file=network_file_path,
                no_convergence_lpf_penalty=full_config["no_convergence_lpf_penalty"],
                reward_scale_factor=full_config["reward_scale_factor"],
                constraint_penalty_factor=full_config["constraint_penalty_factor"],
                offset_k=replacement_reward_offset,
                seed=current_seed
            )
        else:
            raise ValueError(f"Unknown environment class: {full_config['env_class']}")

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        agent = PPO_agent(
            env=env,
            run=run,
            device=device,
            hidden_dimensions=full_config["HIDDEN_DIMENSIONS"],
            dropout=full_config["DROPOUT"],
            discount_factor=full_config["DISCOUNT_FACTOR"],
            optimizer_name=full_config["optimizer_name"],
            max_episodes=full_config["MAX_EPISODES"],
            print_interval=full_config["PRINT_INTERVAL"],
            PPO_steps=full_config["PPO_STEPS"],
            n_trials=full_config["N_TRIALS"],
            epsilon=full_config["EPSILON"],
            entropy_coefficient=full_config["ENTROPY_COEFFICIENT"],
            learning_rate=full_config["LEARNING_RATE"],
            batch_size=full_config["BATCH_SIZE"],
            seed=current_seed
        )

        # Train
        try:
            train_rewards = agent.train()

            # Store results
            result = {
                'config_idx': config_idx,
                'config_type': config_type,
                'config': full_config.copy(),
                'train_rewards': train_rewards,
                'final_reward': train_rewards[-1] if train_rewards else None
            }
            all_results.append(result)

            print(f"Configuration {config_idx + 1} completed successfully")

        except Exception as e:
            print(f"Configuration {config_idx + 1} failed with error: {e}")
            result = {
                'config_idx': config_idx,
                'config_type': config_type,
                'config': full_config.copy(),
                'error': str(e),
                'train_rewards': None,
                'final_reward': None
            }
            all_results.append(result)

        finally:
            # Always stop the Neptune run
            run.stop()

    # Print summary
    print(f"\nTraining completed for all {len(all_configs)} configurations")
    print(f"Priority configs: {len(priority_configs)}")
    print(f"Random configs: {len(additional_configs)}")
    successful_runs = [r for r in all_results if 'error' not in r]
    print(f"Successful runs: {len(successful_runs)}/{len(all_configs)}")

    # Print summary by type
    priority_successful = [r for r in all_results if r.get('config_type') == 'priority' and 'error' not in r]
    random_successful = [r for r in all_results if r.get('config_type') == 'random' and 'error' not in r]
    print(f"Priority successful: {len(priority_successful)}/{len(priority_configs)}")
    print(f"Random successful: {len(random_successful)}/{len(additional_configs)}")

    return all_results

Mounted at /content/drive


In [15]:
execute_training_simple_constr(env_class="EnvDispatchReplacement", seed=35)

Running configuration 1/16 (priority)
Config: {'LEARNING_RATE': 0.001, 'EPSILON': 0.2, 'ENTROPY_COEFFICIENT': 0.01, 'HIDDEN_DIMENSIONS': 128, 'PPO_STEPS': 16, 'BATCH_SIZE': 256, 'DISCOUNT_FACTOR': 0.99, 'constraint_penalty_factor': 0.1, 'seed': 35}
Sampling 1000 random states to calculate offset k...




Fixed ZA0 0 PHS: set max_hours to 8.0
Fixed ZA0 5 PHS: set max_hours to 8.0
Fixed ZA0 6 hydro: corrected max_hours from 3831.6270020496813 to 6.0
=== FIXING ARTIFICIAL LINES WITH REASONABLE CAPACITY ===
Found 3 artificial lines to fix:

 Fixing: lines new ZA0 4 <-> ZA2 0 AC
    Connected buses: ZA0 4 ↔ ZA2 0
    Bus demands: ZA0 4: 15945.8 MW, ZA2 0: 452.6 MW
    s_nom: 0.0 → 47837.3 MW
    s_nom_extendable: → False

 Fixing: lines new ZA0 0 <-> ZA1 0 AC
    Connected buses: ZA0 0 ↔ ZA1 0
    Bus demands: ZA0 0: 3513.0 MW, ZA1 0: 1386.9 MW
    s_nom: 0.0 → 10538.9 MW
    s_nom_extendable: → False

 Fixing: lines new ZA0 0 <-> ZA3 0 AC
    Connected buses: ZA0 0 ↔ ZA3 0
    Bus demands: ZA0 0: 3513.0 MW, ZA3 0: 721.1 MW
    s_nom: 0.0 → 10538.9 MW
    s_nom_extendable: → False
Found 12 offshore wind generators:
['ZA0 1 offwind-ac', 'ZA0 1 offwind-dc', 'ZA0 5 offwind-ac', 'ZA0 5 offwind-dc', 'ZA0 7 offwind-ac', 'ZA0 7 offwind-dc', 'ZA0 8 offwind-ac', 'ZA0 8 offwind-dc', 'ZA1 0 offwind-ac



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/EnergyGridRL/elec-s-10-ec-lc10-1h-sweep/e/EL-21




Fixed ZA0 0 PHS: set max_hours to 8.0
Fixed ZA0 5 PHS: set max_hours to 8.0
Fixed ZA0 6 hydro: corrected max_hours from 3831.6270020496813 to 6.0
=== FIXING ARTIFICIAL LINES WITH REASONABLE CAPACITY ===
Found 3 artificial lines to fix:

 Fixing: lines new ZA0 4 <-> ZA2 0 AC
    Connected buses: ZA0 4 ↔ ZA2 0
    Bus demands: ZA0 4: 15945.8 MW, ZA2 0: 452.6 MW
    s_nom: 0.0 → 47837.3 MW
    s_nom_extendable: → False

 Fixing: lines new ZA0 0 <-> ZA1 0 AC
    Connected buses: ZA0 0 ↔ ZA1 0
    Bus demands: ZA0 0: 3513.0 MW, ZA1 0: 1386.9 MW
    s_nom: 0.0 → 10538.9 MW
    s_nom_extendable: → False

 Fixing: lines new ZA0 0 <-> ZA3 0 AC
    Connected buses: ZA0 0 ↔ ZA3 0
    Bus demands: ZA0 0: 3513.0 MW, ZA3 0: 721.1 MW
    s_nom: 0.0 → 10538.9 MW
    s_nom_extendable: → False
Found 12 offshore wind generators:
['ZA0 1 offwind-ac', 'ZA0 1 offwind-dc', 'ZA0 5 offwind-ac', 'ZA0 5 offwind-dc', 'ZA0 7 offwind-ac', 'ZA0 7 offwind-dc', 'ZA0 8 offwind-ac', 'ZA0 8 offwind-dc', 'ZA1 0 offwind-ac

KeyboardInterrupt: 

In [None]:
# execute_training_simple_constr(env_class="EnvDispatchConstr", seed=61)

In [None]:
# execute_training_simple_constr(env_class="EnvDispatchConstr", seed=7)

In [10]:
import tempfile
import torch
import json
import os
import neptune

def load_model_from_neptune(neptune_run, device='cpu'):
    """
    Load model state dict and configuration from Neptune

    Args:
        neptune_run: Neptune run object (can be existing run or fetched run)
        device: Device to load the model on ('cpu' or 'cuda')

    Returns:
        tuple: (model_state_dict, model_config)
    """
    with tempfile.TemporaryDirectory() as temp_dir:
        # Download actor-critic model state dict
        actor_critic_path = os.path.join(temp_dir, 'actor_critic.pt')
        neptune_run["model/actor_critic"].download(actor_critic_path)

        # Download model configuration
        model_info_path = os.path.join(temp_dir, 'model_info.json')
        neptune_run["model/model_info"].download(model_info_path)

        # Load state dict
        model_state_dict = torch.load(actor_critic_path, map_location=device,weights_only=True)

        # Load configuration
        with open(model_info_path, 'r') as f:
            model_config = json.load(f)

        print(f"✓ Model loaded from Neptune")
        print(f"✓ Model architecture: {model_config.get('model_architecture', 'N/A')}")

        return model_state_dict, model_config

def reconstruct_agent_from_saved_model(model_state_dict, model_config, env, device='cpu', run=None):
    """
    Reconstruct a PPO agent from saved model state and configuration

    Args:
        model_state_dict: PyTorch state dict loaded from Neptune
        model_config: Configuration dictionary loaded from Neptune
        env: Environment instance (needed for observation/action space info)
        device: Device to load the model on ('cpu' or 'cuda')
        run: Neptune run object for logging (optional, can be None for inference)

    Returns:
        PPO_agent: Reconstructed agent ready for inference or further training
    """

    # Create agent with saved configuration
    # Note: Some parameters are not critical for inference and can use defaults
    agent = PPO_agent(
        env=env,
        device=device,
        run=run,  # Can be None for inference
        hidden_dimensions=model_config['hidden_dimensions'],
        dropout=model_config['dropout'],
        discount_factor=model_config['discount_factor'],
        max_episodes=1,  # Not needed for inference
        print_interval=1,  # Not needed for inference
        PPO_steps=model_config['ppo_steps'],
        n_trials=1,  # Not needed for inference
        epsilon=model_config['epsilon'],
        entropy_coefficient=model_config['entropy_coefficient'],
        learning_rate=model_config['learning_rate'],
        batch_size=model_config['batch_size'],
        optimizer_name='Adam',  # Default, not critical for inference
        seed=42,  # Default, can be any value
        training_start_time=None
    )

    # Load the saved state dict into the model
    agent.model.load_state_dict(model_state_dict)

    # Set model to evaluation mode
    agent.model.eval()

    print(f"✓ Agent reconstructed successfully")
    print(f"✓ Input features: {agent.INPUT_FEATURES}")
    print(f"✓ Action dimensions: {agent.action_dim}")
    print(f"✓ Hidden dimensions: {agent.HIDDEN_DIMENSIONS}")

    return agent

def load_trained_agent_from_neptune(run_id, project_name, env, device='cpu'):
    """
    Complete function to load a trained agent from Neptune using run ID

    Args:
        run_id: Neptune run ID (e.g., 'PROJECT-123')
        project_name: Neptune project name (e.g., 'username/project-name')
        env: Environment instance
        device: Device to load model on

    Returns:
        PPO_agent: Loaded agent ready for inference
    """

    # Initialize Neptune and fetch the run
    print(f"Fetching Neptune run: {run_id}")
    run = neptune.init_run(
        project=project_name,
        run=run_id,
        mode="read-only"
    )

    try:
        # Load model and config from Neptune
        model_state_dict, model_config = load_model_from_neptune(run, device)

        # Reconstruct agent
        agent = reconstruct_agent_from_saved_model(
            model_state_dict,
            model_config,
            env,
            device,
            run=None  # No need for logging during inference
        )

        return agent

    finally:
        # Close the Neptune run
        run.stop()

In [None]:
# !pip install ipdb
# import ipdb
def evaluate_deterministic(env, agent, optimal_objective_value, seed=35):
    """
    Deterministic evaluation of trained agent on test data.
    Uses Beta distribution mean for consistent, reproducible results.
    """

    # Set seed for full reproducibility
    set_all_seeds(seed)

    # Reset environment to test data start
    obs, info = env.reset(seed=seed)
    agent.model.eval()

    # Track metrics
    rl_total_reward = 0.0
    total_test_snapshots = env.total_snapshots
    rewards= []
    violations={}
    # Extract constraint names in the correct order
    slack_generators = env.network.generators[env.network.generators.control == "Slack"].index

    #remove this because don't check line constraints during training
    line_names = env.network.lines.index
    storage_names = env.storage_names

    # Build constraint names in the order they appear in info['constraint_results']
    constraint_names = []

    # p_min for each slack generator
    for gen_name in slack_generators:
        constraint_names.append(f"p_min_{gen_name}")

    # p_max for each slack generator
    for gen_name in slack_generators:
        constraint_names.append(f"p_max_{gen_name}")

    # # s_max for each line

    # soc_min for each storage
    for storage_name in storage_names:
        constraint_names.append(f"soc_min_{storage_name}")

    # soc_max for each storage
    for storage_name in storage_names:
        constraint_names.append(f"soc_max_{storage_name}")

    for line_name in line_names:
        constraint_names.append(f"s_max_{line_name}")

    constraint_names.append("redundant_entry1")
    constraint_names.append("redundant_entry2")
    first_violation_append=True
    # Run agent on all test snapshots
    for step in range(total_test_snapshots):
        # Get DETERMINISTIC action from PPO agent
        with torch.no_grad():
            obs_tensor = torch.FloatTensor(obs).unsqueeze(0).to(agent.device)
            action_mean, _ = agent.model(obs_tensor)

            action_dim = env.action_space.shape[0]
            alpha_raw, beta_raw = torch.split(action_mean, action_dim, dim=-1)
            alpha = torch.nn.functional.softplus(alpha_raw) + 1.0
            beta = torch.nn.functional.softplus(beta_raw) + 1.0

            # DETERMINISTIC: Use mean of Beta distribution
            action_tensor = alpha / (alpha + beta)
            action = action_tensor.detach().cpu().numpy().flatten()
        # Take step in environment
        obs, reward, terminated, truncated, info = env.step_test(action) #call step_test for base reward without penalty/without offset k
        if first_violation_append:
            for i in range(len(constraint_names)):
                violations["violations_"+constraint_names[i]]=np.zeros(total_test_snapshots)
            first_violation_append=False
        for i, violation in enumerate(info["constraint_results"]):
                violations["violations_"+constraint_names[i]][step]=violation

        rewards.append(reward)

        if terminated or truncated:
            break

    # Calculate MAPE using only valid samples (Equation 22)
    if len(valid_rewards) > 0:
        total_reward = sum(rewards)
        mape = abs(valid_total_reward - optimal_objective_value) / abs(optimal_objective_value) * 100.0

    avg_violation={}
    for i in range(len(constraint_names)):
        avg_violations["avg_violation_"+constraint_names[i]]=np.sum(violations["violations_"+constraint_names[i]])/len(violations["violations_"+constraint_names[i]])


    results = {
        'mape': mape,
        'rl_total_objective': sum(rewards),  # Keep total for comparison
        'optimal_total_objective': optimal_objective_value,
        'total_test_snapshots': total_test_snapshots,
        'evaluation_method': 'deterministic',
        'rewards':rewards,
        'violations':violations
    }

    combined_results = {**results, **violations,**avg_violation}

    return results

In [11]:
def evaluate_deterministic(env, agent, optimal_objective_value, seed=35):
    """
    Deterministic evaluation of trained agent on test data.
    Uses Beta distribution mean for consistent, reproducible results.
    """

    # Set seed for full reproducibility
    set_all_seeds(seed)

    # Reset environment to test data start
    obs, info = env.reset(seed=seed)
    agent.model.eval()

    # Track metrics
    rl_total_reward = 0.0
    total_test_snapshots = env.total_snapshots
    rewards = []
    violations = {}

    # Extract constraint names in the correct order
    slack_generators = env.network.generators[env.network.generators.control == "Slack"].index
    line_names = env.network.lines.index
    storage_names = env.storage_names

    # Build constraint names in the order they appear in info['constraint_results']
    constraint_names = []

    # p_min for each slack generator
    for gen_name in slack_generators:
        constraint_names.append(f"p_min_{gen_name}")

    # p_max for each slack generator
    for gen_name in slack_generators:
        constraint_names.append(f"p_max_{gen_name}")

    # soc_min for each storage
    for storage_name in storage_names:
        constraint_names.append(f"soc_min_{storage_name}")

    # soc_max for each storage
    for storage_name in storage_names:
        constraint_names.append(f"soc_max_{storage_name}")

    # s_max for each line (only for step_test which includes line constraints)
    for line_name in line_names:
        constraint_names.append(f"s_max_{line_name}")

    constraint_names.append("redundant_entry1")
    constraint_names.append("redundant_entry2")
    first_violation_append = True

    # Run agent on all test snapshots
    for step in range(total_test_snapshots):
        # Get DETERMINISTIC action from PPO agent
        with torch.no_grad():
            obs_tensor = torch.FloatTensor(obs).unsqueeze(0).to(agent.device)
            action_mean, *_ = agent.model(obs_tensor)  # Fixed: obs_tensor instead of obs*tensor

            action_dim = env.action_space.shape[0]
            alpha_raw, beta_raw = torch.split(action_mean, action_dim, dim=-1)
            alpha = torch.nn.functional.softplus(alpha_raw) + 1.0
            beta = torch.nn.functional.softplus(beta_raw) + 1.0

            # DETERMINISTIC: Use mean of Beta distribution
            action_tensor = alpha / (alpha + beta)
            action = action_tensor.detach().cpu().numpy().flatten()

        # Take step in environment
        obs, reward, terminated, truncated, info = env.step_test(action)

        if first_violation_append:
            for i in range(len(constraint_names)):
                violations[f"violations_{constraint_names[i]}"] = np.zeros(total_test_snapshots)
            first_violation_append = False

        for i, violation in enumerate(info["constraint_results"]):
            violations[f"violations_{constraint_names[i]}"][step] = violation

        rewards.append(reward)
        if terminated or truncated:
            break

    # Calculate MAPE using total rewards (Fixed: removed undefined variables)
    total_reward = sum(rewards)
    mape = abs(total_reward - optimal_objective_value) / abs(optimal_objective_value) * 100.0

    # Calculate average violations
    avg_violations = {}  # Fixed: consistent variable name
    for i in range(len(constraint_names)):
        constraint_name = constraint_names[i]
        violation_key = f"violations_{constraint_name}"
        avg_key = f"avg_violation_{constraint_name}"
        avg_violations[avg_key] = np.sum(violations[violation_key]) / len(violations[violation_key])

    results = {
        'mape': mape,
        'rl_total_objective': total_reward,  # Fixed: use total_reward instead of sum(rewards)
        'optimal_total_objective': optimal_objective_value,
        'total_test_snapshots': total_test_snapshots,
        'evaluation_method': 'deterministic',
        'rewards': rewards,
        'violations': violations
    }

    # Combine all results
    combined_results = {**results, **violations, **avg_violations}
    return combined_results  # Fixed: return combined_results instead of results

In [12]:
import time
def evaluate_trained_agent(run_id, env_class, seed=42):
    """
    Evaluate a trained PPO agent from Neptune

    Args:
        run_id: Neptune run ID to load the trained agent from
        seed: Random seed for reproducible evaluation
    """

    # Set all seeds for reproducibility
    set_all_seeds(seed)

    # Configuration (should match training config for environment setup)
    config = {
        "network_file": "elec_s_10_ec_lc1.0_1h.nc",
        "optimization_result_file": "elec_s_10_ec_lc1.0_1h_Test_Objective.txt",
        "constraint_penalty_factor": 1,
        "env_class": env_class,
        "seed": seed
    }

    # Set up paths
    gdrive_base = '/content/drive/My Drive/Colab_Notebooks'  # or '/workspace/'
    network_file_path = os.path.join(gdrive_base, "networks_1_year_connected", config["network_file"])
    optimization_result_path = os.path.join(gdrive_base, "optimized_network", config["optimization_result_file"])

    # Load optimization result (optimal objective value)
    with open(optimization_result_path, 'r') as f:
        objective = float(f.read().strip())

    # Multiply by negative one since comparing to reward found by RL agent
    objective = -1 * objective

    # Calculate offset if using replacement environment
    replacement_reward_offset = None

    # Create test environment (same type as training)
    # Create environment based on env_class
    if config["env_class"] == "EnvDispatchConstr":
        env = EnvDispatchConstr(
            network_file=network_file_path,
            no_convergence_lpf_penalty=1,
            reward_scale_factor=1,
            constraint_penalty_factor=1,
            seed=seed
        )

    elif config["env_class"] == "EnvDispatchReplacement":
        env = EnvDispatchReplacement(
            network_file=network_file_path,
            no_convergence_lpf_penalty=1,
            reward_scale_factor=1,
            constraint_penalty_factor=1,
            offset_k=replacement_reward_offset,
            seed=seed
        )

    # Connect to the existing training run for logging evaluation results
    API_TOKEN = "eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI1ODQwZjA5OS05MDFmLTQ2MWYtYWJiMi0yMDkzYmEwNzgzMzEifQ=="
    PROJECT_NAME = "EnergyGridRL/elec-s-10-ec-lc10-1h-Sweep"

    # Reconnect to the existing training run
    print(f"Connecting to existing training run: {run_id}")
    eval_run = neptune.init_run(
        project=PROJECT_NAME,
        api_token=API_TOKEN,
        with_id=run_id,  # Connect to existing run using with_id
        mode="async"  # Use async mode to append to existing run
    )

    try:
        # Log evaluation metadata
        eval_run["evaluation/optimal_objective"] = objective
        eval_run["evaluation/seed"] = seed
        eval_run["evaluation/env_class"] = config["env_class"]
        eval_run["evaluation/network_file"] = config["network_file"]
        eval_run["evaluation/timestamp"] = time.time()

        if replacement_reward_offset is not None:
            eval_run["evaluation/replacement_reward_offset"] = replacement_reward_offset

        # Load trained agent from Neptune (using a separate read-only connection)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Loading agent from run: {run_id}")

        # Create read-only connection for loading the model
        read_run = neptune.init_run(
            project=PROJECT_NAME,
            api_token=API_TOKEN,
            with_id=run_id,  # Use with_id instead of run
            mode="read-only"
        )

        try:
            # Load model and config from Neptune
            model_state_dict, model_config = load_model_from_neptune(read_run, device)

            # Reconstruct agent
            agent = reconstruct_agent_from_saved_model(
                model_state_dict,
                model_config,
                env,
                device,
                run=None
            )
        finally:
            read_run.stop()

        print("Agent loaded successfully. Starting evaluation...")

        # Evaluate the agent
        test_results = evaluate_deterministic(env, agent, objective, seed=seed)

        # Log all test results to the same training run
        for key, value in test_results.items():
            if key not in ['rewards', 'violations']:  # Log individual arrays separately if needed
                eval_run[f"evaluation/{key}"] = value

        # # Log rewards and violations as series if you want to plot them
        # if 'rewards' in test_results:
        #     for i, reward in enumerate(test_results['rewards']):
        #         eval_run[f"evaluation/step_rewards"].log(reward)

        # if 'violations' in test_results:
        #     for i, violation in enumerate(test_results['violations']):
        #         eval_run[f"evaluation/step_violations"].log(violation)

        return test_results

    finally:
        # Stop the evaluation run
        eval_run.stop()

def evaluate_multiple_agents(run_ids, seed=42):
    """
    Evaluate multiple trained agents and compare results

    Args:
        run_ids: List of Neptune run IDs to evaluate
        seed: Random seed for reproducible evaluation

    Returns:
        dict: Comparison results for all agents
    """

    comparison_results = {}

    for run_id in run_ids:
        print(f"\nEvaluating agent from run: {run_id}")
        try:
            results = evaluate_trained_agent(run_id, seed=seed)
            comparison_results[run_id] = results
        except Exception as e:
            print(f"Error evaluating run {run_id}: {str(e)}")
            comparison_results[run_id] = None

    # Print comparison summary
    print("\n" + "="*80)
    print("COMPARISON SUMMARY")
    print("="*80)
    print(f"{'Run ID':<15} {'MAPE (%)':<10} {'Violations (%)':<15} {'Total Reward':<15}")
    print("-" * 80)

    for run_id, results in comparison_results.items():
        if results is not None:
            print(f"{run_id:<15} {results['mape']:<10.2f} {results['constraint_violation_percentage']:<15.2f} {results['rl_total_objective']:<15.2f}")
        else:
            print(f"{run_id:<15} {'ERROR':<10} {'ERROR':<15} {'ERROR':<15}")

    print("="*80)

    return comparison_results

In [13]:
test_results=evaluate_trained_agent("EL-14", "EnvDispatchReplacement", seed=35)



Fixed ZA0 0 PHS: set max_hours to 8.0
Fixed ZA0 5 PHS: set max_hours to 8.0
Fixed ZA0 6 hydro: corrected max_hours from 3831.6270020496813 to 6.0
=== FIXING ARTIFICIAL LINES WITH REASONABLE CAPACITY ===
Found 3 artificial lines to fix:

 Fixing: lines new ZA0 4 <-> ZA2 0 AC
    Connected buses: ZA0 4 ↔ ZA2 0
    Bus demands: ZA0 4: 15945.8 MW, ZA2 0: 452.6 MW
    s_nom: 0.0 → 47837.3 MW
    s_nom_extendable: → False

 Fixing: lines new ZA0 0 <-> ZA1 0 AC
    Connected buses: ZA0 0 ↔ ZA1 0
    Bus demands: ZA0 0: 3513.0 MW, ZA1 0: 1386.9 MW
    s_nom: 0.0 → 10538.9 MW
    s_nom_extendable: → False

 Fixing: lines new ZA0 0 <-> ZA3 0 AC
    Connected buses: ZA0 0 ↔ ZA3 0
    Bus demands: ZA0 0: 3513.0 MW, ZA3 0: 721.1 MW
    s_nom: 0.0 → 10538.9 MW
    s_nom_extendable: → False
Found 12 offshore wind generators:
['ZA0 1 offwind-ac', 'ZA0 1 offwind-dc', 'ZA0 5 offwind-ac', 'ZA0 5 offwind-dc', 'ZA0 7 offwind-ac', 'ZA0 7 offwind-dc', 'ZA0 8 offwind-ac', 'ZA0 8 offwind-dc', 'ZA1 0 offwind-ac



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/EnergyGridRL/elec-s-10-ec-lc10-1h-sweep/e/EL-14
Loading agent from run: EL-14
[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/EnergyGridRL/elec-s-10-ec-lc10-1h-sweep/e/EL-14
✓ Model loaded from Neptune
✓ Model architecture: actorCritic(
  (actor): BackboneNetwork(
    (neuralnet): Sequential(
      (0): Linear(in_features=45, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=92, bias=True)
    )
  )
  (critic): BackboneNetwork(
    (neuralnet): Sequential(
      (0): Linear(in_features=45, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
Initialized Beta parameters to produce uniform-like distribut

        Convert the value to a supported type, such as a string or float, or use stringify_unsupported(obj)
        for dictionaries or collections that contain unsupported values.
        For more, see https://docs-legacy.neptune.ai/help/value_of_unsupported_type


[neptune] [info   ] All 36 operations synced, thanks for waiting!
[neptune] [info   ] Explore the metadata in the Neptune app: https://app.neptune.ai/EnergyGridRL/elec-s-10-ec-lc10-1h-sweep/e/EL-14/metadata


In [15]:
network_file = "elec_s_10_ec_lc1.0_1h.nc"

gdrive_base = '/content/drive/My Drive/Colab_Notebooks'  # or '/workspace/'
network_file_path = os.path.join(gdrive_base, "networks_1_year_connected", network_file)

env = EnvDispatchReplacement(
            network_file=network_file_path,
            no_convergence_lpf_penalty=1,
            reward_scale_factor=1,
            constraint_penalty_factor=1,
            offset_k=0,
            seed=35
        )
# Extract constraint names in the correct order
slack_generators = env.network.generators[env.network.generators.control == "Slack"].index
line_names = env.network.lines.index
storage_names = env.storage_names

# Build constraint names in the order they appear in info['constraint_results']
constraint_names = []

# p_min for each slack generator
for gen_name in slack_generators:
    constraint_names.append(f"p_min_{gen_name}")

# p_max for each slack generator
for gen_name in slack_generators:
    constraint_names.append(f"p_max_{gen_name}")

# soc_min for each storage
for storage_name in storage_names:
    constraint_names.append(f"soc_min_{storage_name}")

# soc_max for each storage
for storage_name in storage_names:
    constraint_names.append(f"soc_max_{storage_name}")

# s_max for each line (only for step_test which includes line constraints)
for line_name in line_names:
    constraint_names.append(f"s_max_{line_name}")

constraint_names.append("redundant_entry1")
constraint_names.append("redundant_entry2")



Fixed ZA0 0 PHS: set max_hours to 8.0
Fixed ZA0 5 PHS: set max_hours to 8.0
Fixed ZA0 6 hydro: corrected max_hours from 3831.6270020496813 to 6.0
=== FIXING ARTIFICIAL LINES WITH REASONABLE CAPACITY ===
Found 3 artificial lines to fix:

 Fixing: lines new ZA0 4 <-> ZA2 0 AC
    Connected buses: ZA0 4 ↔ ZA2 0
    Bus demands: ZA0 4: 15945.8 MW, ZA2 0: 452.6 MW
    s_nom: 0.0 → 47837.3 MW
    s_nom_extendable: → False

 Fixing: lines new ZA0 0 <-> ZA1 0 AC
    Connected buses: ZA0 0 ↔ ZA1 0
    Bus demands: ZA0 0: 3513.0 MW, ZA1 0: 1386.9 MW
    s_nom: 0.0 → 10538.9 MW
    s_nom_extendable: → False

 Fixing: lines new ZA0 0 <-> ZA3 0 AC
    Connected buses: ZA0 0 ↔ ZA3 0
    Bus demands: ZA0 0: 3513.0 MW, ZA3 0: 721.1 MW
    s_nom: 0.0 → 10538.9 MW
    s_nom_extendable: → False
Found 12 offshore wind generators:
['ZA0 1 offwind-ac', 'ZA0 1 offwind-dc', 'ZA0 5 offwind-ac', 'ZA0 5 offwind-dc', 'ZA0 7 offwind-ac', 'ZA0 7 offwind-dc', 'ZA0 8 offwind-ac', 'ZA0 8 offwind-dc', 'ZA1 0 offwind-ac

In [16]:
print("MAPE: ",test_results['mape'],"%")
for i in range(len(constraint_names)):
        constraint_name = constraint_names[i]
        violation_key = f"violations_{constraint_name}"
        avg_key = f"avg_violation_{constraint_name}"
        print(constraint_name, ":",test_results[avg_key])

MAPE:  205.67762552470347 %
p_min_ZA0 0 coal : 0.0
p_min_ZA2 0 onwind : 0.0
p_max_ZA0 0 coal : 0.9138832624457134
p_max_ZA2 0 onwind : 0.0
soc_min_ZA0 0 PHS : 0.993249433326811
soc_min_ZA0 5 PHS : 0.0
soc_min_ZA0 6 hydro : 0.0
soc_max_ZA0 0 PHS : 0.0
soc_max_ZA0 5 PHS : 0.0
soc_max_ZA0 6 hydro : 0.0
s_max_0 : 0.0
s_max_1 : 0.0
s_max_10 : 0.0
s_max_11 : 0.0
s_max_12 : 0.0
s_max_13 : 0.0
s_max_14 : 0.0
s_max_15 : 0.0
s_max_2 : 0.0
s_max_3 : 0.0
s_max_4 : 0.0
s_max_5 : 0.0
s_max_6 : 0.0
s_max_7 : 0.0
s_max_8 : 0.0
s_max_9 : 0.0
s_max_lines new ZA0 4 <-> ZA2 0 AC : 0.0
s_max_lines new ZA0 0 <-> ZA1 0 AC : 0.0
s_max_lines new ZA0 0 <-> ZA3 0 AC : 0.0
redundant_entry1 : 0.0
redundant_entry2 : 0.0
