In [4]:
import json
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import List, Dict, Any, Tuple
from pathlib import Path
import argparse

# Import your existing framework components
from agentic_energy.schemas import BatteryParams, DayInputs, SolveRequest, SolveResponse
from agentic_energy.data_loader import BatteryDataLoader
from agentic_energy.milp.milp_mcp_server import solve_daily_milp

In [5]:
DATASET_CONFIG = {
    'italy': {
        'file': 'Italy_data.csv',
        'train_years': [2022],
        'test_years': [2023],
        'price_column': 'prices',  # Adjust based on your CSV structure
        'demand_column': 'consumption',  # Adjust based on your CSV structure
        'timestamp_column': 'timestamps',  # Adjust based on your CSV structure
    },
    'germany': {
        'file': 'Germany_energy_Data.csv',
        'train_years': [2019, 2020, 2021, 2022],
        'test_years': [2023],
        'price_column': 'prices',
        'demand_column': 'consumption',
        'timestamp_column': 'timestamps',
    },
    'caiso': {
        'file': 'CAISO_data.csv',
        'train_years': [2021, 2022],
        'test_years': [2023],
        'price_column': 'prices',
        'demand_column': 'consumption',
        'timestamp_column': 'timestamps',
    }
}

In [6]:
class MultiDatasetLoader:
    """Load and prepare data from multiple ISOs"""
    
    def __init__(self, data_dir: str = "./data"):
        self.data_dir = Path(data_dir)
        
    def load_dataset(
        self,
        dataset_name: str,
        split: str = 'train'
    ) -> pd.DataFrame:
        """
        Load dataset for specified split (train or test)
        
        Args:
            dataset_name: 'italy', 'germany', or 'caiso'
            split: 'train' or 'test'
            
        Returns:
            DataFrame with columns: timestamp, price, demand
        """
        if dataset_name not in DATASET_CONFIG:
            raise ValueError(f"Unknown dataset: {dataset_name}. Use 'italy', 'germany', or 'caiso'")
        
        config = DATASET_CONFIG[dataset_name]
        file_path = self.data_dir / config['file']
        
        if not file_path.exists():
            raise FileNotFoundError(f"Dataset file not found: {file_path}")
        
        # Load CSV
        print(f"Loading {dataset_name} data from {file_path}...")
        df = pd.read_csv(file_path)
        
        # Ensure timestamp is datetime
        if config['timestamp_column'] in df.columns:
            df['timestamp'] = pd.to_datetime(df[config['timestamp_column']])
        else:
            raise ValueError(f"Timestamp column '{config['timestamp_column']}' not found in {file_path}")
        
        # Rename columns to standard names
        df = df.rename(columns={
            config['price_column']: 'price',
            config['demand_column']: 'demand'
        })
        
        # Extract year
        df['year'] = df['timestamp'].dt.year
        
        # Filter by split
        if split == 'train':
            years = config['train_years']
        elif split == 'test':
            years = config['test_years']
        else:
            raise ValueError(f"Split must be 'train' or 'test', got: {split}")
        
        df_filtered = df[df['year'].isin(years)].copy()
        
        print(f"  Loaded {len(df_filtered)} records from years {years}")
        print(f"  Date range: {df_filtered['timestamp'].min()} to {df_filtered['timestamp'].max()}")
        print(f"  Price range: {df_filtered['price'].min():.2f} - {df_filtered['price'].max():.2f}")
        print(f"  Demand range: {df_filtered['demand'].min():.2f} - {df_filtered['demand'].max():.2f}")
        
        return df_filtered[['timestamp', 'price', 'demand', 'year']].sort_values('timestamp')
    
    def get_daily_batches(self, df: pd.DataFrame) -> List[Tuple[str, np.ndarray, np.ndarray]]:
        """
        Split dataframe into daily batches (24-hour periods)
        
        Returns:
            List of tuples: (date_str, prices_array, demand_array)
        """
        df['date'] = df['timestamp'].dt.date
        
        daily_batches = []
        for date, group in df.groupby('date'):
            if len(group) == 24:  # Only complete days
                date_str = str(date)
                prices = group['price'].values
                demand = group['demand'].values
                daily_batches.append((date_str, prices, demand))
            else:
                print(f"  Warning: Skipping {date} (only {len(group)} hours)")
        
        print(f"  Extracted {len(daily_batches)} complete days")
        return daily_batches



