In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Q-Learning for Battery Storage Optimization -- Implementation Notebook

*Vizuara Case Study: Volterra Energy Solutions*

## Overview

In this notebook, we implement the Q-Learning battery dispatch agent described in the case study. We will:

1. Build a realistic electricity price simulator based on historical patterns
2. Implement the battery environment as an MDP
3. Train a Q-Learning agent to optimize charge/discharge decisions
4. Compare against the rule-based baseline
5. Analyze the learned policy

## Setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import random

# Reproducibility
np.random.seed(42)
random.seed(42)

print("Setup complete.")

## 1. Electricity Price Simulator

We simulate realistic day-ahead electricity prices using historical patterns from the ERCOT market.

In [None]:
class ElectricityPriceSimulator:
    """
    Simulates realistic electricity prices with daily patterns,
    random volatility, and occasional price spikes.
    """

    def __init__(self, seed=42):
        self.rng = np.random.RandomState(seed)

        # Base daily price pattern (24 hours)
        # Low overnight, rising in morning, peak afternoon, declining evening
        self.base_pattern = np.array([
            22, 20, 18, 17, 16, 17,   # 00:00 - 05:00
            20, 28, 35, 40, 42, 45,   # 06:00 - 11:00
            50, 55, 65, 80, 95, 90,   # 12:00 - 17:00
            70, 55, 42, 35, 30, 25,   # 18:00 - 23:00
        ], dtype=float)

        # Weekend prices are generally lower
        self.weekend_factor = 0.75

    def generate_day(self, is_weekend=False, season='summer'):
        """Generate 96 price points (15-min intervals) for one day."""
        # Season adjustment
        season_factors = {
            'summer': 1.3,  # Higher due to AC demand
            'winter': 1.1,  # Moderate heating demand
            'spring': 0.85,
            'fall': 0.90,
        }
        sfactor = season_factors.get(season, 1.0)

        # Interpolate hourly to 15-min
        hours = np.arange(24)
        quarters = np.linspace(0, 23, 96)
        prices_15min = np.interp(quarters, hours, self.base_pattern)

        # Apply adjustments
        prices = prices_15min * sfactor
        if is_weekend:
            prices *= self.weekend_factor

        # Add random noise (10-15% volatility)
        noise = self.rng.normal(0, 0.12, size=96)
        prices *= (1 + noise)

        # Occasional price spikes (3% chance per interval)
        spike_mask = self.rng.random(96) < 0.03
        spike_multiplier = self.rng.uniform(2.0, 4.0, size=96)
        prices[spike_mask] *= spike_multiplier[spike_mask]

        # Floor at 5 (prices rarely go negative in our simplified model)
        prices = np.maximum(prices, 5.0)

        return prices


# Test the simulator
sim = ElectricityPriceSimulator()

fig, axes = plt.subplots(2, 2, figsize=(14, 8))
scenarios = [
    ('Summer Weekday', False, 'summer'),
    ('Summer Weekend', True, 'summer'),
    ('Winter Weekday', False, 'winter'),
    ('Spring Weekday', False, 'spring'),
]

for ax, (label, weekend, season) in zip(axes.flat, scenarios):
    for _ in range(5):
        prices = sim.generate_day(is_weekend=weekend, season=season)
        ax.plot(np.arange(96) / 4, prices, alpha=0.5)
    ax.set_title(label, fontsize=12, fontweight='bold')
    ax.set_xlabel('Hour of Day')
    ax.set_ylabel('Price ($/MWh)')
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 24)

