In [1]:
"""
Corrected Dataset Generator - Matches Runtime DayInputs Format
Uses exact field names and output schema from runtime
"""

import json
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Dict

from agentic_energy.schemas import BatteryParams, DayInputs, SolveRequest, SolveResponse
from agentic_energy.data_loader import BatteryDataLoader
from agentic_energy.milp.milp_mcp_server import solve_daily_milp

from agentic_energy.milp.milp_mcp_server import solve_daily_milp
import asyncio

class CorrectedDatasetGenerator:
    """Generate training data matching runtime DayInputs/SolveResponse format"""
    
    def __init__(self, battery_params: BatteryParams):
        self.battery_params = battery_params
        self.instruction = self._create_instruction()
    
    def _create_instruction(self) -> str:
        """Create instruction matching runtime format"""
        bp = self.battery_params
        
        # Simplified but compatible instruction
        return f"""Solve battery optimization using forecast data to minimize costs.

BATTERY: {bp.capacity_kwh:.2f} kWh, {bp.cmax_kw:.2f} kW charge/discharge, efficiency {bp.eta_c}/{bp.eta_d}
INITIAL SOC: {bp.soc_init:.2f}, TARGET: {bp.soc_target:.2f}, RANGE: [{bp.soc_min:.2f}, {bp.soc_max:.2f}]

CONSTRAINTS:
- SOC[t+1] = SOC[t] + (charge[t]*{bp.eta_c} - discharge[t]/{bp.eta_d})*dt/{bp.capacity_kwh}
- 0 ≤ charge[t] ≤ {bp.cmax_kw}, 0 ≤ discharge[t] ≤ {bp.dmax_kw}
- {bp.soc_min} ≤ SOC[t] ≤ {bp.soc_max}
- charge[t] * discharge[t] = 0 (no simultaneous)
- SOC[0] = {bp.soc_init}, SOC[T] ≥ {bp.soc_target}

STRATEGY: Use prices_buy_forecast, prices_sell_forecast, demand_kw_forecast to decide charge/discharge.
Charge when forecast price < p25, discharge when > p75. Calculate cost with actual prices.

OUTPUT: JSON with status, message, objective_cost, charge_kw, discharge_kw, import_kw, export_kw, soc, decision."""
    
    def create_input(
        self,
        date: str,
        prices: np.ndarray,
        demand: np.ndarray,
        dataset_name: str = None
    ) -> str:
        """Create input in DayInputs format"""
        
        # Calculate stats for strategy hints
        p25 = float(np.percentile(prices, 25))
        p75 = float(np.percentile(prices, 75))
        
        # Exact DayInputs format
        input_dict = {
            "prices_buy": [round(float(p), 2) for p in prices],
            "prices_sell": [round(float(p), 2) for p in prices],
            "demand_kw": [round(float(d), 2) for d in demand],
            "allow_export": True,
            "dt_hours": 1.0,
            "prices_buy_forecast": [round(float(p), 2) for p in prices],
            "prices_sell_forecast": [round(float(p), 2) for p in prices],
            "demand_kw_forecast": [round(float(d), 2) for d in demand],
            "p25": round(p25, 2),
            "p75": round(p75, 2)
        }
        
        return json.dumps(input_dict, separators=(',', ':'))
    
    def create_output(
        self,
        milp_solution: SolveResponse,
        prices: np.ndarray
    ) -> str:
        """Create output in SolveResponse format"""
        
        T = len(prices)
        charge = milp_solution.charge_kw or [0.0] * T
        discharge = milp_solution.discharge_kw or [0.0] * T
        soc = milp_solution.soc or [self.battery_params.soc_init] * (T + 1)
        cost = milp_solution.objective_cost or 0.0
        
        # Calculate import/export
        import_kw = milp_solution.import_kw or [0.0] * T
        export_kw = milp_solution.export_kw or [0.0] * T
        
        # Calculate decision
        decision = [1.0 if c > 0.01 else (-1.0 if d > 0.01 else 0.0) 
                   for c, d in zip(charge, discharge)]
        
        # Brief strategy message
        charge_hours = sum(1 for d in decision if d > 0)
        discharge_hours = sum(1 for d in decision if d < 0)
        
        if charge_hours > 0 and discharge_hours > 0:
            msg = f"Charge {charge_hours}h, discharge {discharge_hours}h"
        elif charge_hours > 0:
            msg = f"Charge {charge_hours}h only"
        elif discharge_hours > 0:
            msg = f"Discharge {discharge_hours}h only"
        else:
            msg = "No arbitrage"
        
        # Match SolveResponse schema exactly
        output_dict = {
            "status": "success",
            "message": msg,
            "objective_cost": round(float(cost), 2),
            "charge_kw": [round(float(c), 2) for c in charge],
            "discharge_kw": [round(float(d), 2) for d in discharge],
            "import_kw": [round(float(i), 2) for i in import_kw],
            "export_kw": [round(float(e), 2) for e in export_kw],
            "soc": [round(float(s), 4) for s in soc],
            "decision": [float(d) for d in decision]
        }
        
        return json.dumps(output_dict, separators=(',', ':'))
    
    def create_training_example(
        self,
        date: str,
        prices: np.ndarray,
        demand: np.ndarray,
        milp_solution: SolveResponse,
        dataset_name: str = None
    ) -> Dict[str, str]:
        """Create complete training example"""
        
        return {
            "instruction": self.instruction,
            "input": self.create_input(date, prices, demand, dataset_name),
            "output": self.create_output(milp_solution, prices)
        }