In [10]:
class DatasetGenerator:
    """
    Generates fine-tuning datasets from MILP solutions
    """
    
    def __init__(self, battery_config: Dict[str, float]):
        self.battery_config = battery_config
        self.instruction = self._create_instruction()
        
    def _create_instruction(self) -> str:
        """Create the static instruction component"""
        return """You are an energy storage optimization expert. Given electricity price forecasts and battery specifications, determine the optimal charge/discharge schedule for a 24-hour period to minimize operational costs while respecting all physical and operational constraints.

Your task is to:
1. Analyze the price forecast to identify charging opportunities (low prices) and discharging opportunities (high prices)
2. Respect battery capacity (kWh), power limits (kW), efficiency losses, and state of charge (SOC) bounds
3. Ensure SOC trajectory remains within [SOC_min, SOC_max] at all times
4. Meet demand at every timestep
5. Provide detailed reasoning for each hourly decision
6. Output the complete schedule in JSON format

Key decision principles:
- Charge during hours with prices below average (especially if below 25th percentile)
- Discharge during hours with prices above average (especially if above 75th percentile)
- Account for round-trip efficiency losses (charge efficiency × discharge efficiency)
- Maintain feasible SOC trajectory throughout the day
- Prioritize largest price arbitrage opportunities
- Consider the time value of stored energy
- Ensure smooth SOC transitions to avoid unnecessary cycling

Output format:
- Include complete charge/discharge schedule (kW for each hour)
- Include SOC trajectory
- Provide hourly decision reasoning
- Calculate total operational cost
- Validate all constraints are satisfied"""
    
    def create_input_component(
        self,
        date: str,
        prices: np.ndarray,
        demand: np.ndarray,
        battery_params: BatteryParams,
        dataset_name: str = None
    ) -> str:
        """
        Create the input component with all necessary context
        """
        # Calculate price statistics
        price_stats = {
            "min": float(np.min(prices)),
            "max": float(np.max(prices)),
            "mean": float(np.mean(prices)),
            "median": float(np.median(prices)),
            "std": float(np.std(prices)),
            "p25": float(np.percentile(prices, 25)),
            "p75": float(np.percentile(prices, 75)),
        }
        
        # Identify price patterns
        peak_hours = [int(h) for h in range(24) if prices[h] > price_stats["p75"]]
        valley_hours = [int(h) for h in range(24) if prices[h] < price_stats["p25"]]
        
        input_dict = {
            "date": date,
            "region": dataset_name.upper() if dataset_name else "UNKNOWN",
            "timestep": "hourly",
            "horizon_hours": 24,
            
            # Price data
            "price_forecast_eur_per_mwh": prices.tolist(),
            "price_statistics": price_stats,
            "price_patterns": {
                "peak_hours": peak_hours,
                "valley_hours": valley_hours,
                "price_range": float(price_stats["max"] - price_stats["min"]),
                "volatility_coefficient": float(price_stats["std"] / price_stats["mean"]) if price_stats["mean"] > 0 else 0.0
            },
            
            # Demand data
            "demand_forecast_kw": demand.tolist(),
            "demand_statistics": {
                "min": float(np.min(demand)),
                "max": float(np.max(demand)),
                "mean": float(np.mean(demand)),
                "total_daily_kwh": float(np.sum(demand))
            },
            
            # Battery specifications
            "battery_specifications": {
                "capacity_kwh": float(battery_params.capacity_kwh),
                "max_charge_power_kw": float(battery_params.cmax_kw),
                "max_discharge_power_kw": float(battery_params.dmax_kw),
                "charge_efficiency": float(battery_params.eta_c),
                "discharge_efficiency": float(battery_params.eta_d),
                "roundtrip_efficiency": float(battery_params.eta_c * battery_params.eta_d),
                "initial_soc": float(battery_params.soc_init),
                "soc_minimum": float(battery_params.soc_min),
                "soc_maximum": float(battery_params.soc_max),
                "soc_target_end_of_day": float(battery_params.soc_target) if battery_params.soc_target is not None else None,
            },
            
            # Operational constraints
            "operational_constraints": {
                "timestep_duration_hours": 1.0,
                "allow_grid_export": False,
                "must_meet_demand_every_hour": True,
                "simultaneous_charge_discharge": False
            },
            
            # Market context
            "market_context": {
                "day_of_week": datetime.strptime(date, "%Y-%m-%d").strftime("%A"),
                "season": self._get_season(date),
                "expected_arbitrage_opportunities": len(peak_hours) * len(valley_hours) > 0
            }
        }
        
        return json.dumps(input_dict, indent=2)
    
    def create_output_component(
        self,
        milp_solution: SolveResponse,
        prices: np.ndarray,
        demand: np.ndarray,
        battery_params: BatteryParams
    ) -> str:
        """
        Create the output component from MILP solution with detailed reasoning
        """
        # Extract MILP solution components
        charge_schedule = milp_solution.charge_kw if milp_solution.charge_kw else [0] * 24
        discharge_schedule = milp_solution.discharge_kw if milp_solution.discharge_kw else [0] * 24
        soc_trajectory = milp_solution.soc if milp_solution.soc else [battery_params.soc_init] * 25
        import_grid = milp_solution.import_kw if milp_solution.import_kw else demand.tolist()
        export_grid = milp_solution.export_kw if milp_solution.export_kw else [0] * 24
        total_cost = milp_solution.objective_cost if milp_solution.objective_cost else 0.0
        
        # Calculate statistics
        avg_price = np.mean(prices)
        p25_price = np.percentile(prices, 25)
        p75_price = np.percentile(prices, 75)
        
        # Generate hourly reasoning
        hourly_decisions = []
        for hour in range(24):
            decision = self._analyze_hour_decision(
                hour=hour,
                price=prices[hour],
                charge=charge_schedule[hour],
                discharge=discharge_schedule[hour],
                soc_before=soc_trajectory[hour],
                soc_after=soc_trajectory[hour + 1] if hour + 1 < len(soc_trajectory) else soc_trajectory[hour],
                avg_price=avg_price,
                p25_price=p25_price,
                p75_price=p75_price,
                battery_params=battery_params,
                demand=demand[hour]
            )
            hourly_decisions.append(decision)
        
        # Generate strategy summary
        strategy_summary = self._generate_strategy_summary(
            hourly_decisions, prices, charge_schedule, discharge_schedule, soc_trajectory
        )
        
        # Calculate performance metrics
        performance_metrics = self._calculate_performance_metrics(
            charge_schedule, discharge_schedule, soc_trajectory,
            prices, import_grid, export_grid, battery_params
        )
        
        # Construct output
        output_dict = {
            "optimization_status": milp_solution.status if milp_solution.status else "unknown",
            "objective_value": {
                "total_cost_eur": float(total_cost),
                "average_hourly_cost_eur": float(total_cost / 24) if total_cost else 0.0
            },
            
            "schedule": {
                "charge_kw": [float(c) for c in charge_schedule],
                "discharge_kw": [float(d) for d in discharge_schedule],
                "soc_trajectory": [float(s) for s in soc_trajectory],
                "grid_import_kw": [float(i) for i in import_grid],
                "grid_export_kw": [float(e) for e in export_grid]
            },
            
            "strategy_summary": strategy_summary,
            
            "hourly_decision_reasoning": hourly_decisions,
            
            "performance_metrics": performance_metrics,
            
            "validation": {
                "all_constraints_satisfied": self._validate_solution(
                    charge_schedule, discharge_schedule, soc_trajectory, battery_params
                ),
                "min_soc_value": float(np.min(soc_trajectory)),
                "max_soc_value": float(np.max(soc_trajectory)),
                "soc_violations": int(np.sum([
                    1 for s in soc_trajectory 
                    if s < battery_params.soc_min - 1e-6 or s > battery_params.soc_max + 1e-6
                ])),
                "power_violations": int(np.sum([
                    1 for c, d in zip(charge_schedule, discharge_schedule)
                    if c > battery_params.cmax_kw + 1e-6 or d > battery_params.dmax_kw + 1e-6
                ]))
            }
        }
        
        return json.dumps(output_dict, indent=2)
    
    def _analyze_hour_decision(
        self,
        hour: int,
        price: float,
        charge: float,
        discharge: float,
        soc_before: float,
        soc_after: float,
        avg_price: float,
        p25_price: float,
        p75_price: float,
        battery_params: BatteryParams,
        demand: float
    ) -> Dict[str, Any]:
        """Generate detailed reasoning for a single hour's decision"""
        
        # Determine decision type
        if charge > 0.01:
            decision_type = "CHARGE"
            action_magnitude = charge
        elif discharge > 0.01:
            decision_type = "DISCHARGE"
            action_magnitude = discharge
        else:
            decision_type = "IDLE"
            action_magnitude = 0.0
        
        # Generate reasoning
        reasoning_parts = []
        
        # Price analysis
        if decision_type == "CHARGE":
            if price < p25_price:
                reasoning_parts.append(
                    f"Price {price:.2f} EUR/MWh is in the bottom quartile (below {p25_price:.2f}), "
                    f"making this an excellent charging opportunity."
                )
            elif price < avg_price:
                reasoning_parts.append(
                    f"Price {price:.2f} EUR/MWh is below average ({avg_price:.2f}), "
                    f"making this a good charging window."
                )
            reasoning_parts.append(
                f"Charging at {charge:.2f} kW to store energy for later use during more expensive hours."
            )
            
        elif decision_type == "DISCHARGE":
            if price > p75_price:
                reasoning_parts.append(
                    f"Price {price:.2f} EUR/MWh is in the top quartile (above {p75_price:.2f}), "
                    f"making this an excellent discharging opportunity."
                )
            elif price > avg_price:
                reasoning_parts.append(
                    f"Price {price:.2f} EUR/MWh is above average ({avg_price:.2f}), "
                    f"making this a good discharging window."
                )
            reasoning_parts.append(
                f"Discharging at {discharge:.2f} kW to reduce grid imports during expensive hours."
            )
            
        else:  # IDLE
            price_position = "near average" if abs(price - avg_price) < 0.1 * avg_price else "moderate"
            reasoning_parts.append(
                f"Price {price:.2f} EUR/MWh is {price_position} ({avg_price:.2f} average), "
                f"not offering sufficient arbitrage opportunity to justify battery action."
            )
        
        # SOC constraints
        soc_change = soc_after - soc_before
        if abs(soc_change) > 0.001:
            reasoning_parts.append(
                f"SOC changes from {soc_before:.1%} to {soc_after:.1%} "
                f"({'increasing' if soc_change > 0 else 'decreasing'} by {abs(soc_change):.1%})."
            )
        
        # Constraint boundaries
        if soc_after >= battery_params.soc_max - 0.01:
            reasoning_parts.append("Battery reaches maximum SOC capacity.")
        elif soc_after <= battery_params.soc_min + 0.01:
            reasoning_parts.append("Battery reaches minimum SOC limit.")
        
        if decision_type == "CHARGE" and charge >= battery_params.cmax_kw - 0.01:
            reasoning_parts.append("Charging at maximum power capacity.")
        elif decision_type == "DISCHARGE" and discharge >= battery_params.dmax_kw - 0.01:
            reasoning_parts.append("Discharging at maximum power capacity.")
        
        # Energy balance
        net_demand = demand - discharge + charge
        reasoning_parts.append(f"Net grid import: {net_demand:.2f} kW to meet {demand:.2f} kW demand.")
        
        return {
            "hour": hour,
            "time_of_day": f"{hour:02d}:00",
            "electricity_price_eur_per_mwh": float(price),
            "price_vs_average": float(price - avg_price),
            "decision": decision_type,
            "charge_power_kw": float(charge),
            "discharge_power_kw": float(discharge),
            "soc_before_action": float(soc_before),
            "soc_after_action": float(soc_after),
            "soc_change": float(soc_change),
            "detailed_reasoning": " ".join(reasoning_parts)
        }
    
    def _generate_strategy_summary(
        self,
        hourly_decisions: List[Dict],
        prices: np.ndarray,
        charge_schedule: List[float],
        discharge_schedule: List[float],
        soc_trajectory: List[float]
    ) -> str:
        """Generate high-level strategy summary"""
        
        charge_hours = [d['hour'] for d in hourly_decisions if d['decision'] == 'CHARGE']
        discharge_hours = [d['hour'] for d in hourly_decisions if d['decision'] == 'DISCHARGE']
        
        total_charged = sum(charge_schedule)
        total_discharged = sum(discharge_schedule)
        
        avg_charge_price = np.mean([prices[h] for h in charge_hours]) if charge_hours else 0
        avg_discharge_price = np.mean([prices[h] for h in discharge_hours]) if discharge_hours else 0
        
        summary_parts = []
        
        # Charging summary
        if charge_hours:
            summary_parts.append(
                f"Charging strategy: Charge during {len(charge_hours)} hours "
                f"at average price {avg_charge_price:.2f} EUR/MWh. "
                f"Total energy charged: {total_charged:.2f} kWh."
            )
        
        # Discharging summary
        if discharge_hours:
            summary_parts.append(
                f"Discharging strategy: Discharge during {len(discharge_hours)} hours "
                f"at average price {avg_discharge_price:.2f} EUR/MWh. "
                f"Total energy discharged: {total_discharged:.2f} kWh."
            )
        
        # Arbitrage analysis
        if charge_hours and discharge_hours:
            price_spread = avg_discharge_price - avg_charge_price
            summary_parts.append(
                f"Price arbitrage: Buying at {avg_charge_price:.2f} and selling at {avg_discharge_price:.2f}, "
                f"capturing a spread of {price_spread:.2f} EUR/MWh."
            )
        
        # SOC management
        min_soc = min(soc_trajectory)
        max_soc = max(soc_trajectory)
        soc_range = max_soc - min_soc
        summary_parts.append(
            f"SOC management: Operates between {min_soc:.1%} and {max_soc:.1%} "
            f"(utilizing {soc_range:.1%} of available capacity)."
        )
        
        return " ".join(summary_parts)
    
    def _calculate_performance_metrics(
        self,
        charge_schedule: List[float],
        discharge_schedule: List[float],
        soc_trajectory: List[float],
        prices: np.ndarray,
        import_grid: List[float],
        export_grid: List[float],
        battery_params: BatteryParams
    ) -> Dict[str, float]:
        """Calculate various performance metrics"""
        
        total_charged_kwh = sum(charge_schedule)
        total_discharged_kwh = sum(discharge_schedule)
        
        # Energy metrics
        energy_throughput = total_charged_kwh + total_discharged_kwh
        cycle_count = total_charged_kwh / battery_params.capacity_kwh if battery_params.capacity_kwh > 0 else 0
        
        # Financial metrics
        charge_cost = sum(c * p / 1000 for c, p in zip(charge_schedule, prices))
        discharge_revenue = sum(d * p / 1000 for d, p in zip(discharge_schedule, prices))
        arbitrage_profit = discharge_revenue - charge_cost
        
        # Efficiency metrics
        roundtrip_efficiency = (total_discharged_kwh / total_charged_kwh * 100) if total_charged_kwh > 0 else 0
        theoretical_efficiency = battery_params.eta_c * battery_params.eta_d * 100
        
        # Utilization metrics
        avg_soc = np.mean(soc_trajectory)
        soc_range = max(soc_trajectory) - min(soc_trajectory)
        capacity_utilization = soc_range * 100
        
        return {
            "total_energy_charged_kwh": float(total_charged_kwh),
            "total_energy_discharged_kwh": float(total_discharged_kwh),
            "net_energy_flow_kwh": float(total_charged_kwh - total_discharged_kwh),
            "energy_throughput_kwh": float(energy_throughput),
            "cycle_count": float(cycle_count),
            "charging_cost_eur": float(charge_cost),
            "discharging_revenue_eur": float(discharge_revenue),
            "arbitrage_profit_eur": float(arbitrage_profit),
            "realized_roundtrip_efficiency_percent": float(roundtrip_efficiency),
            "theoretical_roundtrip_efficiency_percent": float(theoretical_efficiency),
            "average_soc": float(avg_soc),
            "soc_range_utilized": float(soc_range),
            "capacity_utilization_percent": float(capacity_utilization)
        }
    
    def _validate_solution(
        self,
        charge_schedule: List[float],
        discharge_schedule: List[float],
        soc_trajectory: List[float],
        battery_params: BatteryParams
    ) -> bool:
        """Validate that solution satisfies all constraints"""
        
        # Check SOC bounds
        soc_valid = all(
            battery_params.soc_min - 1e-6 <= s <= battery_params.soc_max + 1e-6 
            for s in soc_trajectory
        )
        
        # Check power bounds
        power_valid = all(
            c <= battery_params.cmax_kw + 1e-6 and 
            d <= battery_params.dmax_kw + 1e-6
            for c, d in zip(charge_schedule, discharge_schedule)
        )
        
        # Check no simultaneous charge/discharge
        no_simultaneous = all(
            c < 1e-6 or d < 1e-6 
            for c, d in zip(charge_schedule, discharge_schedule)
        )
        
        return soc_valid and power_valid and no_simultaneous
    
    def _get_season(self, date_str: str) -> str:
        """Determine season from date"""
        month = datetime.strptime(date_str, "%Y-%m-%d").month
        if month in [12, 1, 2]:
            return "winter"
        elif month in [3, 4, 5]:
            return "spring"
        elif month in [6, 7, 8]:
            return "summer"
        else:
            return "fall"