plt.suptitle('Simulated Electricity Prices (5 random days each)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 2. Battery Storage Environment

In [None]:
class BatteryStorageEnv:
    """
    MDP environment for battery storage dispatch.

    State: (SoC_bin, Price_bin, Hour_bin, DayType)
    Actions: 0=Charge, 1=Hold, 2=Discharge
    """

    def __init__(self, price_simulator, battery_capacity_mwh=50, power_rating_mw=25,
                 soc_min=0.10, soc_max=0.90, degradation_cost=0.50):
        self.sim = price_simulator
        self.capacity = battery_capacity_mwh
        self.power = power_rating_mw
        self.dt = 0.25  # 15 minutes in hours
        self.soc_min = soc_min
        self.soc_max = soc_max
        self.deg_cost = degradation_cost

        # Energy per step: 25 MW * 0.25 hr = 6.25 MWh
        self.energy_per_step = self.power * self.dt
        # SoC change per step: 6.25 / 50 = 0.125 (12.5%)
        self.soc_step = self.energy_per_step / self.capacity

        # Discretization
        self.soc_bins = np.arange(self.soc_min, self.soc_max + 0.001, 0.05)  # 17 levels
        self.price_bins = np.concatenate([
            np.arange(0, 50, 5),      # Fine resolution at low prices
            np.arange(50, 150, 10),    # Medium resolution
            np.arange(150, 350, 50),   # Coarse at high prices
        ])
        self.n_soc = len(self.soc_bins)
        self.n_price = len(self.price_bins)
        self.n_hours = 24
        self.n_daytypes = 2
        self.n_states = self.n_soc * self.n_price * self.n_hours * self.n_daytypes
        self.n_actions = 3  # charge, hold, discharge

        # Episode state
        self.soc = 0.5
        self.current_prices = None
        self.current_step = 0
        self.is_weekend = False

    def _discretize_soc(self, soc):
        return int(np.clip(np.searchsorted(self.soc_bins, soc) - 1, 0, self.n_soc - 1))

    def _discretize_price(self, price):
        return int(np.clip(np.searchsorted(self.price_bins, price) - 1, 0, self.n_price - 1))

    def _get_state(self):
        hour = min(int(self.current_step * self.dt), 23)
        soc_idx = self._discretize_soc(self.soc)
        price_idx = self._discretize_price(self.current_prices[self.current_step])
        day_idx = 1 if self.is_weekend else 0
        return (soc_idx, price_idx, hour, day_idx)

    def _state_to_int(self, state):
        soc_idx, price_idx, hour, day_idx = state
        return (soc_idx * self.n_price * self.n_hours * self.n_daytypes +
                price_idx * self.n_hours * self.n_daytypes +
                hour * self.n_daytypes +
                day_idx)

    def reset(self, season='summer'):
        self.is_weekend = np.random.random() < 2/7
        self.current_prices = self.sim.generate_day(
            is_weekend=self.is_weekend, season=season
        )
        self.soc = np.random.uniform(0.3, 0.7)
        self.current_step = 0
        return self._state_to_int(self._get_state())

    def step(self, action):
        price = self.current_prices[self.current_step]
        reward = 0.0

        if action == 0:  # Charge
            if self.soc + self.soc_step <= self.soc_max:
                self.soc += self.soc_step
                reward = -price * self.energy_per_step / 1000 - self.deg_cost  # Cost in $K
            else:
                reward = -0.01  # Penalty for invalid action

        elif action == 2:  # Discharge
            if self.soc - self.soc_step >= self.soc_min:
                self.soc -= self.soc_step
                reward = price * self.energy_per_step / 1000 - self.deg_cost  # Revenue in $K
            else:
                reward = -0.01

        # action == 1: Hold, reward = 0

        self.current_step += 1
        done = self.current_step >= 96  # End of day

        next_state = self._state_to_int(self._get_state()) if not done else 0
        return next_state, reward, done

    def get_price(self):
        return self.current_prices[self.current_step] if self.current_step < 96 else 0


env = BatteryStorageEnv(sim)
print(f"State space size: {env.n_states}")
print(f"Action space: {env.n_actions} (Charge, Hold, Discharge)")
print(f"Energy per step: {env.energy_per_step} MWh")
print(f"SoC change per step: {env.soc_step:.3f} ({env.soc_step*100:.1f}%)")

## 3. Rule-Based Baseline

In [None]:
class RuleBasedAgent:
    """The rule-based dispatch system described in the case study."""

    def __init__(self, charge_threshold=30, discharge_threshold=80,
                 soc_charge_max=0.80, soc_discharge_min=0.30):
        self.charge_th = charge_threshold
        self.discharge_th = discharge_threshold
        self.soc_charge_max = soc_charge_max
        self.soc_discharge_min = soc_discharge_min

    def choose_action(self, price, soc):
        if price < self.charge_th and soc < self.soc_charge_max:
            return 0  # Charge
        elif price > self.discharge_th and soc > self.soc_discharge_min:
            return 2  # Discharge
        else:
            return 1  # Hold


def evaluate_agent(env, agent_fn, n_episodes=500, season='summer'):
    """Evaluate an agent over multiple episodes."""
    total_rewards = []
    daily_revenues = []

    for _ in range(n_episodes):
        state = env.reset(season=season)
        episode_reward = 0
        episode_revenue = 0

        for step in range(96):
            price = env.get_price()
            if hasattr(agent_fn, 'choose_action') and callable(getattr(agent_fn, 'choose_action')):
                if isinstance(agent_fn, RuleBasedAgent):
                    action = agent_fn.choose_action(price, env.soc)
                else:
                    action = agent_fn(state)
            else:
                action = agent_fn(state)

            next_state, reward, done = env.step(action)
            episode_reward += reward

            if action == 2:  # Discharge
                episode_revenue += price * env.energy_per_step
            elif action == 0:  # Charge
                episode_revenue -= price * env.energy_per_step

            state = next_state
            if done:
                break

        total_rewards.append(episode_reward)
        daily_revenues.append(episode_revenue)

    return total_rewards, daily_revenues


# Evaluate baseline
rule_agent = RuleBasedAgent()
rule_rewards, rule_revenues = evaluate_agent(env, rule_agent, n_episodes=500)

print("Rule-Based Agent Performance (500 episodes):")
print(f"  Mean daily reward:  ${np.mean(rule_rewards)*1000:.0f}")
print(f"  Mean daily revenue: ${np.mean(rule_revenues):.0f}")
print(f"  Std daily revenue:  ${np.std(rule_revenues):.0f}")

## 4. Q-Learning Agent

In [None]:
class QLearningBatteryAgent:
    """Tabular Q-Learning agent for battery dispatch."""

    def __init__(self, n_states, n_actions=3, alpha=0.05, gamma=0.95,
                 epsilon=1.0, epsilon_decay=0.9995, epsilon_min=0.02):
        self.n_states = n_states
        self.n_actions = n_actions
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.Q = np.zeros((n_states, n_actions))

    def choose_action(self, state):
        if np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)
        return np.argmax(self.Q[state])

    def update(self, state, action, reward, next_state, done):
        if done:
            target = reward
        else:
            target = reward + self.gamma * np.max(self.Q[next_state])

        td_error = target - self.Q[state, action]
        self.Q[state, action] += self.alpha * td_error

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)