def generate_corrected_dataset(
    dataset_name: str,
    split: str,
    battery_params: BatteryParams,
    data_dir: str = "./agentic_energy/data",
    output_dir: str = "./datasets_corrected",
    max_days: int = None
) -> List[Dict[str, str]]:
    """Generate corrected dataset"""
    
    print(f"\n{'='*80}")
    print(f"GENERATING: {dataset_name.upper()} ({split})")
    print(f"{'='*80}")
    
    from agentic_energy.data_loader import EnergyDataLoader
    
    # Load data
    loader = EnergyDataLoader(region=dataset_name.upper(), data_version="actual")
    data = loader.load_region_data()
    
    # Split data
    split_idx = int(len(data) * 0.5)
    if split == "train":
        data_subset = data[:split_idx]
    else:
        data_subset = data[split_idx:]
    
    # Limit if requested
    if max_days:
        data_subset = data_subset[:max_days * 24]
    
    generator = CorrectedDatasetGenerator(battery_params)
    examples = []
    
    # Process in 24-hour chunks
    for i in range(0, len(data_subset) - 23, 24):
        day_data = data_subset[i:i+24]
        
        prices = np.array([d.prices for d in day_data], dtype=float)
        demand = np.array([d.consumption for d in day_data], dtype=float)
        date = day_data[0].timestamps[:10] if hasattr(day_data[0], 'timestamps') else f"day_{i//24}"
        
        # Create DayInputs
        day_inputs = DayInputs(
            prices_buy=prices.tolist(),
            prices_sell=prices.tolist(),
            demand_kw=demand.tolist(),
            allow_export=True,
            dt_hours=1.0,
            prices_buy_forecast=prices.tolist(),
            prices_sell_forecast=prices.tolist(),
            demand_kw_forecast=demand.tolist()
        )
        
        # Solve with MILP
        req = SolveRequest(
            battery=battery_params,
            day=day_inputs,
            solver=None,
            solver_opts=None
        )
        
        try:
            # milp_solution = solve_daily_milp(req)
            milp_solution = asyncio.run(solve_daily_milp(req)) if asyncio.iscoroutinefunction(solve_daily_milp) else solve_daily_milp(req)

            
            if milp_solution.status == "success":
                example = generator.create_training_example(
                    date, prices, demand, milp_solution, dataset_name
                )
                examples.append(example)
                
                if len(examples) % 50 == 0:
                    print(f"  Generated {len(examples)} examples...")
        except Exception as e:
            print(f"  Skipping day {i//24}: {e}")
            continue
    
    # Save
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    filename = f"{dataset_name}_{split}_corrected.json"
    filepath = output_path / filename
    
    with open(filepath, 'w') as f:
        json.dump(examples, f, indent=2)
    
    print(f"\n✓ Saved {len(examples)} examples to {filepath}")
    print(f"  Size: {filepath.stat().st_size / (1024**2):.2f} MB")
    
    return examples


def analyze_token_lengths(dataset_path: str):
    """Analyze token lengths"""
    import tiktoken
    
    enc = tiktoken.get_encoding("cl100k_base")
    
    with open(dataset_path, 'r') as f:
        data = json.load(f)
    
    lengths = []
    for ex in data:
        full_text = f"{ex['instruction']}\n\nINPUT:\n{ex['input']}\n\nOUTPUT:\n{ex['output']}"
        tokens = len(enc.encode(full_text))
        lengths.append(tokens)
    
    lengths = np.array(lengths)
    
    print(f"\n{'='*80}")
    print("TOKEN LENGTH ANALYSIS")
    print(f"{'='*80}")
    print(f"Dataset: {dataset_path}")
    print(f"Examples: {len(data)}")
    print(f"\nToken Statistics:")
    print(f"  Min:    {int(np.min(lengths)):,}")
    print(f"  Max:    {int(np.max(lengths)):,}")
    print(f"  Mean:   {int(np.mean(lengths)):,}")
    print(f"  Median: {int(np.median(lengths)):,}")
    print(f"  P95:    {int(np.percentile(lengths, 95)):,}")
    print(f"  P99:    {int(np.percentile(lengths, 99)):,}")
    
    for limit in [2048, 4096, 8192]:
        truncated = np.sum(lengths > limit)
        pct = 100 * truncated / len(lengths)
        print(f"\n  With max_seq_length={limit}:")
        print(f"    Truncated: {truncated}/{len(lengths)} ({pct:.1f}%)")
    
    # Show example
    print(f"\n{'='*80}")
    print("EXAMPLE (First Training Instance)")
    print(f"{'='*80}")
    ex = data[0]
    print(f"\nInstruction length: {len(ex['instruction'])} chars")
    print(f"Input length: {len(ex['input'])} chars")
    print(f"Output length: {len(ex['output'])} chars")
    print(f"Total tokens: {lengths[0]:,}")
    print(f"\n--- INPUT SAMPLE (first 200 chars) ---")
    print(ex['input'][:200] + "...")
    print(f"\n--- OUTPUT SAMPLE (first 200 chars) ---")
    print(ex['output'][:200] + "...")