In [11]:
def generate_dataset_for_region(
    dataset_name: str,
    split: str,
    battery_params: BatteryParams,
    data_dir: str = "./data",
    output_dir: str = "./datasets",
    max_days: int = None
) -> List[Dict[str, str]]:
    """
    Generate fine-tuning dataset for a specific region and split
    
    Args:
        dataset_name: 'italy', 'germany', or 'caiso'
        split: 'train' or 'test'
        battery_params: Battery configuration
        data_dir: Directory containing CSV files
        output_dir: Directory to save output datasets
        max_days: Maximum number of days to process (None = all)
        
    Returns:
        List of dataset entries
    """
    print(f"\n{'='*80}")
    print(f"GENERATING DATASET: {dataset_name.upper()} - {split.upper()}")
    print(f"{'='*80}\n")
    
    # Load data
    loader = MultiDatasetLoader(data_dir=data_dir)
    df = loader.load_dataset(dataset_name, split)
    daily_batches = loader.get_daily_batches(df)
    
    if max_days:
        daily_batches = daily_batches[:max_days]
        print(f"Limiting to first {max_days} days")
    
    # Initialize generator
    generator = DatasetGenerator(battery_config={})
    
    # Generate dataset
    dataset = []
    failures = []
    
    for idx, (date, prices, demand) in enumerate(daily_batches):
        try:
            # Create day inputs
            day = DayInputs(
                prices_buy=prices.tolist(),
                demand_kw=demand.tolist(),
                prices_sell=prices.tolist(),
                allow_export=False,
                dt_hours=1.0
            )

            milp_solution = solve_daily_milp(
                batt=battery_params,
                day=day,
                solver="GUROBI", 
                solver_opts={}
            )
            # Run MILP solver
            # milp_solution = solve_daily_milp(
            #     batt=battery_params,
            #     day=day,
            #     solver=None, 
            #     solver_opts={}
            # )

# # Run MILP solver
# milp_solution = solve_daily_milp(
#     batt=battery_params,
#     day=day,
#     solver="GUROBI", 
#     solver_opts={}
# )
            
            # Check if solution is valid
            if milp_solution.status not in ["optimal", "optimal_inaccurate"]:
                print(f"  Warning: Day {date} - MILP status: {milp_solution.status}")
                failures.append({'date': date, 'reason': milp_solution.status})
                continue
            
            # Create dataset entry
            input_str = generator.create_input_component(
                date=date,
                prices=prices,
                demand=demand,
                battery_params=battery_params,
                dataset_name=dataset_name
            )
            
            output_str = generator.create_output_component(
                milp_solution=milp_solution,
                prices=prices,
                demand=demand,
                battery_params=battery_params
            )
            
            dataset.append({
                'instruction': generator.instruction,
                'input': input_str,
                'output': output_str
            })
            
            # Update battery SOC for next day
            if milp_solution.soc and len(milp_solution.soc) > 0:
                battery_params.soc_init = milp_solution.soc[-1]
            
            if (idx + 1) % 50 == 0:
                print(f"  Processed {idx + 1}/{len(daily_batches)} days")
                
        except Exception as e:
            print(f"  Error processing day {date}: {e}")
            failures.append({'date': date, 'reason': str(e)})
            continue
    
    print(f"\n{'='*80}")
    print(f"GENERATION COMPLETE: {dataset_name.upper()} - {split.upper()}")
    print(f"  Successful: {len(dataset)} examples")
    print(f"  Failed: {len(failures)} examples")
    print(f"{'='*80}\n")
    
    # Save dataset
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True, parents=True)
    
    output_file = output_path / f"{dataset_name}_{split}.json"
    with open(output_file, 'w') as f:
        json.dump(dataset, f, indent=2)
    
    print(f"Dataset saved to: {output_file}")
    print(f"Size: {len(json.dumps(dataset)) / 1024 / 1024:.2f} MB\n")
    
    # Save failures log if any
    if failures:
        failures_file = output_path / f"{dataset_name}_{split}_failures.json"
        with open(failures_file, 'w') as f:
            json.dump(failures, f, indent=2)
        print(f"Failures logged to: {failures_file}\n")
    
    return dataset


def combine_datasets(
    datasets: List[List[Dict]],
    output_path: str
):
    """
    Combine multiple datasets into one
    """
    print(f"\nCombining {len(datasets)} datasets...")
    
    combined = []
    for ds in datasets:
        combined.extend(ds)
    
    # Shuffle
    import random
    random.seed(42)
    random.shuffle(combined)
    
    # Save
    with open(output_path, 'w') as f:
        json.dump(combined, f, indent=2)
    
    print(f"Combined dataset saved to: {output_path}")
    print(f"Total examples: {len(combined)}")
    print(f"Size: {len(json.dumps(combined)) / 1024 / 1024:.2f} MB\n")


# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="Generate fine-tuning datasets from MILP solutions"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        choices=['italy', 'germany', 'caiso', 'all'],
        default='all',
        help="Which dataset to generate"
    )
    parser.add_argument(
        "--split",
        type=str,
        choices=['train', 'test', 'both'],
        default='both',
        help="Which split to generate"
    )
    parser.add_argument(
        "--data-dir",
        type=str,
        default="./data",
        help="Directory containing CSV files"
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="./datasets",
        help="Directory to save output datasets"
    )
    parser.add_argument(
        "--max-days",
        type=int,
        default=None,
        help="Maximum number of days per dataset (for testing)"
    )
    parser.add_argument(
        "--generate-all",
        action="store_true",
        help="Generate all datasets (all regions, train and test)"
    )
    
    args = parser.parse_args()
    
    # Define battery configuration
    # Adjust these based on your actual battery specs
    battery_params = BatteryParams(
        capacity_kwh=21.89,
        cmax_kw=5.47,
        dmax_kw=5.47,
        eta_c=0.95,
        eta_d=0.95,
        soc_init=0.5,
        soc_min=0.0,
        soc_max=1.0,
        soc_target=0.5
    )
    
    print("\n" + "="*80)
    print("ENERGY STORAGE DATASET GENERATION")
    print("="*80)
    print("\nBattery Configuration:")
    print(f"  Capacity: {battery_params.capacity_kwh} kWh")
    print(f"  Max Charge: {battery_params.cmax_kw} kW")
    print(f"  Max Discharge: {battery_params.dmax_kw} kW")
    print(f"  Efficiency: {battery_params.eta_c}/{battery_params.eta_d}")
    print(f"  SOC Range: [{battery_params.soc_min}, {battery_params.soc_max}]")
    
    # Determine which datasets to generate
    if args.generate_all or args.dataset == 'all':
        datasets_to_generate = ['italy', 'germany', 'caiso']
    else:
        datasets_to_generate = [args.dataset]
    
    # Determine which splits to generate
    if args.split == 'both':
        splits_to_generate = ['train', 'test']
    else:
        splits_to_generate = [args.split]
    
    # Generate datasets
    all_train_datasets = []
    all_test_datasets = []
    
    for dataset_name in datasets_to_generate:
        # Reset battery SOC for each region
        battery_params.soc_init = 0.5
        
        for split in splits_to_generate:
            dataset = generate_dataset_for_region(
                dataset_name=dataset_name,
                split=split,
                battery_params=battery_params,
                data_dir=args.data_dir,
                output_dir=args.output_dir,
                max_days=args.max_days
            )
            
            if split == 'train':
                all_train_datasets.append(dataset)
            else:
                all_test_datasets.append(dataset)
    
    # Combine all training datasets
    if len(all_train_datasets) > 1:
        combine_datasets(
            all_train_datasets,
            f"{args.output_dir}/combined_train.json"
        )
    
    # Combine all test datasets
    if len(all_test_datasets) > 1:
        combine_datasets(
            all_test_datasets,
            f"{args.output_dir}/combined_test.json"
        )
    
    print("\n" + "="*80)
    print("DATASET GENERATION PIPELINE COMPLETE!")
    print("="*80)
    print("\nNext steps:")
    print("  1. Review generated datasets in:", args.output_dir)
    print("  2. Use combined_train.json for fine-tuning")
    print("  3. Use combined_test.json or individual test sets for evaluation")
    print("  4. Run: python train_with_unsloth.py --dataset ./datasets/combined_train.json")
    print("\n")



In [12]:
from agentic_energy.schemas import BatteryParams

battery_params = BatteryParams(
    capacity_kwh=21.89,
    cmax_kw=5.47,
    dmax_kw=5.47,
    eta_c=0.95,
    eta_d=0.95,
    soc_init=0.5,
    soc_min=0.0,
    soc_max=1.0,
    soc_target=0.5
)

dataset = generate_dataset_for_region(
    dataset_name='italy',
    split='train',
    battery_params=battery_params,
    data_dir="./agentic_energy/data",
    output_dir="./datasets",
    max_days=5
)


GENERATING DATASET: ITALY - TRAIN

Loading italy data from agentic_energy\data\Italy_data.csv...
  Loaded 8758 records from years [2022]
  Date range: 2022-01-01 00:00:00 to 2022-12-31 23:00:00
  Price range: 10.00 - 870.00
  Demand range: 17.82 - 49.07
  Extracted 364 complete days
Limiting to first 5 days