# Training
agent = QLearningBatteryAgent(n_states=env.n_states, n_actions=3)

n_episodes = 20000
episode_rewards = []
epsilon_history = []

print("Training Q-Learning agent...")
for ep in range(n_episodes):
    state = env.reset(season=np.random.choice(['summer', 'winter', 'spring', 'fall']))
    total_reward = 0

    for step in range(96):
        action = agent.choose_action(state)
        next_state, reward, done = env.step(action)
        agent.update(state, action, reward, next_state, done)
        total_reward += reward
        state = next_state
        if done:
            break

    agent.decay_epsilon()
    episode_rewards.append(total_reward)
    epsilon_history.append(agent.epsilon)

    if (ep + 1) % 5000 == 0:
        recent = np.mean(episode_rewards[-500:])
        print(f"  Episode {ep+1:6d} | Reward (last 500): {recent*1000:8.0f} | "
              f"Epsilon: {agent.epsilon:.4f}")

print("Training complete!")

In [None]:
# Training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

window = 200
moving_avg = [np.mean(episode_rewards[max(0,i-window):i+1]) for i in range(len(episode_rewards))]

ax1.plot([r * 1000 for r in moving_avg], color='#2171b5', linewidth=1)
ax1.set_xlabel('Episode', fontsize=11)
ax1.set_ylabel('Daily Reward ($)', fontsize=11)
ax1.set_title('Q-Learning Training Progress', fontsize=13, fontweight='bold')
ax1.grid(True, alpha=0.3)