if __name__ == "__main__":
    # Battery configuration - matches your setup
    battery_params = BatteryParams(
        capacity_kwh=49.44,
        cmax_kw=12.36,
        dmax_kw=12.36,
        eta_c=0.95,
        eta_d=0.95,
        soc_init=0.5,
        soc_min=0.0,
        soc_max=1.0,
        soc_target=0.5
    )
    
    # Generate datasets
    datasets_to_generate = ['italy', 'germany', 'caiso']
    
    all_train = []
    all_test = []
    
    for dataset_name in datasets_to_generate:
        battery_params.soc_init = 0.5
        
        # Train
        train_data = generate_corrected_dataset(
            dataset_name=dataset_name,
            split='train',
            battery_params=battery_params,
            data_dir="./agentic_energy/data",
            output_dir="./datasets_corrected",
            max_days=None
        )
        all_train.extend(train_data)
        
        # Test
        battery_params.soc_init = 0.5
        test_data = generate_corrected_dataset(
            dataset_name=dataset_name,
            split='test',
            battery_params=battery_params,
            data_dir="./agentic_energy/data",
            output_dir="./datasets_corrected",
            max_days=None
        )
        all_test.extend(test_data)
    
    # Combine
    print(f"\n{'='*80}")
    print("COMBINING DATASETS")
    print(f"{'='*80}\n")
    
    combined_train_path = Path("./datasets_corrected/combined_train_corrected.json")
    combined_test_path = Path("./datasets_corrected/combined_test_corrected.json")
    
    with open(combined_train_path, 'w') as f:
        json.dump(all_train, f, indent=2)
    
    with open(combined_test_path, 'w') as f:
        json.dump(all_test, f, indent=2)
    
    print(f"✓ Combined train: {combined_train_path}")
    print(f"  Examples: {len(all_train)}")
    print(f"  Size: {combined_train_path.stat().st_size / (1024**2):.2f} MB")
    
    print(f"\n✓ Combined test: {combined_test_path}")
    print(f"  Examples: {len(all_test)}")
    print(f"  Size: {combined_test_path.stat().st_size / (1024**2):.2f} MB")
    
    # Analyze
    print(f"\n{'='*80}")
    print("ANALYZING TRAINING DATASET")
    print(f"{'='*80}")
    analyze_token_lengths(str(combined_train_path))

Forecast Engine using device: cpu


2025-11-16 02:36:30.665 | DEBUG    | agentics.core.agentics:from_csv:307 - Importing Agentics of type EnergyDataRecord from CSV c:\Users\16467\OneDrive\Desktop\Columbia\Agentics\Another\Agentics_for_EnergyArbitrage_Battery\energy_arbitrage\agentic_energy\data\Italy_data_actual.csv
2025-11-16 02:36:30.713 | DEBUG    | agentics.core.llm_connections:get_llm_provider:32 - Available LLM providers: ['gemini', 'openai', 'ollama']. None specified, defaulting to 'ollama'
2025-11-16 02:36:30.818 | DEBUG    | agentics.core.agentics:from_csv:307 - Importing Agentics of type EnergyDataRecord from CSV c:\Users\16467\OneDrive\Desktop\Columbia\Agentics\Another\Agentics_for_EnergyArbitrage_Battery\energy_arbitrage\agentic_energy\data\Italy_data_actual.csv



GENERATING: ITALY (train)
  Skipping day 0: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 1: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 2: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 3: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 4: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 5: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 6: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 7: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 8: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 9: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 10: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 11: solve_daily_milp() missing 1 required po

2025-11-16 02:36:30.869 | DEBUG    | agentics.core.llm_connections:get_llm_provider:32 - Available LLM providers: ['gemini', 'openai', 'ollama']. None specified, defaulting to 'ollama'
2025-11-16 02:36:30.946 | DEBUG    | agentics.core.agentics:from_csv:307 - Importing Agentics of type EnergyDataRecord from CSV c:\Users\16467\OneDrive\Desktop\Columbia\Agentics\Another\Agentics_for_EnergyArbitrage_Battery\energy_arbitrage\agentic_energy\data\Germany_energy_Data.csv


  Skipping day 0: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 1: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 2: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 3: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 4: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 5: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 6: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 7: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 8: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 9: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 10: solve_daily_milp() missing 1 required positional argument: 'day'
  Skipping day 11: solve_daily_milp() missing 1 required positional argument: 'day'
  

ValidationError: 1 validation error for EnergyDataRecord
prices
  Input should be a valid number, unable to parse string as a number [type=float_parsing, input_value='28,32', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/float_parsing