This use of ``*`` has resulted in matrix multiplication.
Using ``*`` for matrix multiplication has been deprecated since CVXPY 1.1.
    Use ``*`` for matrix-scalar and vector-scalar multiplication.
    Use ``@`` for matrix-matrix and matrix-vector multiplication.
    Use ``multiply`` for elementwise multiplication.
This code path has been hit 6 times so far.

This use of ``*`` has resulted in matrix multiplication.
Using ``*`` for matrix multiplication has been deprecated since CVXPY 1.1.
    Use ``*`` for matrix-scalar and vector-scalar multiplication.
    Use ``@`` for matrix-matrix and matrix-vector multiplication.
    Use ``multiply`` for elementwise multiplication.
This code path has been hit 7 times so far.

This use of ``*`` has resulted in matrix multiplication.
Using ``*`` for matrix multiplication has been deprecated since CVXPY 1.1.
    Use ``*`` for matrix-scalar and vector-scalar multiplication.
    Use ``@`` for matrix-matrix and matrix-vector multiplication.
    Use ``mu


GENERATION COMPLETE: ITALY - TRAIN
  Successful: 5 examples
  Failed: 0 examples

Dataset saved to: datasets\italy_train.json
Size: 0.11 MB



This use of ``*`` has resulted in matrix multiplication.
Using ``*`` for matrix multiplication has been deprecated since CVXPY 1.1.
    Use ``*`` for matrix-scalar and vector-scalar multiplication.
    Use ``@`` for matrix-matrix and matrix-vector multiplication.
    Use ``multiply`` for elementwise multiplication.
This code path has been hit 10 times so far.



In [2]:
import json
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import List, Dict, Any, Tuple, Optional
from pathlib import Path
import argparse

# Import your existing framework components
from agentic_energy.schemas import BatteryParams, DayInputs, SolveRequest, SolveResponse
from agentic_energy.data_loader import BatteryDataLoader
from agentic_energy.milp.milp_mcp_server import solve_daily_milp

DATASET_CONFIG = {
    'italy': {
        'file': 'Italy_data.csv',
        'train_years': [2022],
        'test_years': [2023],
        'price_column': 'prices',  # Adjust based on your CSV structure
        'demand_column': 'consumption',  # Adjust based on your CSV structure
        'timestamp_column': 'timestamps',  # Adjust based on your CSV structure
    },
    'germany': {
        'file': 'Germany_energy_Data.csv',
        'train_years': [2019, 2020, 2021, 2022],
        'test_years': [2023],
        'price_column': 'prices',
        'demand_column': 'consumption',
        'timestamp_column': 'timestamps',
    },
    'caiso': {
        'file': 'CAISO_data.csv',
        'train_years': [2021, 2022],
        'test_years': [2023],
        'price_column': 'prices',
        'demand_column': 'consumption',
        'timestamp_column': 'timestamps',
    }
}

class MultiDatasetLoader:
    """Load and prepare data from multiple ISOs"""
    
    def __init__(self, data_dir: str = "./data"):
        self.data_dir = Path(data_dir)
        
    def load_dataset(
        self,
        dataset_name: str,
        split: str = 'train'
    ) -> pd.DataFrame:
        """
        Load dataset for specified split (train or test)
        
        Args:
            dataset_name: 'italy', 'germany', or 'caiso'
            split: 'train' or 'test'
            
        Returns:
            DataFrame with columns: timestamp, price, demand
        """
        if dataset_name not in DATASET_CONFIG:
            raise ValueError(f"Unknown dataset: {dataset_name}. Use 'italy', 'germany', or 'caiso'")
        
        config = DATASET_CONFIG[dataset_name]
        file_path = self.data_dir / config['file']
        
        if not file_path.exists():
            raise FileNotFoundError(f"Dataset file not found: {file_path}")
        
        # Load CSV
        print(f"Loading {dataset_name} data from {file_path}...")
        df = pd.read_csv(file_path)
        
        # Ensure timestamp is datetime
        if config['timestamp_column'] in df.columns:
            df['timestamp'] = pd.to_datetime(df[config['timestamp_column']])
        else:
            raise ValueError(f"Timestamp column '{config['timestamp_column']}' not found in {file_path}")
        
        # Rename columns to standard names
        df = df.rename(columns={
            config['price_column']: 'price',
            config['demand_column']: 'demand'
        })
        
        # ✅ Handle European number format (comma as decimal separator)
        # Convert to string first, then replace commas with periods
        df['price'] = df['price'].astype(str).str.replace(',', '.', regex=False)
        df['demand'] = df['demand'].astype(str).str.replace(',', '.', regex=False)
        
        # Now convert to numeric
        df['price'] = pd.to_numeric(df['price'], errors='coerce')
        df['demand'] = pd.to_numeric(df['demand'], errors='coerce')

        if dataset_name == 'germany':
            # original_mean = df['demand'].mean()
            df['demand'] = df['demand'] / 1000
            # print(f"  ⚙️  Scaled Germany demand: {original_mean:.0f} MW → {df['demand'].mean():.2f} MW")
    
        
        # Drop rows with missing values
        rows_before = len(df)
        df = df.dropna(subset=['price', 'demand'])
        rows_after = len(df)
        
        if rows_before != rows_after:
            print(f"  ⚠️  Dropped {rows_before - rows_after} rows with invalid/missing data")
        
        # Extract year
        df['year'] = df['timestamp'].dt.year
        
        # Filter by split
        if split == 'train':
            years = config['train_years']
        elif split == 'test':
            years = config['test_years']
        else:
            raise ValueError(f"Split must be 'train' or 'test', got: {split}")
        
        df_filtered = df[df['year'].isin(years)].copy()
        
        print(f"  Loaded {len(df_filtered)} records from years {years}")
        print(f"  Date range: {df_filtered['timestamp'].min()} to {df_filtered['timestamp'].max()}")
        print(f"  Price range: {df_filtered['price'].min():.2f} - {df_filtered['price'].max():.2f}")
        print(f"  Demand range: {df_filtered['demand'].min():.2f} - {df_filtered['demand'].max():.2f}")
        
        return df_filtered[['timestamp', 'price', 'demand', 'year']].sort_values('timestamp')
    # def load_dataset(
    #     self,
    #     dataset_name: str,
    #     split: str = 'train'
    # ) -> pd.DataFrame:
    #     """
    #     Load dataset for specified split (train or test)
        
    #     Args:
    #         dataset_name: 'italy', 'germany', or 'caiso'
    #         split: 'train' or 'test'
            
    #     Returns:
    #         DataFrame with columns: timestamp, price, demand
    #     """
    #     if dataset_name not in DATASET_CONFIG:
    #         raise ValueError(f"Unknown dataset: {dataset_name}. Use 'italy', 'germany', or 'caiso'")
        
    #     config = DATASET_CONFIG[dataset_name]
    #     file_path = self.data_dir / config['file']
        
    #     if not file_path.exists():
    #         raise FileNotFoundError(f"Dataset file not found: {file_path}")
        
    #     # Load CSV
    #     print(f"Loading {dataset_name} data from {file_path}...")
    #     df = pd.read_csv(file_path)
        
    #     # Ensure timestamp is datetime
    #     if config['timestamp_column'] in df.columns:
    #         df['timestamp'] = pd.to_datetime(df[config['timestamp_column']])
    #     else:
    #         raise ValueError(f"Timestamp column '{config['timestamp_column']}' not found in {file_path}")
        
    #     # Rename columns to standard names
    #     df = df.rename(columns={
    #         config['price_column']: 'price',
    #         config['demand_column']: 'demand'
    #     })
        
    #     # Extract year
    #     df['year'] = df['timestamp'].dt.year
        
    #     # Filter by split
    #     if split == 'train':
    #         years = config['train_years']
    #     elif split == 'test':
    #         years = config['test_years']
    #     else:
    #         raise ValueError(f"Split must be 'train' or 'test', got: {split}")
        
    #     df_filtered = df[df['year'].isin(years)].copy()
        
    #     print(f"  Loaded {len(df_filtered)} records from years {years}")
    #     print(f"  Date range: {df_filtered['timestamp'].min()} to {df_filtered['timestamp'].max()}")
    #     print(f"  Price range: {df_filtered['price'].min():.2f} - {df_filtered['price'].max():.2f}")
    #     print(f"  Demand range: {df_filtered['demand'].min():.2f} - {df_filtered['demand'].max():.2f}")
        
    #     return df_filtered[['timestamp', 'price', 'demand', 'year']].sort_values('timestamp')
    
    def get_daily_batches(self, df: pd.DataFrame) -> List[Tuple[str, np.ndarray, np.ndarray]]:
        """
        Split dataframe into daily batches (24-hour periods)
        
        Returns:
            List of tuples: (date_str, prices_array, demand_array)
        """
        df['date'] = df['timestamp'].dt.date
        
        daily_batches = []
        for date, group in df.groupby('date'):
            if len(group) == 24:  # Only complete days
                date_str = str(date)
                prices = group['price'].values
                demand = group['demand'].values
                daily_batches.append((date_str, prices, demand))
            else:
                print(f"  Warning: Skipping {date} (only {len(group)} hours)")
        
        print(f"  Extracted {len(daily_batches)} complete days")
        return daily_batches