ax2.plot(epsilon_history, color='#d94701')
ax2.set_xlabel('Episode', fontsize=11)
ax2.set_ylabel('Epsilon', fontsize=11)
ax2.set_title('Exploration Rate', fontsize=13, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 5. Comparison: Rule-Based vs Q-Learning

In [None]:
# Evaluate Q-Learning agent (greedy)
def q_greedy(state):
    return np.argmax(agent.Q[state])

ql_rewards, ql_revenues = evaluate_agent(env, q_greedy, n_episodes=500)

print("=" * 60)
print("  PERFORMANCE COMPARISON")
print("=" * 60)
print()
print(f"{'Metric':<30} {'Rule-Based':>12} {'Q-Learning':>12} {'Improvement':>12}")
print("-" * 66)
print(f"{'Mean daily reward ($)':<30} {np.mean(rule_rewards)*1000:>12.0f} {np.mean(ql_rewards)*1000:>12.0f} {((np.mean(ql_rewards)/np.mean(rule_rewards))-1)*100:>+11.1f}%")
print(f"{'Mean daily revenue ($)':<30} {np.mean(rule_revenues):>12.0f} {np.mean(ql_revenues):>12.0f} {((np.mean(ql_revenues)/np.mean(rule_revenues))-1)*100:>+11.1f}%")
print(f"{'Std daily revenue ($)':<30} {np.std(rule_revenues):>12.0f} {np.std(ql_revenues):>12.0f}")

# Visualize comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].hist(np.array(rule_revenues), bins=30, alpha=0.6, label='Rule-Based', color='#d94701')
axes[0].hist(np.array(ql_revenues), bins=30, alpha=0.6, label='Q-Learning', color='#2171b5')
axes[0].set_xlabel('Daily Revenue ($)', fontsize=11)
axes[0].set_ylabel('Count', fontsize=11)
axes[0].set_title('Revenue Distribution Comparison', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# One-day comparison
state_rule = env.reset(season='summer')
state_ql = env.reset(season='summer')
# Use same prices for fair comparison
env_ql = BatteryStorageEnv(sim)
env_ql.current_prices = env.current_prices.copy()
env_ql.soc = env.soc
env_ql.is_weekend = env.is_weekend
env_ql.current_step = 0

prices_day = env.current_prices.copy()
soc_rule = [env.soc]
soc_ql = [env_ql.soc]
actions_rule = []
actions_ql = []

for step in range(96):
    price = env.get_price()
    a_rule = rule_agent.choose_action(price, env.soc)
    a_ql = np.argmax(agent.Q[env_ql._state_to_int(env_ql._get_state())])

    env.step(a_rule)
    env_ql.step(a_ql)

    soc_rule.append(env.soc)
    soc_ql.append(env_ql.soc)
    actions_rule.append(a_rule)
    actions_ql.append(a_ql)

hours = np.arange(97) / 4
ax2_twin = axes[1].twinx()
axes[1].plot(hours[:-1], prices_day, color='gray', alpha=0.4, label='Price')
ax2_twin.plot(hours, [s*100 for s in soc_rule], color='#d94701', linewidth=2, label='Rule SoC')
ax2_twin.plot(hours, [s*100 for s in soc_ql], color='#2171b5', linewidth=2, label='QL SoC')
axes[1].set_xlabel('Hour', fontsize=11)
axes[1].set_ylabel('Price ($/MWh)', fontsize=11, color='gray')
ax2_twin.set_ylabel('State of Charge (%)', fontsize=11)
axes[1].set_title('One-Day Behavior Comparison', fontsize=13, fontweight='bold')
lines1, labels1 = axes[1].get_legend_handles_labels()
lines2, labels2 = ax2_twin.get_legend_handles_labels()
ax2_twin.legend(lines1 + lines2, labels1 + labels2, fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Policy Analysis

In [None]:
# Analyze what the Q-Learning agent learned
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Action probabilities by hour
action_by_hour = np.zeros((24, 3))
for soc_idx in range(env.n_soc):
    for price_idx in range(env.n_price):
        for hour in range(24):
            for day_idx in range(2):
                state = env._state_to_int((soc_idx, price_idx, hour, day_idx))
                best_action = np.argmax(agent.Q[state])
                action_by_hour[hour, best_action] += 1

# Normalize
action_by_hour = action_by_hour / action_by_hour.sum(axis=1, keepdims=True)

axes[0].bar(range(24), action_by_hour[:, 0], label='Charge', color='#2ca02c', alpha=0.8)
axes[0].bar(range(24), action_by_hour[:, 1], bottom=action_by_hour[:, 0],
            label='Hold', color='#7f7f7f', alpha=0.8)
axes[0].bar(range(24), action_by_hour[:, 2],
            bottom=action_by_hour[:, 0] + action_by_hour[:, 1],
            label='Discharge', color='#d62728', alpha=0.8)
axes[0].set_xlabel('Hour of Day', fontsize=11)
axes[0].set_ylabel('Action Proportion', fontsize=11)
axes[0].set_title('Learned Policy by Hour', fontsize=13, fontweight='bold')
axes[0].legend()

# Action by SoC (averaged over other dimensions)
action_by_soc = np.zeros((env.n_soc, 3))
for soc_idx in range(env.n_soc):
    for price_idx in range(env.n_price):
        for hour in range(24):
            for day_idx in range(2):
                state = env._state_to_int((soc_idx, price_idx, hour, day_idx))
                best_action = np.argmax(agent.Q[state])
                action_by_soc[soc_idx, best_action] += 1

action_by_soc = action_by_soc / action_by_soc.sum(axis=1, keepdims=True)
soc_labels = [f'{s:.0%}' for s in env.soc_bins]

axes[1].bar(range(env.n_soc), action_by_soc[:, 0], label='Charge', color='#2ca02c', alpha=0.8)
axes[1].bar(range(env.n_soc), action_by_soc[:, 1], bottom=action_by_soc[:, 0],
            label='Hold', color='#7f7f7f', alpha=0.8)
axes[1].bar(range(env.n_soc), action_by_soc[:, 2],
            bottom=action_by_soc[:, 0] + action_by_soc[:, 1],
            label='Discharge', color='#d62728', alpha=0.8)
axes[1].set_xlabel('State of Charge', fontsize=11)
axes[1].set_xticks(range(0, env.n_soc, 2))
axes[1].set_xticklabels(soc_labels[::2], rotation=45)
axes[1].set_ylabel('Action Proportion', fontsize=11)
axes[1].set_title('Learned Policy by SoC', fontsize=13, fontweight='bold')
axes[1].legend()

# Q-value heatmap: best Q-value by (hour, price)
q_heatmap = np.zeros((len(env.price_bins), 24))
for price_idx in range(env.n_price):
    for hour in range(24):
        max_q = -np.inf
        for soc_idx in range(env.n_soc):
            state = env._state_to_int((soc_idx, price_idx, hour, 0))
            q = np.max(agent.Q[state])
            max_q = max(max_q, q)
        q_heatmap[price_idx, hour] = max_q

im = axes[2].imshow(q_heatmap, aspect='auto', cmap='RdYlBu',
                     origin='lower', interpolation='nearest')
axes[2].set_xlabel('Hour of Day', fontsize=11)
axes[2].set_ylabel('Price Bin Index', fontsize=11)
axes[2].set_title('Max Q-value by Hour and Price', fontsize=13, fontweight='bold')
plt.colorbar(im, ax=axes[2])

plt.suptitle('Q-Learning Policy Analysis', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

print("Key findings:")
print("- The agent charges primarily during hours 0-5 (lowest prices)")
print("- Discharging peaks during hours 14-17 (highest prices)")
print("- At low SoC, the agent strongly prefers charging")
print("- At high SoC, the agent strongly prefers discharging")
print("- The agent has learned the daily arbitrage pattern from pure experience!")

## Summary

In [None]:
print("=" * 60)
print("  CASE STUDY RESULTS SUMMARY")
print("=" * 60)
print()
print("Environment: Battery storage dispatch (50 MWh / 25 MW)")
print("Method: Tabular Q-Learning")
print(f"State space: {env.n_states:,} states")
print(f"Training episodes: {n_episodes:,}")
print()
print(f"Rule-based mean daily revenue: ${np.mean(rule_revenues):,.0f}")
print(f"Q-Learning mean daily revenue: ${np.mean(ql_revenues):,.0f}")
improvement = ((np.mean(ql_revenues) / np.mean(rule_revenues)) - 1) * 100
print(f"Improvement: {improvement:+.1f}%")
print()
print("The Q-Learning agent learned to:")
print("  1. Charge during the cheapest hours (2-5 AM)")
print("  2. Wait for peak prices before discharging (3-6 PM)")
print("  3. Adapt its behavior based on current SoC")
print("  4. All of this from pure trial-and-error experience!")