class DatasetGenerator:
    """
    Generates fine-tuning datasets from MILP solutions
    """
    
    def __init__(self, battery_config: Dict[str, float]):
        self.battery_config = battery_config
        self.instruction = self._create_instruction()
        
    def _create_instruction(self) -> str:
        """Create the static instruction component"""
        return """You are an energy storage optimization expert. Given electricity price forecasts and battery specifications, determine the optimal charge/discharge schedule for a 24-hour period to minimize operational costs while respecting all physical and operational constraints.

Your task is to:
1. Analyze the price forecast to identify charging opportunities (low prices) and discharging opportunities (high prices)
2. Respect battery capacity (kWh), power limits (kW), efficiency losses, and state of charge (SOC) bounds
3. Ensure SOC trajectory remains within [SOC_min, SOC_max] at all times
4. Meet demand at every timestep
5. Provide detailed reasoning for each hourly decision
6. Output the complete schedule in JSON format

Key decision principles:
- Charge during hours with prices below average (especially if below 25th percentile)
- Discharge during hours with prices above average (especially if above 75th percentile)
- Account for round-trip efficiency losses (charge efficiency × discharge efficiency)
- Maintain feasible SOC trajectory throughout the day
- Prioritize largest price arbitrage opportunities
- Consider the time value of stored energy
- Ensure smooth SOC transitions to avoid unnecessary cycling

Output format:
- Include complete charge/discharge schedule (kW for each hour)
- Include SOC trajectory
- Provide hourly decision reasoning
- Calculate total operational cost
- Validate all constraints are satisfied"""
    
    def create_input_component(
        self,
        date: str,
        prices: np.ndarray,
        demand: np.ndarray,
        battery_params: BatteryParams,
        dataset_name: str = None
    ) -> str:
        """
        Create the input component with all necessary context
        """
        # Calculate price statistics
        price_stats = {
            "min": float(np.min(prices)),
            "max": float(np.max(prices)),
            "mean": float(np.mean(prices)),
            "median": float(np.median(prices)),
            "std": float(np.std(prices)),
            "p25": float(np.percentile(prices, 25)),
            "p75": float(np.percentile(prices, 75)),
        }
        
        # Identify price patterns
        peak_hours = [int(h) for h in range(24) if prices[h] > price_stats["p75"]]
        valley_hours = [int(h) for h in range(24) if prices[h] < price_stats["p25"]]
        
        input_dict = {
            "date": date,
            "region": dataset_name.upper() if dataset_name else "UNKNOWN",
            "timestep": "hourly",
            "horizon_hours": 24,
            
            # Price data
            "price_forecast_eur_per_mwh": prices.tolist(),
            "price_statistics": price_stats,
            "price_patterns": {
                "peak_hours": peak_hours,
                "valley_hours": valley_hours,
                "price_range": float(price_stats["max"] - price_stats["min"]),
                "volatility_coefficient": float(price_stats["std"] / price_stats["mean"]) if price_stats["mean"] > 0 else 0.0
            },
            
            # Demand data
            "demand_forecast_kw": demand.tolist(),
            "demand_statistics": {
                "min": float(np.min(demand)),
                "max": float(np.max(demand)),
                "mean": float(np.mean(demand)),
                "total_daily_kwh": float(np.sum(demand))
            },
            
            # Battery specifications
            "battery_specifications": {
                "capacity_kwh": float(battery_params.capacity_kwh),
                "max_charge_power_kw": float(battery_params.cmax_kw),
                "max_discharge_power_kw": float(battery_params.dmax_kw),
                "charge_efficiency": float(battery_params.eta_c),
                "discharge_efficiency": float(battery_params.eta_d),
                "roundtrip_efficiency": float(battery_params.eta_c * battery_params.eta_d),
                "initial_soc": float(battery_params.soc_init),
                "soc_minimum": float(battery_params.soc_min),
                "soc_maximum": float(battery_params.soc_max),
                "soc_target_end_of_day": float(battery_params.soc_target) if battery_params.soc_target is not None else None,
            },
            
            # Operational constraints
            "operational_constraints": {
                "timestep_duration_hours": 1.0,
                "allow_grid_export": True,
                "must_meet_demand_every_hour": True,
                "simultaneous_charge_discharge": False
            },
            
            # Market context
            "market_context": {
                "day_of_week": datetime.strptime(date, "%Y-%m-%d").strftime("%A"),
                "season": self._get_season(date),
                "expected_arbitrage_opportunities": len(peak_hours) * len(valley_hours) > 0
            }
        }
        
        return json.dumps(input_dict, indent=2)
    
    def create_output_component(
        self,
        milp_solution: SolveResponse,
        prices: np.ndarray,
        demand: np.ndarray,
        battery_params: BatteryParams
    ) -> str:
        """
        Create the output component from MILP solution with detailed reasoning
        """
        # Extract MILP solution components
        charge_schedule = milp_solution.charge_kw if milp_solution.charge_kw else [0] * 24
        discharge_schedule = milp_solution.discharge_kw if milp_solution.discharge_kw else [0] * 24
        soc_trajectory = milp_solution.soc if milp_solution.soc else [battery_params.soc_init] * 25
        import_grid = milp_solution.import_kw if milp_solution.import_kw else demand.tolist()
        export_grid = milp_solution.export_kw if milp_solution.export_kw else [0] * 24
        total_cost = milp_solution.objective_cost if milp_solution.objective_cost else 0.0
        
        # Calculate statistics
        avg_price = np.mean(prices)
        p25_price = np.percentile(prices, 25)
        p75_price = np.percentile(prices, 75)
        
        # Generate hourly reasoning
        hourly_decisions = []
        for hour in range(24):
            decision = self._analyze_hour_decision(
                hour=hour,
                price=prices[hour],
                charge=charge_schedule[hour],
                discharge=discharge_schedule[hour],
                soc_before=soc_trajectory[hour],
                soc_after=soc_trajectory[hour + 1] if hour + 1 < len(soc_trajectory) else soc_trajectory[hour],
                avg_price=avg_price,
                p25_price=p25_price,
                p75_price=p75_price,
                battery_params=battery_params,
                demand=demand[hour]
            )
            hourly_decisions.append(decision)
        
        # Generate strategy summary
        strategy_summary = self._generate_strategy_summary(
            hourly_decisions, prices, charge_schedule, discharge_schedule, soc_trajectory
        )
        
        # Calculate performance metrics
        performance_metrics = self._calculate_performance_metrics(
            charge_schedule, discharge_schedule, soc_trajectory,
            prices, import_grid, export_grid, battery_params
        )
        
        # Construct output
        output_dict = {
            "optimization_status": milp_solution.status if milp_solution.status else "unknown",
            "objective_value": {
                "total_cost_eur": float(total_cost),
                "average_hourly_cost_eur": float(total_cost / 24) if total_cost else 0.0
            },
            
            "schedule": {
                "charge_kw": [float(c) for c in charge_schedule],
                "discharge_kw": [float(d) for d in discharge_schedule],
                "soc_trajectory": [float(s) for s in soc_trajectory],
                "grid_import_kw": [float(i) for i in import_grid],
                "grid_export_kw": [float(e) for e in export_grid]
            },
            
            "strategy_summary": strategy_summary,
            
            "hourly_decision_reasoning": hourly_decisions,
            
            "performance_metrics": performance_metrics,
            
            "validation": {
                "all_constraints_satisfied": self._validate_solution(
                    charge_schedule, discharge_schedule, soc_trajectory, battery_params
                ),
                "min_soc_value": float(np.min(soc_trajectory)),
                "max_soc_value": float(np.max(soc_trajectory)),
                "soc_violations": int(np.sum([
                    1 for s in soc_trajectory 
                    if s < battery_params.soc_min - 1e-6 or s > battery_params.soc_max + 1e-6
                ])),
                "power_violations": int(np.sum([
                    1 for c, d in zip(charge_schedule, discharge_schedule)
                    if c > battery_params.cmax_kw + 1e-6 or d > battery_params.dmax_kw + 1e-6
                ]))
            }
        }
        
        return json.dumps(output_dict, indent=2)
    
    def _analyze_hour_decision(
        self,
        hour: int,
        price: float,
        charge: float,
        discharge: float,
        soc_before: float,
        soc_after: float,
        avg_price: float,
        p25_price: float,
        p75_price: float,
        battery_params: BatteryParams,
        demand: float
    ) -> Dict[str, Any]:
        """Generate detailed reasoning for a single hour's decision"""
        
        # Determine decision type
        if charge > 0.01:
            decision_type = "CHARGE"
            action_magnitude = charge
        elif discharge > 0.01:
            decision_type = "DISCHARGE"
            action_magnitude = discharge
        else:
            decision_type = "IDLE"
            action_magnitude = 0.0
        
        # Generate reasoning
        reasoning_parts = []
        
        # Price analysis
        if decision_type == "CHARGE":
            if price < p25_price:
                reasoning_parts.append(
                    f"Price {price:.2f} EUR/MWh is in the bottom quartile (below {p25_price:.2f}), "
                    f"making this an excellent charging opportunity."
                )
            elif price < avg_price:
                reasoning_parts.append(
                    f"Price {price:.2f} EUR/MWh is below average ({avg_price:.2f}), "
                    f"making this a good charging window."
                )

            # reasoning misses the correct way to analyze, can you scale the 
            reasoning_parts.append(
                f"Charging at {charge:.2f} kW to store energy for later use during more expensive hours."
            )
            
        elif decision_type == "DISCHARGE":
            if price > p75_price:
                reasoning_parts.append(
                    f"Price {price:.2f} EUR/MWh is in the top quartile (above {p75_price:.2f}), "
                    f"making this an excellent discharging opportunity."
                )
            elif price > avg_price:
                reasoning_parts.append(
                    f"Price {price:.2f} EUR/MWh is above average ({avg_price:.2f}), "
                    f"making this a good discharging window."
                )
            reasoning_parts.append(
                f"Discharging at {discharge:.2f} kW to reduce grid imports during expensive hours."
            )
            
        else:  # IDLE
            price_position = "near average" if abs(price - avg_price) < 0.1 * avg_price else "moderate"
            reasoning_parts.append(
                f"Price {price:.2f} EUR/MWh is {price_position} ({avg_price:.2f} average), "
                f"not offering sufficient arbitrage opportunity to justify battery action."
            )
        
        # SOC constraints
        soc_change = soc_after - soc_before
        if abs(soc_change) > 0.001:
            reasoning_parts.append(
                f"SOC changes from {soc_before:.1%} to {soc_after:.1%} "
                f"({'increasing' if soc_change > 0 else 'decreasing'} by {abs(soc_change):.1%})."
            )
        
        # Constraint boundaries
        if soc_after >= battery_params.soc_max - 0.01:
            reasoning_parts.append("Battery reaches maximum SOC capacity.")
        elif soc_after <= battery_params.soc_min + 0.01:
            reasoning_parts.append("Battery reaches minimum SOC limit.")
        
        if decision_type == "CHARGE" and charge >= battery_params.cmax_kw - 0.01:
            reasoning_parts.append("Charging at maximum power capacity.")
        elif decision_type == "DISCHARGE" and discharge >= battery_params.dmax_kw - 0.01:
            reasoning_parts.append("Discharging at maximum power capacity.")
        
        # Energy balance
        net_demand = demand - discharge + charge
        reasoning_parts.append(f"Net grid import: {net_demand:.2f} kW to meet {demand:.2f} kW demand.")
        
        return {
            "hour": hour,
            "time_of_day": f"{hour:02d}:00",
            "electricity_price_eur_per_mwh": float(price),
            "price_vs_average": float(price - avg_price),
            "decision": decision_type,
            "charge_power_kw": float(charge),
            "discharge_power_kw": float(discharge),
            "soc_before_action": float(soc_before),
            "soc_after_action": float(soc_after),
            "soc_change": float(soc_change),
            "detailed_reasoning": " ".join(reasoning_parts)
        }
    
    def _generate_strategy_summary(
        self,
        hourly_decisions: List[Dict],
        prices: np.ndarray,                         # EUR/MWh
        charge_schedule: List[float],               # kWh/interval
        discharge_schedule: List[float],            # kWh/interval
        soc_trajectory: List[float],                # fraction in [0,1]
        *,
        # physics/econ
        eta_c: float = 0.95,
        eta_d: float = 0.95,
        var_om_eur_per_mwh: float = 0.0,
        schedule_side: str = "ac",                  # "ac" or "dc"
        dt_hours: float = 1.0,
        capacity_kwh: Optional[float] = None,
        p_max_kw: Optional[float] = None,
        soc_floor: float = 0.0,
        soc_ceiling: float = 1.0,
        # analysis windows
        kelly_window: int = 12,
        spread_window_future: int = 6,
        spread_window_past: int = 6,
        # output style
        verbose: bool = True,
        eps: float = 1e-6,
    ) -> str:
        import numpy as np

        T = len(prices)
        p = np.asarray(prices, float)
        ch = np.asarray(charge_schedule, float)
        dis = np.asarray(discharge_schedule, float)
        soc = np.asarray(soc_trajectory, float)
        if len(soc) == T:
            soc = np.concatenate([soc, soc[-1:]])

        # decisions → hour sets (fallback to schedules)
        charge_hours = [d['hour'] for d in hourly_decisions if d.get('decision') == 'CHARGE'] \
                    if hourly_decisions else list(np.where(ch > eps)[0])
        discharge_hours = [d['hour'] for d in hourly_decisions if d.get('decision') == 'DISCHARGE'] \
                        if hourly_decisions else list(np.where(dis > eps)[0])

        # Market-side energy for economics
        if schedule_side.lower() == "dc":
            grid_buy_mwh  = (ch / max(eta_c, 1e-9)) / 1000.0
            grid_sell_mwh = (dis * eta_d) / 1000.0
        else:
            grid_buy_mwh  = ch / 1000.0
            grid_sell_mwh = dis / 1000.0

        # Totals & prices
        total_charged_kwh = float(ch.sum())
        total_discharged_kwh = float(dis.sum())
        idle_hours = int(np.sum((ch <= eps) & (dis <= eps)))
        avg_charge_price = float(np.mean(p[charge_hours])) if charge_hours else 0.0
        avg_discharge_price = float(np.mean(p[discharge_hours])) if discharge_hours else 0.0

        # Economics
        revenue_eur = float(np.sum(p * grid_sell_mwh))
        energy_cost_eur = float(np.sum(p * grid_buy_mwh))
        throughput_mwh = float(grid_buy_mwh.sum() + grid_sell_mwh.sum())
        vom_eur = var_om_eur_per_mwh * throughput_mwh
        gross_profit_eur = revenue_eur - energy_cost_eur
        net_profit_eur = gross_profit_eur - vom_eur

        # Spread tests (temporal arbitrage condition)
        eta_rt = eta_c * eta_d
        spread_realized = (avg_discharge_price - avg_charge_price) if (charge_hours and discharge_hours) else 0.0
        spread_required = (
            avg_charge_price * (1.0/max(eta_rt,1e-9) - 1.0) + var_om_eur_per_mwh / max(eta_rt, 1e-9)
        ) if charge_hours else 0.0
        spread_margin = spread_realized - spread_required if (charge_hours and discharge_hours) else 0.0
        arbitrage_profitable = (spread_margin > 0.0)

        # Quantile diagnostics
        q25, q50, q75 = np.quantile(p, [0.25, 0.50, 0.75])
        frac_charge_below_q25 = (np.mean(p[charge_hours] <= q25) if charge_hours else 0.0)
        frac_discharge_above_q75 = (np.mean(p[discharge_hours] >= q75) if discharge_hours else 0.0)

        # Streaks
        def longest_streak(mask):
            best = cur = 0
            for v in mask:
                cur = cur + 1 if v else 0
                best = max(best, cur)
            return best
        longest_charge_streak = longest_streak(ch > eps)
        longest_discharge_streak = longest_streak(dis > eps)

        # Bindings/utilization
        at_floor = int(np.sum(soc[:-1] <= (soc_floor + 1e-9)))
        at_ceiling = int(np.sum(soc[:-1] >= (soc_ceiling - 1e-9)))
        charge_at_pmax = discharge_at_pmax = 0
        if p_max_kw is not None and dt_hours > 0:
            charge_at_pmax = int(np.sum(ch >= (0.99 * p_max_kw * dt_hours)))
            discharge_at_pmax = int(np.sum(dis >= (0.99 * p_max_kw * dt_hours)))
        time_utilization = 1.0 - (idle_hours / max(T,1))
        energy_utilization = (
            (total_charged_kwh + total_discharged_kwh) / (T * (p_max_kw * dt_hours))
            if (p_max_kw is not None and p_max_kw > 0) else None
        )
        fce = (
            (total_charged_kwh + total_discharged_kwh) / (2.0 * capacity_kwh)
            if (capacity_kwh is not None and capacity_kwh > 0) else None
        )

        # SoC band
        min_soc, max_soc = float(np.min(soc)), float(np.max(soc))
        soc_range = max_soc - min_soc
        soc_drift = float(soc[-1] - soc[0])

        # Arbitrage indicator & rolling spread windows
        dp = np.diff(p, prepend=p[0])
        p_avg = p.mean()
        p_std = p.std(ddof=0) if p.std(ddof=0) > 0 else 1.0
        A = (p < (p_avg - p_std)).astype(int) - (p > (p_avg + p_std)).astype(int)    # +1 charge, -1 discharge, 0 idle
        frac_A_pos = float(np.mean(A == 1.0))
        frac_A_neg = float(np.mean(A == -1.0))
        act = np.zeros(T, dtype=int); act[ch > eps] = 1; act[dis > eps] = -1
        indicator_alignment = float(np.mean(A == act)) if T > 0 else 0.0

        fw = spread_window_future
        bw = spread_window_past
        future_max = np.array([p[t+1:t+1+fw].max() if t+1 < T else p[t] for t in range(T)])
        past_min   = np.array([p[max(0, t-bw):t].min() if t > 0 else p[t] for t in range(T)])
        spread_future = future_max - p
        spread_past   = p - past_min
        feas_charge = int(np.sum(spread_future > spread_required))
        feas_discharge = int(np.sum(spread_past   > spread_required))

        # Arbitrage Index
        denom = float(np.sum(np.abs(dp))) or 1.0
        arbitrage_index = float(np.sum(np.maximum(0.0, dp)) / denom)

        # Kelly proxy
        r = np.zeros(T); r[1:] = (p[1:] - p[:-1]) / np.maximum(p[:-1], 1e-12)
        mu = np.array([np.mean(r[max(0,t-kelly_window+1):t+1]) for t in range(T)])
        var = np.array([np.var (r[max(0,t-kelly_window+1):t+1], ddof=0) for t in range(T)])
        var[var < 1e-12] = 1e-12
        kelly_f = np.clip(mu / var, -1.0, 1.0)
        ksign = np.sign(kelly_f)
        nonzero = (act != 0)
        kelly_alignment = float(np.mean(ksign[nonzero] == act[nonzero])) if np.any(nonzero) else 0.0
        avg_kelly_mag_charge = float(np.mean(np.abs(kelly_f[act == 1]))) if np.any(act == 1) else 0.0
        avg_kelly_mag_discharge = float(np.mean(np.abs(kelly_f[act == -1]))) if np.any(act == -1) else 0.0

        # ---------- Formatting helpers ----------
        pct = lambda x: f"{100.0 * x:.1f}%"
        eur = lambda x: f"{x:,.2f} €"
        mwh = lambda x: f"{x:.3f} MWh"
        kwh = lambda x: f"{x:.1f} kWh"

        parts = []

        if verbose:
            # Charging paragraph
            if charge_hours:
                parts.append(
                    (
                        f"The strategy charges the battery during {len(charge_hours)} distinct hours when market prices "
                        f"tend to be relatively low. The average purchase price paid during charging is "
                        f"{avg_charge_price:.2f} €/MWh. Notably, {pct(frac_charge_below_q25)} of all charging activity occurs "
                        f"when the price falls at or below the 25th percentile of the entire price distribution, indicating a "
                        f"systematic preference for the cheapest periods. In total, the battery buys {kwh(total_charged_kwh)} "
                        f"of energy from the grid over the horizon."
                    )
                )
            # Discharging paragraph
            if discharge_hours:
                parts.append(
                    (
                        f"The strategy discharges during {len(discharge_hours)} hours that coincide with elevated prices. "
                        f"The average selling price realized during discharging is {avg_discharge_price:.2f} €/MWh. "
                        f"Furthermore, {pct(frac_discharge_above_q75)} of discharging occurs when the price is at or above the "
                        f"75th percentile, which confirms that the policy concentrates sales in the most lucrative windows. "
                        f"Across the horizon, the system sells {kwh(total_discharged_kwh)} of energy back to the grid."
                    )
                )
            # Arbitrage economics paragraph
            if charge_hours and discharge_hours:
                profitability_text = (
                    "This realized spread exceeds the minimum required spread once round-trip efficiency losses and variable "
                    "O&M are accounted for, which implies that the temporal arbitrage executed by the policy is profitable on average."
                    if arbitrage_profitable else
                    "This realized spread does not meet the minimum required threshold after accounting for round-trip efficiency "
                    "losses and variable O&M, suggesting that the temporal arbitrage is not profitable on average."
                )
                parts.append(
                    (
                        f"From a temporal arbitrage perspective, the policy buys low and sells high: the realized average spread "
                        f"between selling and buying prices is {spread_realized:.2f} €/MWh. Based on a round-trip efficiency of "
                        f"{eta_rt:.3f} and a variable O&M estimate of {var_om_eur_per_mwh:.2f} €/MWh, the required break-even spread "
                        f"is {spread_required:.2f} €/MWh. The resulting margin is {spread_margin:.2f} €/MWh. {profitability_text}"
                    )
                )
            # P&L paragraph
            parts.append(
                (
                    f"Economic results reflect these mechanics: total market revenue from sales is {eur(revenue_eur)}, while "
                    f"the cost of energy purchased is {eur(energy_cost_eur)}. On a throughput of {mwh(throughput_mwh)}, the variable "
                    f"O&M cost amounts to {eur(vom_eur)}. Combining these terms yields a gross profit of {eur(gross_profit_eur)} and "
                    f"a net profit of {eur(net_profit_eur)} over the analyzed period."
                )
            )
            # SoC management paragraph
            parts.append(
                (
                    f"The state of charge (SoC) operates within a band from {100*min_soc:.1f}% to {100*max_soc:.1f}% "
                    f"(a utilization range of {pct(soc_range)}). Over the full horizon, the SoC exhibits a net drift of "
                    f"{pct(soc_drift)}, indicating the overall tendency to end with a higher or lower inventory relative to the start. "
                    f"The trajectory reaches the SoC floor on {at_floor} hour(s) and the SoC ceiling on {at_ceiling} hour(s), which "
                    f"highlights how often the energy capacity constraints are binding. The asset remains idle for {idle_hours} hour(s), "
                    f"and the longest continuous charging and discharging streaks are {longest_charge_streak} and "
                    f"{longest_discharge_streak} hour(s), respectively."
                )
            )
            # Power/Utilization/Cycles paragraph
            if p_max_kw is not None:
                hit_text = (
                    f"Charging power hits its upper limit during {charge_at_pmax} hour(s) and discharging power hits its "
                    f"upper limit during {discharge_at_pmax} hour(s), suggesting frequent operation at nameplate constraints."
                )
            else:
                hit_text = "Power limit diagnostics were not evaluated because a nameplate limit was not provided."
            util_text = (
                f"The time-based utilization is {pct(time_utilization)} of the horizon. "
                f"{('Energy utilization relative to the power limit is ' + pct(energy_utilization) + '. ') if energy_utilization is not None else ''}"
                f"{('Equivalent full cycles are estimated at ' + f'{fce:.2f}' + ' over the period, which contextualizes throughput vs. capacity.') if fce is not None else ''}"
            )
            parts.append(hit_text + " " + util_text)

            # Indicator paragraph
            parts.append(
                (
                    f"To interpret action timing independently of the optimizer, we compute an arbitrage indicator "
                    f"A_t = 1{{p_t < μ_p − σ_p}} − 1{{p_t > μ_p + σ_p}}, where μ_p and σ_p are the sample mean and standard deviation of price. "
                    f"According to this heuristic, {pct(frac_A_pos)} of intervals suggested charging and {pct(frac_A_neg)} suggested discharging. "
                    f"The realized actions agree with this indicator in {pct(indicator_alignment)} of intervals, indicating the degree to which the "
                    f"policy follows a simple buy-low/sell-high rule of thumb."
                )
            )

            # Rolling spread window paragraph
            parts.append(
                (
                    f"We also examine rolling spread windows to quantify actionable opportunities. For each interval, the forward-looking "
                    f"{spread_window_future}-hour window identifies the best future price relative to the current price, yielding a potential "
                    f"charge-side spread; symmetrically, the backward-looking {spread_window_past}-hour window compares the current price with the "
                    f"minimum of recent prices, yielding a discharge-side spread. Counting only windows whose spread exceeds the break-even requirement "
                    f"({spread_required:.2f} €/MWh), we find {feas_charge} feasible charge opportunities and {feas_discharge} feasible "
                    f"discharge opportunities. These counts reflect how often the market provides economically meaningful temporal arbitrage."
                )
            )

            # Arbitrage Index paragraph
            parts.append(
                (
                    f"As a summary measure of directional structure versus oscillation in the price series, the Arbitrage Index is "
                    f"AI = Σ max(0, Δp_t) / Σ |Δp_t| = {arbitrage_index:.2f}. Values near 0.5 indicate a balanced, mean-reverting profile "
                    f"that is typically favorable for short-horizon arbitrage; values approaching 1 indicate persistent uptrends (fewer buy-low "
                    f"opportunities), while values approaching 0 indicate persistent downtrends (fewer sell-high opportunities)."
                )
            )

            # Kelly proxy paragraph
            parts.append(
                (
                    f"Finally, we compare actions to a Kelly-style control signal derived from rolling price returns. Using a window of {kelly_window} "
                    f"periods, we estimate the expected return and variance and form a bounded Kelly fraction f*_t ≈ μ_t / σ_t². The sign of this fraction "
                    f"indicates whether charging (positive) or discharging (negative) would maximize expected log-growth, while its magnitude reflects "
                    f"the suggested aggressiveness. Over intervals with non-zero actions, the sign of the Kelly proxy matches realized actions in "
                    f"{pct(kelly_alignment)} of cases. Conditional on action, the average recommended aggressiveness is {avg_kelly_mag_charge:.2f} "
                    f"while charging and {avg_kelly_mag_discharge:.2f} while discharging, providing a risk-aware benchmark for allocation sizing."
                )
            )

        else:
            # Concise mode (previous short form)
            if charge_hours:
                parts.append(
                    f"Charging: {len(charge_hours)} h @ {avg_charge_price:.2f} €/MWh, {pct(frac_charge_below_q25)} ≤ Q25, total {kwh(total_charged_kwh)}."
                )
            if discharge_hours:
                parts.append(
                    f"Discharging: {len(discharge_hours)} h @ {avg_discharge_price:.2f} €/MWh, {pct(frac_discharge_above_q75)} ≥ Q75, total {kwh(total_discharged_kwh)}."
                )
            if charge_hours and discharge_hours:
                parts.append(
                    f"Arbitrage: spread {spread_realized:.2f} vs req {spread_required:.2f} (η_rt={eta_rt:.3f}, VOM={var_om_eur_per_mwh:.2f}); margin {spread_margin:.2f} → "
                    f"{'profitable' if arbitrage_profitable else 'not profitable'}."
                )
            parts.append(
                f"Economics: revenue {eur(revenue_eur)}, energy cost {eur(energy_cost_eur)}, VOM {eur(vom_eur)} on {mwh(throughput_mwh)} → net {eur(net_profit_eur)}."
            )
            parts.append(
                f"SoC: {100*min_soc:.1f}%–{100*max_soc:.1f}% (range {pct(soc_range)}), drift {pct(soc_drift)}; floor {at_floor} h, ceiling {at_ceiling} h; idle {idle_hours} h."
            )
            if p_max_kw is not None:
                parts.append(f"Power limits hit: charge {charge_at_pmax} h, discharge {discharge_at_pmax} h.")
            parts.append(f"Streaks: charge {longest_charge_streak} h, discharge {longest_discharge_streak} h.")
            if energy_utilization is not None:
                parts.append(f"Energy util vs P_max: {pct(energy_utilization)}.")
            if fce is not None:
                parts.append(f"FCE: {fce:.2f}.")
            parts.append(
                f"A_t: {pct(frac_A_pos)} charge, {pct(frac_A_neg)} discharge; align {pct(indicator_alignment)}. "
                f"Windows(±{spread_window_past}/{spread_window_future}h): charge {feas_charge} h, discharge {feas_discharge} h. "
                f"AI={arbitrage_index:.2f}. Kelly align {pct(kelly_alignment)}."
            )

        return " ".join(parts)

    
    def _calculate_performance_metrics(
        self,
        charge_schedule: List[float],
        discharge_schedule: List[float],
        soc_trajectory: List[float],
        prices: np.ndarray,
        import_grid: List[float],
        export_grid: List[float],
        battery_params: BatteryParams
    ) -> Dict[str, float]:
        """Calculate various performance metrics"""
        
        total_charged_kwh = sum(charge_schedule)
        total_discharged_kwh = sum(discharge_schedule)
        
        # Energy metrics
        energy_throughput = total_charged_kwh + total_discharged_kwh
        cycle_count = total_charged_kwh / battery_params.capacity_kwh if battery_params.capacity_kwh > 0 else 0
        
        # Financial metrics
        charge_cost = sum(c * p / 1000 for c, p in zip(charge_schedule, prices))
        discharge_revenue = sum(d * p / 1000 for d, p in zip(discharge_schedule, prices))
        arbitrage_profit = discharge_revenue - charge_cost
        
        # Efficiency metrics
        roundtrip_efficiency = (total_discharged_kwh / total_charged_kwh * 100) if total_charged_kwh > 0 else 0
        theoretical_efficiency = battery_params.eta_c * battery_params.eta_d * 100
        
        # Utilization metrics
        avg_soc = np.mean(soc_trajectory)
        soc_range = max(soc_trajectory) - min(soc_trajectory)
        capacity_utilization = soc_range * 100
        
        return {
            "total_energy_charged_kwh": float(total_charged_kwh),
            "total_energy_discharged_kwh": float(total_discharged_kwh),
            "net_energy_flow_kwh": float(total_charged_kwh - total_discharged_kwh),
            "energy_throughput_kwh": float(energy_throughput),
            "cycle_count": float(cycle_count),
            "charging_cost_eur": float(charge_cost),
            "discharging_revenue_eur": float(discharge_revenue),
            "arbitrage_profit_eur": float(arbitrage_profit),
            "realized_roundtrip_efficiency_percent": float(roundtrip_efficiency),
            "theoretical_roundtrip_efficiency_percent": float(theoretical_efficiency),
            "average_soc": float(avg_soc),
            "soc_range_utilized": float(soc_range),
            "capacity_utilization_percent": float(capacity_utilization)
        }
    
    def _validate_solution(
        self,
        charge_schedule: List[float],
        discharge_schedule: List[float],
        soc_trajectory: List[float],
        battery_params: BatteryParams
    ) -> bool:
        """Validate that solution satisfies all constraints"""
        
        # Check SOC bounds
        soc_valid = all(
            battery_params.soc_min - 1e-6 <= s <= battery_params.soc_max + 1e-6 
            for s in soc_trajectory
        )
        
        # Check power bounds
        power_valid = all(
            c <= battery_params.cmax_kw + 1e-6 and 
            d <= battery_params.dmax_kw + 1e-6
            for c, d in zip(charge_schedule, discharge_schedule)
        )
        
        # Check no simultaneous charge/discharge
        no_simultaneous = all(
            c < 1e-6 or d < 1e-6 
            for c, d in zip(charge_schedule, discharge_schedule)
        )
        
        return soc_valid and power_valid and no_simultaneous
    
    def _get_season(self, date_str: str) -> str:
        """Determine season from date"""
        month = datetime.strptime(date_str, "%Y-%m-%d").month
        if month in [12, 1, 2]:
            return "winter"
        elif month in [3, 4, 5]:
            return "spring"
        elif month in [6, 7, 8]:
            return "summer"
        else:
            return "fall"
        

def generate_dataset_for_region(
    dataset_name: str,
    split: str,
    battery_params: BatteryParams,
    data_dir: str = "./data",
    output_dir: str = "./datasets",
    max_days: int = None
) -> List[Dict[str, str]]:
    """
    Generate fine-tuning dataset for a specific region and split
    
    Args:
        dataset_name: 'italy', 'germany', or 'caiso'
        split: 'train' or 'test'
        battery_params: Battery configuration
        data_dir: Directory containing CSV files
        output_dir: Directory to save output datasets
        max_days: Maximum number of days to process (None = all)
        
    Returns:
        List of dataset entries
    """
    print(f"\n{'='*80}")
    print(f"GENERATING DATASET: {dataset_name.upper()} - {split.upper()}")
    print(f"{'='*80}\n")
    
    # Load data
    loader = MultiDatasetLoader(data_dir=data_dir)
    df = loader.load_dataset(dataset_name, split)
    daily_batches = loader.get_daily_batches(df)
    
    if max_days:
        daily_batches = daily_batches[:max_days]
        print(f"Limiting to first {max_days} days")
    
    # Initialize generator
    generator = DatasetGenerator(battery_config={})
    
    # Generate dataset
    dataset = []
    failures = []
    
    for idx, (date, prices, demand) in enumerate(daily_batches):
        try:
            # Create day inputs
            day = DayInputs(
                prices_buy=prices.tolist(),
                demand_kw=demand.tolist(),
                prices_sell=prices.tolist(),
                allow_export=False,
                dt_hours=1.0
            )
            
            # Run MILP solver
            milp_solution = solve_daily_milp(
                batt=battery_params,
                day=day,
                solver=None, 
                solver_opts={}
            )

# # Run MILP solver
# milp_solution = solve_daily_milp(
#     batt=battery_params,
#     day=day,
#     solver="GUROBI", 
#     solver_opts={}
# )
            
            # Check if solution is valid
            if milp_solution.status not in ["optimal", "optimal_inaccurate"]:
                print(f"  Warning: Day {date} - MILP status: {milp_solution.status}")
                failures.append({'date': date, 'reason': milp_solution.status})
                continue
            
            # Create dataset entry
            input_str = generator.create_input_component(
                date=date,
                prices=prices,
                demand=demand,
                battery_params=battery_params,
                dataset_name=dataset_name
            )
            
            output_str = generator.create_output_component(
                milp_solution=milp_solution,
                prices=prices,
                demand=demand,
                battery_params=battery_params
            )
            
            dataset.append({
                'instruction': generator.instruction,
                'input': input_str,
                'output': output_str
            })
            
            # Update battery SOC for next day
            if milp_solution.soc and len(milp_solution.soc) > 0:
                battery_params.soc_init = milp_solution.soc[-1]
            
            if (idx + 1) % 50 == 0:
                print(f"  Processed {idx + 1}/{len(daily_batches)} days")
                
        except Exception as e:
            print(f"  Error processing day {date}: {e}")
            failures.append({'date': date, 'reason': str(e)})
            continue
    
    print(f"\n{'='*80}")
    print(f"GENERATION COMPLETE: {dataset_name.upper()} - {split.upper()}")
    print(f"  Successful: {len(dataset)} examples")
    print(f"  Failed: {len(failures)} examples")
    print(f"{'='*80}\n")
    
    # Save dataset
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True, parents=True)
    
    output_file = output_path / f"{dataset_name}_{split}.json"
    with open(output_file, 'w') as f:
        json.dump(dataset, f, indent=2)
    
    print(f"Dataset saved to: {output_file}")
    print(f"Size: {len(json.dumps(dataset)) / 1024 / 1024:.2f} MB\n")
    
    # Save failures log if any
    if failures:
        failures_file = output_path / f"{dataset_name}_{split}_failures.json"
        with open(failures_file, 'w') as f:
            json.dump(failures, f, indent=2)
        print(f"Failures logged to: {failures_file}\n")
    
    return dataset


def combine_datasets(
    datasets: List[List[Dict]],
    output_path: str
):
    """
    Combine multiple datasets into one
    """
    print(f"\nCombining {len(datasets)} datasets...")
    
    combined = []
    for ds in datasets:
        combined.extend(ds)
    
    # Shuffle
    import random
    random.seed(42)
    random.shuffle(combined)
    
    # Save
    with open(output_path, 'w') as f:
        json.dump(combined, f, indent=2)
    
    print(f"Combined dataset saved to: {output_path}")
    print(f"Total examples: {len(combined)}")
    print(f"Size: {len(json.dumps(combined)) / 1024 / 1024:.2f} MB\n")


# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="Generate fine-tuning datasets from MILP solutions"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        choices=['italy', 'germany', 'caiso', 'all'],
        default='all',
        help="Which dataset to generate"
    )
    parser.add_argument(
        "--split",
        type=str,
        choices=['train', 'test', 'both'],
        default='both',
        help="Which split to generate"
    )
    parser.add_argument(
        "--data-dir",
        type=str,
        default="./data",
        help="Directory containing CSV files"
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="./datasets",
        help="Directory to save output datasets"
    )
    parser.add_argument(
        "--max-days",
        type=int,
        default=None,
        help="Maximum number of days per dataset (for testing)"
    )
    parser.add_argument(
        "--generate-all",
        action="store_true",
        help="Generate all datasets (all regions, train and test)"
    )
    
    args = parser.parse_args()
    
    # Define battery configuration
    # Adjust these based on your actual battery specs
    battery_params = BatteryParams(
        capacity_kwh=21.89,
        cmax_kw=5.47,
        dmax_kw=5.47,
        eta_c=0.95,
        eta_d=0.95,
        soc_init=0.5,
        soc_min=0.0,
        soc_max=1.0,
        soc_target=0.5
    )
    
    print("\n" + "="*80)
    print("ENERGY STORAGE DATASET GENERATION")
    print("="*80)
    print("\nBattery Configuration:")
    print(f"  Capacity: {battery_params.capacity_kwh} kWh")
    print(f"  Max Charge: {battery_params.cmax_kw} kW")
    print(f"  Max Discharge: {battery_params.dmax_kw} kW")
    print(f"  Efficiency: {battery_params.eta_c}/{battery_params.eta_d}")
    print(f"  SOC Range: [{battery_params.soc_min}, {battery_params.soc_max}]")
    
    # Determine which datasets to generate
    if args.generate_all or args.dataset == 'all':
        datasets_to_generate = ['italy', 'germany', 'caiso']
    else:
        datasets_to_generate = [args.dataset]
    
    # Determine which splits to generate
    if args.split == 'both':
        splits_to_generate = ['train', 'test']
    else:
        splits_to_generate = [args.split]
    
    # Generate datasets
    all_train_datasets = []
    all_test_datasets = []
    
    for dataset_name in datasets_to_generate:
        # Reset battery SOC for each region
        battery_params.soc_init = 0.5
        
        for split in splits_to_generate:
            dataset = generate_dataset_for_region(
                dataset_name=dataset_name,
                split=split,
                battery_params=battery_params,
                data_dir=args.data_dir,
                output_dir=args.output_dir,
                max_days=args.max_days
            )
            
            if split == 'train':
                all_train_datasets.append(dataset)
            else:
                all_test_datasets.append(dataset)
    
    # Combine all training datasets
    if len(all_train_datasets) > 1:
        combine_datasets(
            all_train_datasets,
            f"{args.output_dir}/combined_train.json"
        )
    
    # Combine all test datasets
    if len(all_test_datasets) > 1:
        combine_datasets(
            all_test_datasets,
            f"{args.output_dir}/combined_test.json"
        )
    
    print("\n" + "="*80)
    print("DATASET GENERATION PIPELINE COMPLETE!")
    print("="*80)
    print("\nNext steps:")
    print("  1. Review generated datasets in:", args.output_dir)
    print("  2. Use combined_train.json for fine-tuning")
    print("  3. Use combined_test.json or individual test sets for evaluation")
    print("  4. Run: python train_with_unsloth.py --dataset ./datasets/combined_train.json")
    print("\n")

from agentic_energy.schemas import BatteryParams

battery_params = BatteryParams(
    capacity_kwh=49.44,
    cmax_kw=12.36,
    dmax_kw=12.36,
    eta_c=0.95,
    eta_d=0.95,
    soc_init=0.5,
    soc_min=0.0,
    soc_max=1.0,
    soc_target=0.5
)

# dataset = generate_dataset_for_region(
#     dataset_name='caiso',
#     split='train',
#     battery_params=battery_params,
#     data_dir="./agentic_energy/data",
#     output_dir="./datasets"
# )

Forecast Engine using device: cpu


In [3]:
import json

# Load the three individual datasets
with open('./datasets/italy_train.json', 'r') as f:
    italy_train = json.load(f)

with open('./datasets/germany_train.json', 'r') as f:
    germany_train = json.load(f)

with open('./datasets/caiso_train.json', 'r') as f:
    caiso_train = json.load(f)

# Combine them
combine_datasets(
    datasets=[italy_train, germany_train, caiso_train],
    output_path='./datasets/combined_train.json'
)

print(f"\n✅ Combined dataset created!")
print(f"   - Italy: {len(italy_train)} examples")
print(f"   - Germany: {len(germany_train)} examples")
print(f"   - CAISO: {len(caiso_train)} examples")
print(f"   - Total: {len(italy_train) + len(germany_train) + len(caiso_train)} examples")


Combining 3 datasets...
Combined dataset saved to: ./datasets/combined_train.json
Total examples: 2329
Size: 60.74 MB


✅ Combined dataset created!
   - Italy: 364 examples
   - Germany: 1333 examples
   - CAISO: 632 examples
   - Total: 2329 examples


In [1]:
import os
import json
import torch
# from datasets import Dataset, load_dataset
# from transformers import TrainingArguments
# from trl import SFTTrainer
# from unsloth import FastLanguageModel

In [2]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

Using device: cpu
