LSTM

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

# -------------------- Data Loading --------------------
data = scipy.io.loadmat('Oxford_Battery_Degradation_Dataset_1.mat')

def soc_from_q(q, phase):
    """Compute SOC% from charge trace."""
    q0, q1 = float(q[0]), float(q[-1])
    if np.isclose(q1, q0):
        return None
    qn = (q - q0) / (q1 - q0)
    if phase == 'C1ch':
        return 100.0 * np.clip(qn, 0, 1)
    elif phase == 'C1dc':
        return 100.0 * (1.0 - np.clip(qn, 0, 1))
    return None

def extract_sequences(data, L=128):
    """Extract V, T, I features and SOC labels."""
    X_list, y_list, phase_list = [], [], []
    
    for ci in range(1, 9):
        cell = data[f'Cell{ci}']
        for cyc_name in sorted(cell.dtype.names, key=lambda s: int(s[3:])):
            cyc = cell[cyc_name][0, 0]
            for phase in ['C1ch', 'C1dc']:
                if phase not in cyc.dtype.names:
                    continue
                blk = cyc[phase][0, 0]
                if not all(k in blk.dtype.names for k in ['t','v','q']):
                    continue
                
                t = blk['t'][0,0].ravel().astype(float)
                v = blk['v'][0,0].ravel().astype(float)
                q = blk['q'][0,0].ravel().astype(float)
                
                if t.size < 5:
                    continue
                
                # Temperature (fill missing with forward/backward fill)
                if 'T' in blk.dtype.names:
                    T = blk['T'][0,0].ravel().astype(float)
                else:
                    T = np.full_like(t, 25.0)
                T = pd.Series(T).ffill().bfill().values
                
                # Current (use defaults if missing)
                if 'i' in blk.dtype.names:
                    I = blk['i'][0,0].ravel().astype(float)
                else:
                    I = np.full_like(t, 0.74 if phase == 'C1ch' else -0.74)
                
                soc = soc_from_q(q, phase)
                if soc is None:
                    continue
                
                # Resample to fixed length L
                t_new = np.linspace(t[0], t[-1], L)
                v_new = np.interp(t_new, t, v)
                T_new = np.interp(t_new, t, T)
                I_new = np.interp(t_new, t, I)
                y_new = np.interp(t_new, t, soc)
                
                X = np.stack([v_new, T_new, I_new], axis=-1)  # [L, 3]
                y = y_new[:, None]  # [L, 1]
                
                X_list.append(X)
                y_list.append(y)
                phase_list.append(phase)  # Track charging vs discharging
    
    return np.array(X_list), np.array(y_list), np.array(phase_list)

# -------------------- Prepare Data --------------------
print("Loading data...")
X, y, phases = extract_sequences(data, L=128)
print(f"Total sequences: {len(X)}, Shape: {X.shape}")

# Train/test split (80/20)
n_train = int(0.8 * len(X))
X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]
phases_test = phases[n_train:]

# Normalize features
mu = X_train.reshape(-1, X_train.shape[-1]).mean(axis=0)
sd = X_train.reshape(-1, X_train.shape[-1]).std(axis=0) + 1e-8
X_train = (X_train - mu) / sd
X_test = (X_test - mu) / sd

# Normalize labels to [0, 1]
y_train_norm = y_train / 100.0
y_test_norm = y_test / 100.0

print(f"Train: {len(X_train)} | Test: {len(X_test)}")

# -------------------- DataLoaders --------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}\n")

train_loader = DataLoader(
    TensorDataset(torch.tensor(X_train, dtype=torch.float32),
                  torch.tensor(y_train_norm, dtype=torch.float32)),
    batch_size=32, shuffle=True)

test_loader = DataLoader(
    TensorDataset(torch.tensor(X_test, dtype=torch.float32),
                  torch.tensor(y_test_norm, dtype=torch.float32)),
    batch_size=32, shuffle=False)

# -------------------- LSTM Model --------------------
class LSTM_SOC(nn.Module):
    def __init__(self, input_dim=3, hidden=128, layers=2):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden, num_layers=layers, 
                           batch_first=True, dropout=0.2)
        self.fc = nn.Linear(hidden, 1)
    
    def forward(self, x):
        out, _ = self.lstm(x)
        return self.fc(out)

model = LSTM_SOC(input_dim=3, hidden=128, layers=2).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}\n")

# -------------------- Training --------------------
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

epochs = 100
best_loss = float('inf')

print("Training...")
for epoch in range(epochs):
    # Train
    model.train()
    train_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        pred = model(xb)
        loss = loss_fn(pred, yb)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)
    
    # Validate
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            test_loss += loss_fn(model(xb), yb).item() * xb.size(0)
    test_loss /= len(test_loader.dataset)
    
    # Save best model
    if test_loss < best_loss:
        best_loss = test_loss
        torch.save(model.state_dict(), 'best_lstm.pt')
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.6f} | Test Loss: {test_loss:.6f}")

print(f"\nBest Test Loss: {best_loss:.6f}")

# -------------------- Validation --------------------
model.load_state_dict(torch.load('best_lstm.pt'))
model.eval()

# Get predictions
predictions, actuals = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        pred = model(xb).cpu().numpy()
        predictions.append(pred)
        actuals.append(yb.numpy())

predictions = np.concatenate(predictions) * 100  # Convert back to %
actuals = np.concatenate(actuals) * 100

# Calculate overall metrics
mae = np.mean(np.abs(predictions - actuals))
rmse = np.sqrt(np.mean((predictions - actuals)**2))
max_error = np.max(np.abs(predictions - actuals))

print(f"\n{'='*60}")
print("Overall Validation Metrics:")
print(f"{'='*60}")
print(f"MAE:        {mae:.4f}%")
print(f"RMSE:       {rmse:.4f}%")
print(f"Max Error:  {max_error:.4f}%")

# -------------------- SOC Range Analysis --------------------
print(f"\n{'='*60}")
print("SOC Range Analysis")
print(f"{'='*60}")

# Define SOC ranges to test
soc_ranges = [
    (0, 30, "Low SOC (0-30%)"),
    (30, 60, "Mid SOC (30-60%)"),
    (60, 90, "High SOC (60-90%)"),
    (0, 100, "Full Range (0-100%)")
]

range_metrics = []

for soc_min, soc_max, range_name in soc_ranges:
    # Filter predictions and actuals for this SOC range
    mask = (actuals >= soc_min) & (actuals <= soc_max)
    
    if np.sum(mask) == 0:
        print(f"\n{range_name}: No data points in this range")
        continue
    
    range_preds = predictions[mask]
    range_actuals = actuals[mask]
    
    # Calculate metrics for this range
    range_mae = np.mean(np.abs(range_preds - range_actuals))
    range_rmse = np.sqrt(np.mean((range_preds - range_actuals)**2))
    range_max_error = np.max(np.abs(range_preds - range_actuals))
    n_points = np.sum(mask)
    
    range_metrics.append({
        'Range': range_name,
        'MAE': range_mae,
        'RMSE': range_rmse,
        'Max Error': range_max_error,
        'N Points': n_points
    })
    
    print(f"\n{range_name}:")
    print(f"  Points:     {n_points:,}")
    print(f"  MAE:        {range_mae:.4f}%")
    print(f"  RMSE:       {range_rmse:.4f}%")
    print(f"  Max Error:  {range_max_error:.4f}%")

# Create DataFrame for easier visualization
metrics_df = pd.DataFrame(range_metrics)
print(f"\n{'='*60}")
print("Summary Table:")
print(f"{'='*60}")
print(metrics_df.to_string(index=False))
print(f"{'='*60}\n")

# -------------------- Phase-specific Analysis --------------------
print(f"\n{'='*60}")
print("Charging vs Discharging Analysis")
print(f"{'='*60}")

phase_metrics = []

for phase_type in ['C1ch', 'C1dc']:
    phase_name = "Charging" if phase_type == 'C1ch' else "Discharging"
    
    # Get indices for this phase
    phase_mask = phases_test == phase_type
    
    if np.sum(phase_mask) == 0:
        continue
    
    # Get predictions for this phase
    phase_preds = predictions[phase_mask]
    phase_actuals = actuals[phase_mask]
    
    # Overall phase metrics
    phase_mae = np.mean(np.abs(phase_preds - phase_actuals))
    phase_rmse = np.sqrt(np.mean((phase_preds - phase_actuals)**2))
    
    print(f"\n{phase_name}:")
    print(f"  Sequences:  {np.sum(phase_mask)}")
    print(f"  MAE:        {phase_mae:.4f}%")
    print(f"  RMSE:       {phase_rmse:.4f}%")
    
    # Break down by SOC ranges
    for soc_min, soc_max, range_name in soc_ranges[:3]:  # Only low, mid, high
        range_mask = (phase_actuals >= soc_min) & (phase_actuals <= soc_max)
        
        if np.sum(range_mask) > 0:
            sub_preds = phase_preds[range_mask]
            sub_actuals = phase_actuals[range_mask]
            sub_mae = np.mean(np.abs(sub_preds - sub_actuals))
            sub_rmse = np.sqrt(np.mean((sub_preds - sub_actuals)**2))
            
            print(f"    {range_name}:")
            print(f"      Points: {np.sum(range_mask):,} | MAE: {sub_mae:.4f}% | RMSE: {sub_rmse:.4f}%")
            
            phase_metrics.append({
                'Phase': phase_name,
                'Range': range_name,
                'MAE': sub_mae,
                'RMSE': sub_rmse,
                'N Points': np.sum(range_mask)
            })

# -------------------- Range-specific Visualization --------------------
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Plot metrics comparison
ax = axes[0, 0]
x_pos = np.arange(len(range_metrics))
ax.bar(x_pos, [m['MAE'] for m in range_metrics], alpha=0.7, color='steelblue')
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('MAE (%)')
ax.set_title('MAE by SOC Range', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

ax = axes[0, 1]
ax.bar(x_pos, [m['RMSE'] for m in range_metrics], alpha=0.7, color='coral')
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('RMSE (%)')
ax.set_title('RMSE by SOC Range', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

ax = axes[0, 2]
ax.bar(x_pos, [m['Max Error'] for m in range_metrics], alpha=0.7, color='lightcoral')
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('Max Error (%)')
ax.set_title('Max Error by SOC Range', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Plot error distribution by range
for idx, (soc_min, soc_max, range_name) in enumerate(soc_ranges[:3]):
    ax = axes[1, idx]
    mask = (actuals >= soc_min) & (actuals <= soc_max)
    
    if np.sum(mask) > 0:
        range_errors = (predictions[mask] - actuals[mask]).flatten()
        ax.hist(range_errors, bins=30, edgecolor='black', alpha=0.7, color='teal')
        ax.axvline(0, color='red', linestyle='--', linewidth=2, label='Zero Error')
        ax.set_xlabel('Error (%)')
        ax.set_ylabel('Frequency')
        ax.set_title(f'{range_name}\nMean Error: {np.mean(range_errors):.3f}%', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)

plt.suptitle('SOC Range Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# -------------------- Phase-specific Visualization --------------------
if len(phase_metrics) > 0:
    phase_df = pd.DataFrame(phase_metrics)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Grouped bar chart for MAE
    ax = axes[0]
    ranges_unique = phase_df['Range'].unique()
    x = np.arange(len(ranges_unique))
    width = 0.35
    
    charging_data = phase_df[phase_df['Phase'] == 'Charging']
    discharging_data = phase_df[phase_df['Phase'] == 'Discharging']
    
    charge_mae = [charging_data[charging_data['Range'] == r]['MAE'].values[0] 
                  if len(charging_data[charging_data['Range'] == r]) > 0 else 0 
                  for r in ranges_unique]
    discharge_mae = [discharging_data[discharging_data['Range'] == r]['MAE'].values[0] 
                     if len(discharging_data[discharging_data['Range'] == r]) > 0 else 0 
                     for r in ranges_unique]
    
    ax.bar(x - width/2, charge_mae, width, label='Charging', alpha=0.8, color='green')
    ax.bar(x + width/2, discharge_mae, width, label='Discharging', alpha=0.8, color='orange')
    ax.set_xticks(x)
    ax.set_xticklabels([r.split('(')[0].strip() for r in ranges_unique], rotation=45, ha='right')
    ax.set_ylabel('MAE (%)')
    ax.set_title('MAE: Charging vs Discharging by SOC Range', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    # Grouped bar chart for RMSE
    ax = axes[1]
    charge_rmse = [charging_data[charging_data['Range'] == r]['RMSE'].values[0] 
                   if len(charging_data[charging_data['Range'] == r]) > 0 else 0 
                   for r in ranges_unique]
    discharge_rmse = [discharging_data[discharging_data['Range'] == r]['RMSE'].values[0] 
                      if len(discharging_data[discharging_data['Range'] == r]) > 0 else 0 
                      for r in ranges_unique]
    
    ax.bar(x - width/2, charge_rmse, width, label='Charging', alpha=0.8, color='green')
    ax.bar(x + width/2, discharge_rmse, width, label='Discharging', alpha=0.8, color='orange')
    ax.set_xticks(x)
    ax.set_xticklabels([r.split('(')[0].strip() for r in ranges_unique], rotation=45, ha='right')
    ax.set_ylabel('RMSE (%)')
    ax.set_title('RMSE: Charging vs Discharging by SOC Range', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()

# -------------------- Partial Curve Testing --------------------
print(f"\n{'='*60}")
print("Testing on Partial Charge/Discharge Curves")
print(f"{'='*60}")

def create_partial_curve(X_seq, y_seq, soc_start, soc_end):
    """
    Extract a partial curve from a full charge/discharge sequence.
    
    Args:
        X_seq: Input features [L, 3]
        y_seq: SOC labels [L, 1]
        soc_start: Starting SOC%
        soc_end: Ending SOC%
    
    Returns:
        Partial X, y sequences
    """
    soc_values = y_seq[:, 0]
    
    # Find indices where SOC is within the range
    if soc_start < soc_end:  # Charging direction
        mask = (soc_values >= soc_start) & (soc_values <= soc_end)
    else:  # Discharging direction
        mask = (soc_values <= soc_start) & (soc_values >= soc_end)
    
    if np.sum(mask) < 10:  # Need at least 10 points
        return None, None
    
    indices = np.where(mask)[0]
    start_idx, end_idx = indices[0], indices[-1] + 1
    
    return X_seq[start_idx:end_idx], y_seq[start_idx:end_idx]

# Define partial curve test cases
partial_curve_tests = [
    # Charging scenarios
    (0, 30, 'C1ch', 'Charge 0-30%'),
    (30, 60, 'C1ch', 'Charge 30-60%'),
    (60, 90, 'C1ch', 'Charge 60-90%'),
    (20, 80, 'C1ch', 'Charge 20-80%'),
    
    # Discharging scenarios
    (100, 70, 'C1dc', 'Discharge 100-70%'),
    (70, 40, 'C1dc', 'Discharge 70-40%'),
    (40, 10, 'C1dc', 'Discharge 40-10%'),
    (80, 20, 'C1dc', 'Discharge 80-20%'),
]

partial_results = []

for soc_start, soc_end, phase_type, test_name in partial_curve_tests:
    # Find test sequences of the right phase
    phase_mask = phases_test == phase_type
    phase_indices = np.where(phase_mask)[0]
    
    if len(phase_indices) == 0:
        print(f"\n{test_name}: No {phase_type} sequences available")
        continue
    
    # Process multiple sequences for this test case
    test_preds, test_actuals = [], []
    valid_sequences = 0
    
    for idx in phase_indices[:50]:  # Test on up to 50 sequences
        X_seq = X_test[idx]
        y_seq = y_test[idx]
        
        # Create partial curve
        X_partial, y_partial = create_partial_curve(X_seq, y_seq, soc_start, soc_end)
        
        if X_partial is None:
            continue
        
        valid_sequences += 1
        
        # Get prediction for partial curve
        with torch.no_grad():
            x_input = torch.tensor(X_partial[np.newaxis, :, :], dtype=torch.float32, device=device)
            pred = model(x_input).cpu().numpy()[0, :, 0] * 100
        
        true = y_partial[:, 0]
        
        test_preds.append(pred)
        test_actuals.append(true)
    
    if valid_sequences == 0:
        print(f"\n{test_name}: No valid partial curves found in range")
        continue
    
    # Concatenate all predictions
    all_preds = np.concatenate(test_preds)
    all_actuals = np.concatenate(test_actuals)
    
    # Calculate metrics
    mae = np.mean(np.abs(all_preds - all_actuals))
    rmse = np.sqrt(np.mean((all_preds - all_actuals)**2))
    max_error = np.max(np.abs(all_preds - all_actuals))
    
    partial_results.append({
        'Test': test_name,
        'Phase': phase_type,
        'SOC Range': f"{soc_start}-{soc_end}%",
        'Sequences': valid_sequences,
        'Points': len(all_preds),
        'MAE': mae,
        'RMSE': rmse,
        'Max Error': max_error
    })
    
    print(f"\n{test_name}:")
    print(f"  Valid Sequences: {valid_sequences}")
    print(f"  Total Points:    {len(all_preds):,}")
    print(f"  MAE:             {mae:.4f}%")
    print(f"  RMSE:            {rmse:.4f}%")
    print(f"  Max Error:       {max_error:.4f}%")

# Create summary table
if len(partial_results) > 0:
    partial_df = pd.DataFrame(partial_results)
    print(f"\n{'='*60}")
    print("Partial Curve Testing Summary:")
    print(f"{'='*60}")
    print(partial_df.to_string(index=False))
    print(f"{'='*60}\n")

# -------------------- Partial Curve Visualization --------------------
print("\nGenerating partial curve visualizations...")

fig, axes = plt.subplots(4, 2, figsize=(14, 16))

# Select interesting partial curve examples
visualization_tests = [
    (0, 30, 'C1ch', 'Charge 0-30%', 0, 0),
    (30, 60, 'C1ch', 'Charge 30-60%', 0, 1),
    (60, 90, 'C1ch', 'Charge 60-90%', 1, 0),
    (20, 80, 'C1ch', 'Charge 20-80%', 1, 1),
    (100, 70, 'C1dc', 'Discharge 100-70%', 2, 0),
    (70, 40, 'C1dc', 'Discharge 70-40%', 2, 1),
    (40, 10, 'C1dc', 'Discharge 40-10%', 3, 0),
    (80, 20, 'C1dc', 'Discharge 80-20%', 3, 1),
]

for soc_start, soc_end, phase_type, test_name, row, col in visualization_tests:
    ax = axes[row, col]
    
    # Find a suitable test sequence
    phase_mask = phases_test == phase_type
    phase_indices = np.where(phase_mask)[0]
    
    found = False
    for idx in phase_indices[:100]:  # Search through sequences
        X_seq = X_test[idx]
        y_seq = y_test[idx]
        
        # Create partial curve
        X_partial, y_partial = create_partial_curve(X_seq, y_seq, soc_start, soc_end)
        
        if X_partial is None or len(X_partial) < 10:
            continue
        
        # Get prediction
        with torch.no_grad():
            x_input = torch.tensor(X_partial[np.newaxis, :, :], dtype=torch.float32, device=device)
            pred = model(x_input).cpu().numpy()[0, :, 0] * 100
        
        true = y_partial[:, 0]
        error = np.abs(true - pred)
        mae = np.mean(error)
        max_error = np.max(error)
        
        # Plot
        time_steps = np.arange(len(true))
        ax.plot(time_steps, true, 'b-', linewidth=2.5, label='True SOC', alpha=0.8)
        ax.plot(time_steps, pred, 'r--', linewidth=2, label='Predicted SOC', alpha=0.8)
        ax.fill_between(time_steps, true, pred, alpha=0.2, color='orange')
        
        # Add horizontal lines for SOC range
        ax.axhline(soc_start, color='green', linestyle=':', linewidth=1.5, alpha=0.7, label=f'Start: {soc_start}%')
        if phase_type == 'C1ch':
            ax.axhline(soc_end, color='red', linestyle=':', linewidth=1.5, alpha=0.7, label=f'End: {soc_end}%')
        else:
            ax.axhline(soc_end, color='red', linestyle=':', linewidth=1.5, alpha=0.7, label=f'End: {soc_end}%')
        
        ax.set_title(f'{test_name}\nMAE: {mae:.2f}% | Max Error: {max_error:.2f}%', 
                     fontsize=10, fontweight='bold')
        ax.set_xlabel('Time Step', fontsize=9)
        ax.set_ylabel('SOC (%)', fontsize=9)
        ax.legend(fontsize=7, loc='best')
        ax.grid(True, alpha=0.3)
        ax.set_ylim(max(0, min(soc_start, soc_end) - 10), 
                    min(100, max(soc_start, soc_end) + 10))
        
        found = True
        break
    
    if not found:
        ax.text(0.5, 0.5, f'No valid\n{test_name}\ndata found',
               ha='center', va='center', fontsize=10, transform=ax.transAxes)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

plt.suptitle('LSTM Performance on Partial Charge/Discharge Curves', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# -------------------- Comparison: Full vs Partial Curves --------------------
if len(partial_results) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Separate charging and discharging
    charging_partial = partial_df[partial_df['Phase'] == 'C1ch']
    discharging_partial = partial_df[partial_df['Phase'] == 'C1dc']
    
    # Plot MAE comparison
    ax = axes[0]
    if len(charging_partial) > 0:
        x_charge = np.arange(len(charging_partial))
        ax.bar(x_charge - 0.2, charging_partial['MAE'], 0.4, 
               label='Charging', alpha=0.8, color='green')
    if len(discharging_partial) > 0:
        x_discharge = np.arange(len(discharging_partial))
        ax.bar(x_discharge + 0.2, discharging_partial['MAE'], 0.4, 
               label='Discharging', alpha=0.8, color='orange')
    
    all_tests = pd.concat([charging_partial, discharging_partial]) if len(discharging_partial) > 0 else charging_partial
    ax.set_xticks(np.arange(len(all_tests)))
    ax.set_xticklabels([t.replace('Charge ', 'C ').replace('Discharge ', 'D ') 
                        for t in all_tests['Test']], rotation=45, ha='right')
    ax.set_ylabel('MAE (%)')
    ax.set_title('MAE for Partial Curves', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot RMSE comparison
    ax = axes[1]
    if len(charging_partial) > 0:
        ax.bar(x_charge - 0.2, charging_partial['RMSE'], 0.4, 
               label='Charging', alpha=0.8, color='green')
    if len(discharging_partial) > 0:
        ax.bar(x_discharge + 0.2, discharging_partial['RMSE'], 0.4, 
               label='Discharging', alpha=0.8, color='orange')
    
    ax.set_xticks(np.arange(len(all_tests)))
    ax.set_xticklabels([t.replace('Charge ', 'C ').replace('Discharge ', 'D ') 
                        for t in all_tests['Test']], rotation=45, ha='right')
    ax.set_ylabel('RMSE (%)')
    ax.set_title('RMSE for Partial Curves', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()

# -------------------- Mid-SOC Range Predictions Visualization --------------------
print("\nGenerating visualizations for mid-SOC ranges...")

fig, axes = plt.subplots(3, 2, figsize=(14, 12))

soc_test_ranges = [
    (0, 30, "Low SOC (0-30%)", 'blues'),
    (30, 60, "Mid SOC (30-60%)", 'greens'),
    (60, 90, "High SOC (60-90%)", 'oranges')
]

for row_idx, (soc_min, soc_max, range_name, colormap) in enumerate(soc_test_ranges):
    # Find sequences that fall primarily in this SOC range
    candidates = []
    for idx in range(len(X_test)):
        mean_soc = np.mean(y_test[idx, :, 0])
        if soc_min <= mean_soc <= soc_max:
            candidates.append(idx)
    
    if len(candidates) == 0:
        print(f"No sequences found primarily in {range_name}")
        continue
    
    # Show one charging and one discharging example from this range
    for col_idx, phase_type in enumerate(['C1ch', 'C1dc']):
        phase_candidates = [c for c in candidates if phases_test[c] == phase_type]
        
        if len(phase_candidates) == 0:
            ax = axes[row_idx, col_idx]
            ax.text(0.5, 0.5, f'No {range_name}\n{"Charging" if phase_type == "C1ch" else "Discharging"} data',
                   ha='center', va='center', fontsize=12)
            ax.set_xlim(0, 1)
            ax.set_ylim(0, 1)
            ax.axis('off')
            continue
        
        # Pick a random sample from this range and phase
        idx = np.random.choice(phase_candidates)
        
        with torch.no_grad():
            x_sample = torch.tensor(X_test[idx:idx+1], dtype=torch.float32, device=device)
            pred = model(x_sample).cpu().numpy()[0, :, 0] * 100
        
        true = y_test[idx, :, 0]
        error = np.abs(true - pred)
        mean_error = np.mean(error)
        max_error = np.max(error)
        phase_label = "Charging" if phase_type == 'C1ch' else "Discharging"
        
        ax = axes[row_idx, col_idx]
        
        # Plot with shaded error region
        time_steps = np.arange(len(true))
        ax.plot(time_steps, true, 'k-', linewidth=2.5, label='True SOC', alpha=0.8)
        ax.plot(time_steps, pred, 'r--', linewidth=2, label='LSTM Prediction', alpha=0.8)
        ax.fill_between(time_steps, pred - error, pred + error, alpha=0.2, color='red')
        
        # Highlight the target SOC range
        ax.axhspan(soc_min, soc_max, alpha=0.1, color='green', label=f'{range_name}')
        
        ax.set_title(f'{range_name} - {phase_label}\nMAE: {mean_error:.2f}% | Max Error: {max_error:.2f}%', 
                     fontsize=10, fontweight='bold')
        ax.set_xlabel('Time Step', fontsize=9)
        ax.set_ylabel('SOC (%)', fontsize=9)
        ax.legend(fontsize=8, loc='best')
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 100)

plt.suptitle('SOC Estimation in Different SOC Ranges (Charging & Discharging)', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# -------------------- Sample Predictions Visualization --------------------
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Sample predictions
for i in range(4):
    idx = np.random.randint(0, len(X_test))
    with torch.no_grad():
        x_sample = torch.tensor(X_test[idx:idx+1], dtype=torch.float32, device=device)
        pred = model(x_sample).cpu().numpy()[0, :, 0] * 100
    
    true = y_test[idx, :, 0]
    error = np.mean(np.abs(true - pred))
    phase_label = "Charging" if phases_test[idx] == 'C1ch' else "Discharging"
    
    ax = axes[i//2, i%2]
    ax.plot(true, 'k-', linewidth=2, label='True SOC')
    ax.plot(pred, 'r--', linewidth=2, label='LSTM Prediction')
    ax.set_title(f'{phase_label} - Sample {i+1} | MAE: {error:.2f}%', 
                 fontsize=11, fontweight='bold')
    ax.set_xlabel('Time Step')
    ax.set_ylabel('SOC (%)')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('LSTM SOC Estimation Results (Random Samples)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# -------------------- Overall Error Distribution --------------------
plt.figure(figsize=(10, 5))
errors = (predictions - actuals).flatten()

plt.subplot(1, 2, 1)
plt.hist(errors, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('Prediction Error (%)')
plt.ylabel('Frequency')
plt.title('Error Distribution')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.scatter(actuals.flatten(), predictions.flatten(), alpha=0.5, s=1)
plt.plot([0, 100], [0, 100], 'r--', linewidth=2, label='Perfect Prediction')
plt.xlabel('True SOC (%)')
plt.ylabel('Predicted SOC (%)')
plt.title('Prediction vs Actual')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nDone!")

Loading data...
Total sequences: 1038, Shape: (1038, 128, 3)
Train: 830 | Test: 208
Device: cpu

Parameters: 200,321

Training...
Epoch 010 | Train Loss: 0.001991 | Test Loss: 0.001479


KeyboardInterrupt: 

GRU

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

# -------------------- Data Loading --------------------
data = scipy.io.loadmat('Oxford_Battery_Degradation_Dataset_1.mat')

def soc_from_q(q, phase):
    """Compute SOC% from charge trace."""
    q0, q1 = float(q[0]), float(q[-1])
    if np.isclose(q1, q0):
        return None
    qn = (q - q0) / (q1 - q0)
    if phase == 'C1ch':
        return 100.0 * np.clip(qn, 0, 1)
    elif phase == 'C1dc':
        return 100.0 * (1.0 - np.clip(qn, 0, 1))
    return None

def extract_sequences(data, L=128):
    """Extract V, T, I features and SOC labels."""
    X_list, y_list, phase_list = [], [], []
    
    for ci in range(1, 9):
        cell = data[f'Cell{ci}']
        for cyc_name in sorted(cell.dtype.names, key=lambda s: int(s[3:])):
            cyc = cell[cyc_name][0, 0]
            for phase in ['C1ch', 'C1dc']:
                if phase not in cyc.dtype.names:
                    continue
                blk = cyc[phase][0, 0]
                if not all(k in blk.dtype.names for k in ['t','v','q']):
                    continue
                
                t = blk['t'][0,0].ravel().astype(float)
                v = blk['v'][0,0].ravel().astype(float)
                q = blk['q'][0,0].ravel().astype(float)
                
                if t.size < 5:
                    continue
                
                # Temperature (fill missing with forward/backward fill)
                if 'T' in blk.dtype.names:
                    T = blk['T'][0,0].ravel().astype(float)
                else:
                    T = np.full_like(t, 25.0)
                T = pd.Series(T).ffill().bfill().values
                
                # Current (use defaults if missing)
                if 'i' in blk.dtype.names:
                    I = blk['i'][0,0].ravel().astype(float)
                else:
                    I = np.full_like(t, 0.74 if phase == 'C1ch' else -0.74)
                
                soc = soc_from_q(q, phase)
                if soc is None:
                    continue
                
                # Resample to fixed length L
                t_new = np.linspace(t[0], t[-1], L)
                v_new = np.interp(t_new, t, v)
                T_new = np.interp(t_new, t, T)
                I_new = np.interp(t_new, t, I)
                y_new = np.interp(t_new, t, soc)
                
                X = np.stack([v_new, T_new, I_new], axis=-1)  # [L, 3]
                y = y_new[:, None]  # [L, 1]
                
                X_list.append(X)
                y_list.append(y)
                phase_list.append(phase)  # Track charging vs discharging
    
    return np.array(X_list), np.array(y_list), np.array(phase_list)

# -------------------- Prepare Data --------------------
print("Loading data...")
X, y, phases = extract_sequences(data, L=128)
print(f"Total sequences: {len(X)}, Shape: {X.shape}")

# Train/test split (80/20)
n_train = int(0.8 * len(X))
X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]
phases_test = phases[n_train:]

# Normalize features
mu = X_train.reshape(-1, X_train.shape[-1]).mean(axis=0)
sd = X_train.reshape(-1, X_train.shape[-1]).std(axis=0) + 1e-8
X_train = (X_train - mu) / sd
X_test = (X_test - mu) / sd

# Normalize labels to [0, 1]
y_train_norm = y_train / 100.0
y_test_norm = y_test / 100.0

print(f"Train: {len(X_train)} | Test: {len(X_test)}")

# -------------------- DataLoaders --------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}\n")

train_loader = DataLoader(
    TensorDataset(torch.tensor(X_train, dtype=torch.float32),
                  torch.tensor(y_train_norm, dtype=torch.float32)),
    batch_size=32, shuffle=True)

test_loader = DataLoader(
    TensorDataset(torch.tensor(X_test, dtype=torch.float32),
                  torch.tensor(y_test_norm, dtype=torch.float32)),
    batch_size=32, shuffle=False)

# -------------------- GRU Model --------------------
class GRU_SOC(nn.Module):
    def __init__(self, input_dim=3, hidden=128, layers=2, dropout=0.2):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden, num_layers=layers, 
                         batch_first=True, dropout=dropout if layers > 1 else 0)
        self.fc = nn.Linear(hidden, 1)
    
    def forward(self, x):
        # x shape: [batch, seq_len, input_dim]
        out, _ = self.gru(x)  # out shape: [batch, seq_len, hidden]
        return self.fc(out)   # shape: [batch, seq_len, 1]

model = GRU_SOC(input_dim=3, hidden=128, layers=2, dropout=0.2).to(device)
print(f"Model: GRU with {sum(p.numel() for p in model.parameters()):,} parameters\n")

# -------------------- Training --------------------
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

epochs = 100
best_loss = float('inf')

print("Training GRU model...")
for epoch in range(epochs):
    # Train
    model.train()
    train_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        pred = model(xb)
        loss = loss_fn(pred, yb)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)
    
    # Validate
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            test_loss += loss_fn(model(xb), yb).item() * xb.size(0)
    test_loss /= len(test_loader.dataset)
    
    # Save best model
    if test_loss < best_loss:
        best_loss = test_loss
        torch.save(model.state_dict(), 'best_gru.pt')
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.6f} | Test Loss: {test_loss:.6f}")

print(f"\nBest Test Loss: {best_loss:.6f}")

# -------------------- Validation --------------------
model.load_state_dict(torch.load('best_gru.pt'))
model.eval()

# Get predictions
predictions, actuals = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        pred = model(xb).cpu().numpy()
        predictions.append(pred)
        actuals.append(yb.numpy())

predictions = np.concatenate(predictions) * 100  # Convert back to %
actuals = np.concatenate(actuals) * 100

# Calculate overall metrics
mae = np.mean(np.abs(predictions - actuals))
rmse = np.sqrt(np.mean((predictions - actuals)**2))
max_error = np.max(np.abs(predictions - actuals))

print(f"\n{'='*60}")
print("Overall Validation Metrics (GRU):")
print(f"{'='*60}")
print(f"MAE:        {mae:.4f}%")
print(f"RMSE:       {rmse:.4f}%")
print(f"Max Error:  {max_error:.4f}%")

# -------------------- SOC Range Analysis --------------------
print(f"\n{'='*60}")
print("SOC Range Analysis")
print(f"{'='*60}")

# Define SOC ranges to test
soc_ranges = [
    (0, 30, "Low SOC (0-30%)"),
    (30, 60, "Mid SOC (30-60%)"),
    (60, 90, "High SOC (60-90%)"),
    (0, 100, "Full Range (0-100%)")
]

range_metrics = []

for soc_min, soc_max, range_name in soc_ranges:
    # Filter predictions and actuals for this SOC range
    mask = (actuals >= soc_min) & (actuals <= soc_max)
    
    if np.sum(mask) == 0:
        print(f"\n{range_name}: No data points in this range")
        continue
    
    range_preds = predictions[mask]
    range_actuals = actuals[mask]
    
    # Calculate metrics for this range
    range_mae = np.mean(np.abs(range_preds - range_actuals))
    range_rmse = np.sqrt(np.mean((range_preds - range_actuals)**2))
    range_max_error = np.max(np.abs(range_preds - range_actuals))
    n_points = np.sum(mask)
    
    range_metrics.append({
        'Range': range_name,
        'MAE': range_mae,
        'RMSE': range_rmse,
        'Max Error': range_max_error,
        'N Points': n_points
    })
    
    print(f"\n{range_name}:")
    print(f"  Points:     {n_points:,}")
    print(f"  MAE:        {range_mae:.4f}%")
    print(f"  RMSE:       {range_rmse:.4f}%")
    print(f"  Max Error:  {range_max_error:.4f}%")

# Create DataFrame for easier visualization
metrics_df = pd.DataFrame(range_metrics)
print(f"\n{'='*60}")
print("Summary Table:")
print(f"{'='*60}")
print(metrics_df.to_string(index=False))
print(f"{'='*60}\n")

# -------------------- Phase-specific Analysis --------------------
print(f"\n{'='*60}")
print("Charging vs Discharging Analysis")
print(f"{'='*60}")

phase_metrics = []

for phase_type in ['C1ch', 'C1dc']:
    phase_name = "Charging" if phase_type == 'C1ch' else "Discharging"
    
    # Get indices for this phase
    phase_mask = phases_test == phase_type
    
    if np.sum(phase_mask) == 0:
        continue
    
    # Get predictions for this phase
    phase_preds = predictions[phase_mask]
    phase_actuals = actuals[phase_mask]
    
    # Overall phase metrics
    phase_mae = np.mean(np.abs(phase_preds - phase_actuals))
    phase_rmse = np.sqrt(np.mean((phase_preds - phase_actuals)**2))
    
    print(f"\n{phase_name}:")
    print(f"  Sequences:  {np.sum(phase_mask)}")
    print(f"  MAE:        {phase_mae:.4f}%")
    print(f"  RMSE:       {phase_rmse:.4f}%")
    
    # Break down by SOC ranges
    for soc_min, soc_max, range_name in soc_ranges[:3]:  # Only low, mid, high
        range_mask = (phase_actuals >= soc_min) & (phase_actuals <= soc_max)
        
        if np.sum(range_mask) > 0:
            sub_preds = phase_preds[range_mask]
            sub_actuals = phase_actuals[range_mask]
            sub_mae = np.mean(np.abs(sub_preds - sub_actuals))
            sub_rmse = np.sqrt(np.mean((sub_preds - sub_actuals)**2))
            
            print(f"    {range_name}:")
            print(f"      Points: {np.sum(range_mask):,} | MAE: {sub_mae:.4f}% | RMSE: {sub_rmse:.4f}%")
            
            phase_metrics.append({
                'Phase': phase_name,
                'Range': range_name,
                'MAE': sub_mae,
                'RMSE': sub_rmse,
                'N Points': np.sum(range_mask)
            })

# -------------------- Partial Curve Testing --------------------
print(f"\n{'='*60}")
print("Testing on Partial Charge/Discharge Curves")
print(f"{'='*60}")

def create_partial_curve(X_seq, y_seq, soc_start, soc_end):
    """
    Extract a partial curve from a full charge/discharge sequence.
    
    Args:
        X_seq: Input features [L, 3]
        y_seq: SOC labels [L, 1]
        soc_start: Starting SOC%
        soc_end: Ending SOC%
    
    Returns:
        Partial X, y sequences
    """
    soc_values = y_seq[:, 0]
    
    # Find indices where SOC is within the range
    if soc_start < soc_end:  # Charging direction
        mask = (soc_values >= soc_start) & (soc_values <= soc_end)
    else:  # Discharging direction
        mask = (soc_values <= soc_start) & (soc_values >= soc_end)
    
    if np.sum(mask) < 10:  # Need at least 10 points
        return None, None
    
    indices = np.where(mask)[0]
    start_idx, end_idx = indices[0], indices[-1] + 1
    
    return X_seq[start_idx:end_idx], y_seq[start_idx:end_idx]

# Define partial curve test cases
partial_curve_tests = [
    # Charging scenarios
    (0, 30, 'C1ch', 'Charge 0-30%'),
    (30, 60, 'C1ch', 'Charge 30-60%'),
    (60, 90, 'C1ch', 'Charge 60-90%'),
    (20, 80, 'C1ch', 'Charge 20-80%'),
    
    # Discharging scenarios
    (100, 70, 'C1dc', 'Discharge 100-70%'),
    (70, 40, 'C1dc', 'Discharge 70-40%'),
    (40, 10, 'C1dc', 'Discharge 40-10%'),
    (80, 20, 'C1dc', 'Discharge 80-20%'),
]

partial_results = []

for soc_start, soc_end, phase_type, test_name in partial_curve_tests:
    # Find test sequences of the right phase
    phase_mask = phases_test == phase_type
    phase_indices = np.where(phase_mask)[0]
    
    if len(phase_indices) == 0:
        print(f"\n{test_name}: No {phase_type} sequences available")
        continue
    
    # Process multiple sequences for this test case
    test_preds, test_actuals = [], []
    valid_sequences = 0
    
    for idx in phase_indices[:50]:  # Test on up to 50 sequences
        X_seq = X_test[idx]
        y_seq = y_test[idx]
        
        # Create partial curve
        X_partial, y_partial = create_partial_curve(X_seq, y_seq, soc_start, soc_end)
        
        if X_partial is None:
            continue
        
        valid_sequences += 1
        
        # Get prediction for partial curve
        with torch.no_grad():
            x_input = torch.tensor(X_partial[np.newaxis, :, :], dtype=torch.float32, device=device)
            pred = model(x_input).cpu().numpy()[0, :, 0] * 100
        
        true = y_partial[:, 0]
        
        test_preds.append(pred)
        test_actuals.append(true)
    
    if valid_sequences == 0:
        print(f"\n{test_name}: No valid partial curves found in range")
        continue
    
    # Concatenate all predictions
    all_preds = np.concatenate(test_preds)
    all_actuals = np.concatenate(test_actuals)
    
    # Calculate metrics
    mae = np.mean(np.abs(all_preds - all_actuals))
    rmse = np.sqrt(np.mean((all_preds - all_actuals)**2))
    max_error = np.max(np.abs(all_preds - all_actuals))
    
    partial_results.append({
        'Test': test_name,
        'Phase': phase_type,
        'SOC Range': f"{soc_start}-{soc_end}%",
        'Sequences': valid_sequences,
        'Points': len(all_preds),
        'MAE': mae,
        'RMSE': rmse,
        'Max Error': max_error
    })
    
    print(f"\n{test_name}:")
    print(f"  Valid Sequences: {valid_sequences}")
    print(f"  Total Points:    {len(all_preds):,}")
    print(f"  MAE:             {mae:.4f}%")
    print(f"  RMSE:            {rmse:.4f}%")
    print(f"  Max Error:       {max_error:.4f}%")

# Create summary table
if len(partial_results) > 0:
    partial_df = pd.DataFrame(partial_results)
    print(f"\n{'='*60}")
    print("Partial Curve Testing Summary:")
    print(f"{'='*60}")
    print(partial_df.to_string(index=False))
    print(f"{'='*60}\n")

# -------------------- Partial Curve Visualization --------------------
print("\nGenerating partial curve visualizations...")

fig, axes = plt.subplots(4, 2, figsize=(14, 16))

# Select interesting partial curve examples
visualization_tests = [
    (0, 30, 'C1ch', 'Charge 0-30%', 0, 0),
    (30, 60, 'C1ch', 'Charge 30-60%', 0, 1),
    (60, 90, 'C1ch', 'Charge 60-90%', 1, 0),
    (20, 80, 'C1ch', 'Charge 20-80%', 1, 1),
    (100, 70, 'C1dc', 'Discharge 100-70%', 2, 0),
    (70, 40, 'C1dc', 'Discharge 70-40%', 2, 1),
    (40, 10, 'C1dc', 'Discharge 40-10%', 3, 0),
    (80, 20, 'C1dc', 'Discharge 80-20%', 3, 1),
]

for soc_start, soc_end, phase_type, test_name, row, col in visualization_tests:
    ax = axes[row, col]
    
    # Find a suitable test sequence
    phase_mask = phases_test == phase_type
    phase_indices = np.where(phase_mask)[0]
    
    found = False
    for idx in phase_indices[:100]:  # Search through sequences
        X_seq = X_test[idx]
        y_seq = y_test[idx]
        
        # Create partial curve
        X_partial, y_partial = create_partial_curve(X_seq, y_seq, soc_start, soc_end)
        
        if X_partial is None or len(X_partial) < 10:
            continue
        
        # Get prediction
        with torch.no_grad():
            x_input = torch.tensor(X_partial[np.newaxis, :, :], dtype=torch.float32, device=device)
            pred = model(x_input).cpu().numpy()[0, :, 0] * 100
        
        true = y_partial[:, 0]
        error = np.abs(true - pred)
        mae = np.mean(error)
        max_error = np.max(error)
        
        # Plot
        time_steps = np.arange(len(true))
        ax.plot(time_steps, true, 'b-', linewidth=2.5, label='True SOC', alpha=0.8)
        ax.plot(time_steps, pred, 'r--', linewidth=2, label='GRU Prediction', alpha=0.8)
        ax.fill_between(time_steps, true, pred, alpha=0.2, color='orange')
        
        # Add horizontal lines for SOC range
        ax.axhline(soc_start, color='green', linestyle=':', linewidth=1.5, alpha=0.7, label=f'Start: {soc_start}%')
        if phase_type == 'C1ch':
            ax.axhline(soc_end, color='red', linestyle=':', linewidth=1.5, alpha=0.7, label=f'End: {soc_end}%')
        else:
            ax.axhline(soc_end, color='red', linestyle=':', linewidth=1.5, alpha=0.7, label=f'End: {soc_end}%')
        
        ax.set_title(f'{test_name}\nMAE: {mae:.2f}% | Max Error: {max_error:.2f}%', 
                     fontsize=10, fontweight='bold')
        ax.set_xlabel('Time Step', fontsize=9)
        ax.set_ylabel('SOC (%)', fontsize=9)
        ax.legend(fontsize=7, loc='best')
        ax.grid(True, alpha=0.3)
        ax.set_ylim(max(0, min(soc_start, soc_end) - 10), 
                    min(100, max(soc_start, soc_end) + 10))
        
        found = True
        break
    
    if not found:
        ax.text(0.5, 0.5, f'No valid\n{test_name}\ndata found',
               ha='center', va='center', fontsize=10, transform=ax.transAxes)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

plt.suptitle('GRU Performance on Partial Charge/Discharge Curves', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# -------------------- Comparison: Full vs Partial Curves --------------------
if len(partial_results) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Separate charging and discharging
    charging_partial = partial_df[partial_df['Phase'] == 'C1ch']
    discharging_partial = partial_df[partial_df['Phase'] == 'C1dc']
    
    # Plot MAE comparison
    ax = axes[0]
    if len(charging_partial) > 0:
        x_charge = np.arange(len(charging_partial))
        ax.bar(x_charge - 0.2, charging_partial['MAE'], 0.4, 
               label='Charging', alpha=0.8, color='green')
    if len(discharging_partial) > 0:
        x_discharge = np.arange(len(discharging_partial))
        ax.bar(x_discharge + 0.2, discharging_partial['MAE'], 0.4, 
               label='Discharging', alpha=0.8, color='orange')
    
    all_tests = pd.concat([charging_partial, discharging_partial]) if len(discharging_partial) > 0 else charging_partial
    ax.set_xticks(np.arange(len(all_tests)))
    ax.set_xticklabels([t.replace('Charge ', 'C ').replace('Discharge ', 'D ') 
                        for t in all_tests['Test']], rotation=45, ha='right')
    ax.set_ylabel('MAE (%)')
    ax.set_title('MAE for Partial Curves (GRU)', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot RMSE comparison
    ax = axes[1]
    if len(charging_partial) > 0:
        ax.bar(x_charge - 0.2, charging_partial['RMSE'], 0.4, 
               label='Charging', alpha=0.8, color='green')
    if len(discharging_partial) > 0:
        ax.bar(x_discharge + 0.2, discharging_partial['RMSE'], 0.4, 
               label='Discharging', alpha=0.8, color='orange')
    
    ax.set_xticks(np.arange(len(all_tests)))
    ax.set_xticklabels([t.replace('Charge ', 'C ').replace('Discharge ', 'D ') 
                        for t in all_tests['Test']], rotation=45, ha='right')
    ax.set_ylabel('RMSE (%)')
    ax.set_title('RMSE for Partial Curves (GRU)', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()

# -------------------- Range-specific Visualization --------------------
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Plot metrics comparison
ax = axes[0, 0]
x_pos = np.arange(len(range_metrics))
ax.bar(x_pos, [m['MAE'] for m in range_metrics], alpha=0.7, color='steelblue')
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('MAE (%)')
ax.set_title('MAE by SOC Range (GRU)', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

ax = axes[0, 1]
ax.bar(x_pos, [m['RMSE'] for m in range_metrics], alpha=0.7, color='coral')
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('RMSE (%)')
ax.set_title('RMSE by SOC Range (GRU)', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

ax = axes[0, 2]
ax.bar(x_pos, [m['Max Error'] for m in range_metrics], alpha=0.7, color='lightcoral')
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('Max Error (%)')
ax.set_title('Max Error by SOC Range (GRU)', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Plot error distribution by range
for idx, (soc_min, soc_max, range_name) in enumerate(soc_ranges[:3]):
    ax = axes[1, idx]
    mask = (actuals >= soc_min) & (actuals <= soc_max)
    
    if np.sum(mask) > 0:
        range_errors = (predictions[mask] - actuals[mask]).flatten()
        ax.hist(range_errors, bins=30, edgecolor='black', alpha=0.7, color='teal')
        ax.axvline(0, color='red', linestyle='--', linewidth=2, label='Zero Error')
        ax.set_xlabel('Error (%)')
        ax.set_ylabel('Frequency')
        ax.set_title(f'{range_name}\nMean Error: {np.mean(range_errors):.3f}%', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)

plt.suptitle('GRU SOC Range Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# -------------------- Phase-specific Visualization --------------------
if len(phase_metrics) > 0:
    phase_df = pd.DataFrame(phase_metrics)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Grouped bar chart for MAE
    ax = axes[0]
    ranges_unique = phase_df['Range'].unique()
    x = np.arange(len(ranges_unique))
    width = 0.35
    
    charging_data = phase_df[phase_df['Phase'] == 'Charging']
    discharging_data = phase_df[phase_df['Phase'] == 'Discharging']
    
    charge_mae = [charging_data[charging_data['Range'] == r]['MAE'].values[0] 
                  if len(charging_data[charging_data['Range'] == r]) > 0 else 0 
                  for r in ranges_unique]
    discharge_mae = [discharging_data[discharging_data['Range'] == r]['MAE'].values[0] 
                     if len(discharging_data[discharging_data['Range'] == r]) > 0 else 0 
                     for r in ranges_unique]
    
    ax.bar(x - width/2, charge_mae, width, label='Charging', alpha=0.8, color='green')
    ax.bar(x + width/2, discharge_mae, width, label='Discharging', alpha=0.8, color='orange')
    ax.set_xticks(x)
    ax.set_xticklabels([r.split('(')[0].strip() for r in ranges_unique], rotation=45, ha='right')
    ax.set_ylabel('MAE (%)')
    ax.set_title('MAE: Charging vs Discharging by SOC Range (GRU)', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    # Grouped bar chart for RMSE
    ax = axes[1]
    charge_rmse = [charging_data[charging_data['Range'] == r]['RMSE'].values[0] 
                   if len(charging_data[charging_data['Range'] == r]) > 0 else 0 
                   for r in ranges_unique]
    discharge_rmse = [discharging_data[discharging_data['Range'] == r]['RMSE'].values[0] 
                      if len(discharging_data[discharging_data['Range'] == r]) > 0 else 0 
                      for r in ranges_unique]
    
    ax.bar(x - width/2, charge_rmse, width, label='Charging', alpha=0.8, color='green')
    ax.bar(x + width/2, discharge_rmse, width, label='Discharging', alpha=0.8, color='orange')
    ax.set_xticks(x)
    ax.set_xticklabels([r.split('(')[0].strip() for r in ranges_unique], rotation=45, ha='right')
    ax.set_ylabel('RMSE (%)')
    ax.set_title('RMSE: Charging vs Discharging by SOC Range (GRU)', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()

# -------------------- Sample Predictions Visualization --------------------
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Sample predictions
for i in range(4):
    idx = np.random.randint(0, len(X_test))
    with torch.no_grad():
        x_sample = torch.tensor(X_test[idx:idx+1], dtype=torch.float32, device=device)
        pred = model(x_sample).cpu().numpy()[0, :, 0] * 100
    
    true = y_test[idx, :, 0]
    error = np.mean(np.abs(true - pred))
    phase_label = "Charging" if phases_test[idx] == 'C1ch' else "Discharging"
    
    ax = axes[i//2, i%2]
    ax.plot(true, 'k-', linewidth=2, label='True SOC')
    ax.plot(pred, 'r--', linewidth=2, label='GRU Prediction')
    ax.set_title(f'{phase_label} - Sample {i+1} | MAE: {error:.2f}%', 
                 fontsize=11, fontweight='bold')
    ax.set_xlabel('Time Step')
    ax.set_ylabel('SOC (%)')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('GRU SOC Estimation Results (Random Samples)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# -------------------- Overall Error Distribution --------------------
plt.figure(figsize=(10, 5))
errors = (predictions - actuals).flatten()

plt.subplot(1, 2, 1)
plt.hist(errors, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('Prediction Error (%)')
plt.ylabel('Frequency')
plt.title('Error Distribution (GRU)')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.scatter(actuals.flatten(), predictions.flatten(), alpha=0.5, s=1)
plt.plot([0, 100], [0, 100], 'r--', linewidth=2, label='Perfect Prediction')
plt.xlabel('True SOC (%)')
plt.ylabel('Predicted SOC (%)')
plt.title('Prediction vs Actual (GRU)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nDone!")

GRU-PINN

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

# -------------------- Data Loading --------------------
data = scipy.io.loadmat('Oxford_Battery_Degradation_Dataset_1.mat')

def soc_from_q(q, phase):
    """Compute SOC% from charge trace."""
    q0, q1 = float(q[0]), float(q[-1])
    if np.isclose(q1, q0):
        return None
    qn = (q - q0) / (q1 - q0)
    if phase == 'C1ch':
        return 100.0 * np.clip(qn, 0, 1)
    elif phase == 'C1dc':
        return 100.0 * (1.0 - np.clip(qn, 0, 1))
    return None

def extract_sequences_with_physics(data, L=128):
    """Extract V, T, I features and SOC labels with physics-based features."""
    X_list, y_list, phase_list = [], [], []
    
    for ci in range(1, 9):
        cell = data[f'Cell{ci}']
        for cyc_name in sorted(cell.dtype.names, key=lambda s: int(s[3:])):
            cyc = cell[cyc_name][0, 0]
            for phase in ['C1ch', 'C1dc']:
                if phase not in cyc.dtype.names:
                    continue
                blk = cyc[phase][0, 0]
                if not all(k in blk.dtype.names for k in ['t','v','q']):
                    continue
                
                t = blk['t'][0,0].ravel().astype(float)
                v = blk['v'][0,0].ravel().astype(float)
                q = blk['q'][0,0].ravel().astype(float)
                
                if t.size < 5:
                    continue
                
                # Temperature (fill missing with forward/backward fill)
                if 'T' in blk.dtype.names:
                    T = blk['T'][0,0].ravel().astype(float)
                else:
                    T = np.full_like(t, 25.0)
                T = pd.Series(T).ffill().bfill().values
                
                # Current (use defaults if missing)
                if 'i' in blk.dtype.names:
                    I = blk['i'][0,0].ravel().astype(float)
                else:
                    I = np.full_like(t, 0.74 if phase == 'C1ch' else -0.74)
                
                soc = soc_from_q(q, phase)
                if soc is None:
                    continue
                
                # Resample to fixed length L
                t_new = np.linspace(t[0], t[-1], L)
                v_new = np.interp(t_new, t, v)
                T_new = np.interp(t_new, t, T)
                I_new = np.interp(t_new, t, I)
                q_new = np.interp(t_new, t, q)
                y_new = np.interp(t_new, t, soc)
                
                # ============ PHYSICS-BASED FEATURES ============
                
                # 1. Voltage derivative (dV/dt) - rate of voltage change
                dv_dt = np.gradient(v_new, t_new)
                
                # 2. Current integral (Coulomb counting proxy)
                # Cumulative charge from start
                dt = np.diff(t_new, prepend=t_new[0])
                cumulative_q = np.cumsum(I_new * dt)
                
                # 3. Temperature derivative (dT/dt) - thermal dynamics
                dT_dt = np.gradient(T_new, t_new)
                
                # 4. Voltage per current (approximate resistance)
                # R ≈ ΔV/I (avoid division by zero)
                resistance = np.where(np.abs(I_new) > 0.01, 
                                     (v_new - np.mean(v_new)) / (I_new + 1e-6), 
                                     0.0)
                
                # 5. Power (P = V * I)
                power = v_new * I_new
                
                # 6. Energy integral (∫P dt)
                energy = np.cumsum(power * dt)
                
                # 7. Charge direction indicator
                charge_indicator = np.sign(I_new)  # +1 for charging, -1 for discharging
                
                # 8. Normalized capacity (charge throughput)
                # This is similar to SOC but computed from current integration
                if phase == 'C1ch':
                    q_norm = (cumulative_q - cumulative_q[0]) / (cumulative_q[-1] - cumulative_q[0] + 1e-6)
                else:
                    q_norm = 1.0 - (cumulative_q - cumulative_q[0]) / (cumulative_q[-1] - cumulative_q[0] + 1e-6)
                q_norm = np.clip(q_norm, 0, 1)
                
                # 9. Voltage-SOC correlation feature (empirical OCV)
                # Approximate open circuit voltage based on typical Li-ion OCV curve
                # OCV ≈ 3.0 + 1.2 * SOC_normalized (simplified linear model)
                estimated_ocv = 3.0 + 1.2 * (q_norm)
                ocv_deviation = v_new - estimated_ocv
                
                # 10. Temperature-normalized voltage
                # Voltage typically changes ~0.5mV/°C
                T_ref = 25.0
                v_temp_compensated = v_new - 0.0005 * (T_new - T_ref)
                
                # Stack all features: [L, 13]
                # Original: V, T, I (3)
                # Physics: dV/dt, cumulative_q, dT/dt, resistance, power, energy,
                #          charge_indicator, q_norm, ocv_deviation, v_temp_compensated (10)
                X = np.stack([
                    v_new,                  # 0: Voltage
                    T_new,                  # 1: Temperature
                    I_new,                  # 2: Current
                    dv_dt,                  # 3: Voltage derivative
                    cumulative_q,           # 4: Cumulative charge
                    dT_dt,                  # 5: Temperature derivative
                    resistance,             # 6: Approximate resistance
                    power,                  # 7: Power
                    energy,                 # 8: Cumulative energy
                    charge_indicator,       # 9: Charge direction
                    q_norm,                 # 10: Normalized capacity
                    ocv_deviation,          # 11: OCV deviation
                    v_temp_compensated      # 12: Temperature-compensated voltage
                ], axis=-1)
                
                y = y_new[:, None]  # [L, 1]
                
                X_list.append(X)
                y_list.append(y)
                phase_list.append(phase)
    
    return np.array(X_list), np.array(y_list), np.array(phase_list)

# -------------------- Prepare Data --------------------
print("Loading data with physics-based features...")
X, y, phases = extract_sequences_with_physics(data, L=128)
print(f"Total sequences: {len(X)}, Shape: {X.shape}")
print(f"Feature dimensions: {X.shape[-1]} (13 physics-informed features)")

# Train/test split (80/20)
n_train = int(0.8 * len(X))
X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]
phases_test = phases[n_train:]

# Normalize features
mu = X_train.reshape(-1, X_train.shape[-1]).mean(axis=0)
sd = X_train.reshape(-1, X_train.shape[-1]).std(axis=0) + 1e-8
X_train = (X_train - mu) / sd
X_test = (X_test - mu) / sd

# Normalize labels to [0, 1]
y_train_norm = y_train / 100.0
y_test_norm = y_test / 100.0

print(f"Train: {len(X_train)} | Test: {len(X_test)}")

# -------------------- DataLoaders --------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}\n")

train_loader = DataLoader(
    TensorDataset(torch.tensor(X_train, dtype=torch.float32),
                  torch.tensor(y_train_norm, dtype=torch.float32)),
    batch_size=32, shuffle=True)

test_loader = DataLoader(
    TensorDataset(torch.tensor(X_test, dtype=torch.float32),
                  torch.tensor(y_test_norm, dtype=torch.float32)),
    batch_size=32, shuffle=False)

# -------------------- Physics-Informed GRU Model --------------------
class PhysicsInformedGRU(nn.Module):
    """
    Physics-Informed GRU for SOC estimation with:
    1. Physics-based features
    2. Physics-constrained loss
    3. Monotonicity enforcement (optional)
    """
    def __init__(self, input_dim=13, hidden=128, layers=2, dropout=0.2):
        super().__init__()
        
        # Feature extraction layer for physics features
        self.feature_encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        
        # GRU layers
        self.gru = nn.GRU(64, hidden, num_layers=layers, 
                         batch_first=True, dropout=dropout if layers > 1 else 0)
        
        # Output layers with residual connection
        self.fc1 = nn.Linear(hidden, 64)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.fc2 = nn.Linear(64, 1)
        
        # Physics-based attention (weight different features)
        self.feature_attention = nn.Linear(input_dim, input_dim)
    
    def forward(self, x):
        """
        Forward pass with physics-informed processing
        
        Args:
            x: [batch, seq_len, input_dim]
        
        Returns:
            soc: [batch, seq_len, 1]
        """
        # Apply attention to input features
        attention_weights = torch.sigmoid(self.feature_attention(x))
        x_attended = x * attention_weights
        
        # Encode features
        x_encoded = self.feature_encoder(x_attended)
        
        # GRU processing
        gru_out, _ = self.gru(x_encoded)
        
        # Output layers
        out = self.fc1(gru_out)
        out = self.relu(out)
        out = self.dropout(out)
        soc = self.fc2(out)
        
        return soc

class PhysicsInformedLoss(nn.Module):
    """
    Custom loss function incorporating physics constraints
    """
    def __init__(self, lambda_mse=1.0, lambda_monotonic=0.1, lambda_coulomb=0.05):
        super().__init__()
        self.lambda_mse = lambda_mse
        self.lambda_monotonic = lambda_monotonic
        self.lambda_coulomb = lambda_coulomb
        self.mse = nn.MSELoss()
    
    def forward(self, pred, target, features):
        """
        Args:
            pred: [batch, seq_len, 1] - predicted SOC
            target: [batch, seq_len, 1] - true SOC
            features: [batch, seq_len, input_dim] - input features
        """
        # 1. Standard MSE loss
        loss_mse = self.mse(pred, target)
        
        # 2. Monotonicity constraint (SOC should increase/decrease monotonically)
        # Extract charge direction (feature index 9)
        charge_direction = features[:, :, 9:10]  # [batch, seq_len, 1]
        
        # Compute SOC changes
        soc_diff = pred[:, 1:, :] - pred[:, :-1, :]  # [batch, seq_len-1, 1]
        expected_sign = charge_direction[:, 1:, :]  # [batch, seq_len-1, 1]
        
        # Penalize when SOC changes in wrong direction
        # For charging (sign=+1), soc_diff should be positive
        # For discharging (sign=-1), soc_diff should be negative
        monotonic_violation = torch.relu(-soc_diff * expected_sign)
        loss_monotonic = torch.mean(monotonic_violation)
        
        # 3. Coulomb counting consistency
        # Compare predicted SOC with integrated current (feature index 10: q_norm)
        q_norm = features[:, :, 10:11]  # [batch, seq_len, 1]
        # q_norm is already normalized to [0, 1], so compare directly
        loss_coulomb = torch.mean((pred - q_norm) ** 2)
        
        # Total loss
        total_loss = (self.lambda_mse * loss_mse + 
                     self.lambda_monotonic * loss_monotonic +
                     self.lambda_coulomb * loss_coulomb)
        
        return total_loss, loss_mse, loss_monotonic, loss_coulomb

model = PhysicsInformedGRU(input_dim=13, hidden=128, layers=2, dropout=0.2).to(device)
print(f"Physics-Informed GRU with {sum(p.numel() for p in model.parameters()):,} parameters\n")

print("Physics-based features included:")
print("  1. Voltage (V)")
print("  2. Temperature (T)")
print("  3. Current (I)")
print("  4. Voltage derivative (dV/dt)")
print("  5. Cumulative charge (∫I dt)")
print("  6. Temperature derivative (dT/dt)")
print("  7. Approximate resistance (V/I)")
print("  8. Power (V*I)")
print("  9. Cumulative energy (∫P dt)")
print(" 10. Charge direction indicator")
print(" 11. Normalized capacity (Coulomb counting)")
print(" 12. OCV deviation")
print(" 13. Temperature-compensated voltage\n")

# -------------------- Training --------------------
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                        factor=0.5, patience=10)
physics_loss_fn = PhysicsInformedLoss(lambda_mse=1.0, lambda_monotonic=0.1, lambda_coulomb=0.05)

epochs = 100
best_loss = float('inf')

print("Training Physics-Informed GRU...")
print(f"{'='*80}")

for epoch in range(epochs):
    # Train
    model.train()
    train_loss = 0.0
    train_mse = 0.0
    train_mono = 0.0
    train_coulomb = 0.0
    
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        pred = model(xb)
        
        # Physics-informed loss
        loss, mse, mono, coulomb = physics_loss_fn(pred, yb, xb)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        train_loss += loss.item() * xb.size(0)
        train_mse += mse.item() * xb.size(0)
        train_mono += mono.item() * xb.size(0)
        train_coulomb += coulomb.item() * xb.size(0)
    
    train_loss /= len(train_loader.dataset)
    train_mse /= len(train_loader.dataset)
    train_mono /= len(train_loader.dataset)
    train_coulomb /= len(train_loader.dataset)
    
    # Validate
    model.eval()
    test_loss = 0.0
    test_mse = 0.0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            loss, mse, _, _ = physics_loss_fn(pred, yb, xb)
            test_loss += loss.item() * xb.size(0)
            test_mse += mse.item() * xb.size(0)
    
    test_loss /= len(test_loader.dataset)
    test_mse /= len(test_loader.dataset)
    
    # Learning rate scheduling
    scheduler.step(test_loss)
    
    # Save best model
    if test_loss < best_loss:
        best_loss = test_loss
        torch.save(model.state_dict(), 'best_pigru.pt')
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:03d} | "
              f"Train: {train_loss:.6f} (MSE: {train_mse:.6f}, Mono: {train_mono:.6f}, Coulomb: {train_coulomb:.6f}) | "
              f"Test: {test_loss:.6f}")

print(f"\n{'='*80}")
print(f"Best Test Loss: {best_loss:.6f}\n")

# -------------------- Validation --------------------
model.load_state_dict(torch.load('best_pigru.pt'))
model.eval()

# Get predictions
predictions, actuals = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        pred = model(xb).cpu().numpy()
        predictions.append(pred)
        actuals.append(yb.numpy())

predictions = np.concatenate(predictions) * 100  # Convert back to %
actuals = np.concatenate(actuals) * 100

# Calculate overall metrics
mae = np.mean(np.abs(predictions - actuals))
rmse = np.sqrt(np.mean((predictions - actuals)**2))
max_error = np.max(np.abs(predictions - actuals))

print(f"{'='*60}")
print("Overall Validation Metrics (Physics-Informed GRU):")
print(f"{'='*60}")
print(f"MAE:        {mae:.4f}%")
print(f"RMSE:       {rmse:.4f}%")
print(f"Max Error:  {max_error:.4f}%")

# -------------------- SOC Range Analysis --------------------
print(f"\n{'='*60}")
print("SOC Range Analysis")
print(f"{'='*60}")

# Define SOC ranges to test
soc_ranges = [
    (0, 30, "Low SOC (0-30%)"),
    (30, 60, "Mid SOC (30-60%)"),
    (60, 90, "High SOC (60-90%)"),
    (0, 100, "Full Range (0-100%)")
]

range_metrics = []

for soc_min, soc_max, range_name in soc_ranges:
    # Filter predictions and actuals for this SOC range
    mask = (actuals >= soc_min) & (actuals <= soc_max)
    
    if np.sum(mask) == 0:
        print(f"\n{range_name}: No data points in this range")
        continue
    
    range_preds = predictions[mask]
    range_actuals = actuals[mask]
    
    # Calculate metrics for this range
    range_mae = np.mean(np.abs(range_preds - range_actuals))
    range_rmse = np.sqrt(np.mean((range_preds - range_actuals)**2))
    range_max_error = np.max(np.abs(range_preds - range_actuals))
    n_points = np.sum(mask)
    
    range_metrics.append({
        'Range': range_name,
        'MAE': range_mae,
        'RMSE': range_rmse,
        'Max Error': range_max_error,
        'N Points': n_points
    })
    
    print(f"\n{range_name}:")
    print(f"  Points:     {n_points:,}")
    print(f"  MAE:        {range_mae:.4f}%")
    print(f"  RMSE:       {range_rmse:.4f}%")
    print(f"  Max Error:  {range_max_error:.4f}%")

# Create DataFrame for easier visualization
metrics_df = pd.DataFrame(range_metrics)
print(f"\n{'='*60}")
print("Summary Table:")
print(f"{'='*60}")
print(metrics_df.to_string(index=False))
print(f"{'='*60}\n")

# -------------------- Phase-specific Analysis --------------------
print(f"\n{'='*60}")
print("Charging vs Discharging Analysis")
print(f"{'='*60}")

phase_metrics = []

for phase_type in ['C1ch', 'C1dc']:
    phase_name = "Charging" if phase_type == 'C1ch' else "Discharging"
    
    # Get indices for this phase
    phase_mask = phases_test == phase_type
    
    if np.sum(phase_mask) == 0:
        continue
    
    # Get predictions for this phase
    phase_preds = predictions[phase_mask]
    phase_actuals = actuals[phase_mask]
    
    # Overall phase metrics
    phase_mae = np.mean(np.abs(phase_preds - phase_actuals))
    phase_rmse = np.sqrt(np.mean((phase_preds - phase_actuals)**2))
    
    print(f"\n{phase_name}:")
    print(f"  Sequences:  {np.sum(phase_mask)}")
    print(f"  MAE:        {phase_mae:.4f}%")
    print(f"  RMSE:       {phase_rmse:.4f}%")
    
    # Break down by SOC ranges
    for soc_min, soc_max, range_name in soc_ranges[:3]:  # Only low, mid, high
        range_mask = (phase_actuals >= soc_min) & (phase_actuals <= soc_max)
        
        if np.sum(range_mask) > 0:
            sub_preds = phase_preds[range_mask]
            sub_actuals = phase_actuals[range_mask]
            sub_mae = np.mean(np.abs(sub_preds - sub_actuals))
            sub_rmse = np.sqrt(np.mean((sub_preds - sub_actuals)**2))
            
            print(f"    {range_name}:")
            print(f"      Points: {np.sum(range_mask):,} | MAE: {sub_mae:.4f}% | RMSE: {sub_rmse:.4f}%")
            
            phase_metrics.append({
                'Phase': phase_name,
                'Range': range_name,
                'MAE': sub_mae,
                'RMSE': sub_rmse,
                'N Points': np.sum(range_mask)
            })

# -------------------- Partial Curve Testing --------------------
print(f"\n{'='*60}")
print("Testing on Partial Charge/Discharge Curves")
print(f"{'='*60}")

def create_partial_curve(X_seq, y_seq, soc_start, soc_end):
    """Extract a partial curve from a full charge/discharge sequence."""
    soc_values = y_seq[:, 0]
    
    # Find indices where SOC is within the range
    if soc_start < soc_end:  # Charging direction
        mask = (soc_values >= soc_start) & (soc_values <= soc_end)
    else:  # Discharging direction
        mask = (soc_values <= soc_start) & (soc_values >= soc_end)
    
    if np.sum(mask) < 10:  # Need at least 10 points
        return None, None
    
    indices = np.where(mask)[0]
    start_idx, end_idx = indices[0], indices[-1] + 1
    
    return X_seq[start_idx:end_idx], y_seq[start_idx:end_idx]

# Define partial curve test cases
partial_curve_tests = [
    # Charging scenarios
    (0, 30, 'C1ch', 'Charge 0-30%'),
    (30, 60, 'C1ch', 'Charge 30-60%'),
    (60, 90, 'C1ch', 'Charge 60-90%'),
    (20, 80, 'C1ch', 'Charge 20-80%'),
    
    # Discharging scenarios
    (100, 70, 'C1dc', 'Discharge 100-70%'),
    (70, 40, 'C1dc', 'Discharge 70-40%'),
    (40, 10, 'C1dc', 'Discharge 40-10%'),
    (80, 20, 'C1dc', 'Discharge 80-20%'),
]

partial_results = []

for soc_start, soc_end, phase_type, test_name in partial_curve_tests:
    # Find test sequences of the right phase
    phase_mask = phases_test == phase_type
    phase_indices = np.where(phase_mask)[0]
    
    if len(phase_indices) == 0:
        print(f"\n{test_name}: No {phase_type} sequences available")
        continue
    
    # Process multiple sequences for this test case
    test_preds, test_actuals = [], []
    valid_sequences = 0
    
    for idx in phase_indices[:50]:  # Test on up to 50 sequences
        X_seq = X_test[idx]
        y_seq = y_test[idx]
        
        # Create partial curve
        X_partial, y_partial = create_partial_curve(X_seq, y_seq, soc_start, soc_end)
        
        if X_partial is None:
            continue
        
        valid_sequences += 1
        
        # Get prediction for partial curve
        with torch.no_grad():
            x_input = torch.tensor(X_partial[np.newaxis, :, :], dtype=torch.float32, device=device)
            pred = model(x_input).cpu().numpy()[0, :, 0] * 100
        
        true = y_partial[:, 0]
        
        test_preds.append(pred)
        test_actuals.append(true)
    
    if valid_sequences == 0:
        print(f"\n{test_name}: No valid partial curves found in range")
        continue
    
    # Concatenate all predictions
    all_preds = np.concatenate(test_preds)
    all_actuals = np.concatenate(test_actuals)
    
    # Calculate metrics
    mae_partial = np.mean(np.abs(all_preds - all_actuals))
    rmse_partial = np.sqrt(np.mean((all_preds - all_actuals)**2))
    max_error_partial = np.max(np.abs(all_preds - all_actuals))
    
    partial_results.append({
        'Test': test_name,
        'Phase': phase_type,
        'SOC Range': f"{soc_start}-{soc_end}%",
        'Sequences': valid_sequences,
        'Points': len(all_preds),
        'MAE': mae_partial,
        'RMSE': rmse_partial,
        'Max Error': max_error_partial
    })
    
    print(f"\n{test_name}:")
    print(f"  Valid Sequences: {valid_sequences}")
    print(f"  Total Points:    {len(all_preds):,}")
    print(f"  MAE:             {mae_partial:.4f}%")
    print(f"  RMSE:            {rmse_partial:.4f}%")
    print(f"  Max Error:       {max_error_partial:.4f}%")

# Create summary table
if len(partial_results) > 0:
    partial_df = pd.DataFrame(partial_results)
    print(f"\n{'='*60}")
    print("Partial Curve Testing Summary:")
    print(f"{'='*60}")
    print(partial_df.to_string(index=False))
    print(f"{'='*60}\n")

# -------------------- Visualizations --------------------
print("Generating visualizations...\n")

# 1. Partial Curve Visualization
fig, axes = plt.subplots(4, 2, figsize=(14, 16))

visualization_tests = [
    (0, 30, 'C1ch', 'Charge 0-30%', 0, 0),
    (30, 60, 'C1ch', 'Charge 30-60%', 0, 1),
    (60, 90, 'C1ch', 'Charge 60-90%', 1, 0),
    (20, 80, 'C1ch', 'Charge 20-80%', 1, 1),
    (100, 70, 'C1dc', 'Discharge 100-70%', 2, 0),
    (70, 40, 'C1dc', 'Discharge 70-40%', 2, 1),
    (40, 10, 'C1dc', 'Discharge 40-10%', 3, 0),
    (80, 20, 'C1dc', 'Discharge 80-20%', 3, 1),
]

for soc_start, soc_end, phase_type, test_name, row, col in visualization_tests:
    ax = axes[row, col]
    
    # Find a suitable test sequence
    phase_mask = phases_test == phase_type
    phase_indices = np.where(phase_mask)[0]
    
    found = False
    for idx in phase_indices[:100]:
        X_seq = X_test[idx]
        y_seq = y_test[idx]
        
        # Create partial curve
        X_partial, y_partial = create_partial_curve(X_seq, y_seq, soc_start, soc_end)
        
        if X_partial is None or len(X_partial) < 10:
            continue
        
        # Get prediction
        with torch.no_grad():
            x_input = torch.tensor(X_partial[np.newaxis, :, :], dtype=torch.float32, device=device)
            pred = model(x_input).cpu().numpy()[0, :, 0] * 100
        
        true = y_partial[:, 0]
        error = np.abs(true - pred)
        mae_sample = np.mean(error)
        max_error_sample = np.max(error)
        
        # Plot
        time_steps = np.arange(len(true))
        ax.plot(time_steps, true, 'b-', linewidth=2.5, label='True SOC', alpha=0.8)
        ax.plot(time_steps, pred, 'r--', linewidth=2, label='PI-GRU Prediction', alpha=0.8)
        ax.fill_between(time_steps, true, pred, alpha=0.2, color='orange')
        
        # Add horizontal lines for SOC range
        ax.axhline(soc_start, color='green', linestyle=':', linewidth=1.5, alpha=0.7, label=f'Start: {soc_start}%')
        ax.axhline(soc_end, color='red', linestyle=':', linewidth=1.5, alpha=0.7, label=f'End: {soc_end}%')
        
        ax.set_title(f'{test_name}\nMAE: {mae_sample:.2f}% | Max Error: {max_error_sample:.2f}%', 
                     fontsize=10, fontweight='bold')
        ax.set_xlabel('Time Step', fontsize=9)
        ax.set_ylabel('SOC (%)', fontsize=9)
        ax.legend(fontsize=7, loc='best')
        ax.grid(True, alpha=0.3)
        ax.set_ylim(max(0, min(soc_start, soc_end) - 10), 
                    min(100, max(soc_start, soc_end) + 10))
        
        found = True
        break
    
    if not found:
        ax.text(0.5, 0.5, f'No valid\n{test_name}\ndata found',
               ha='center', va='center', fontsize=10, transform=ax.transAxes)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

plt.suptitle('Physics-Informed GRU Performance on Partial Curves', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# 2. Comparison: Partial Curves MAE/RMSE
if len(partial_results) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    charging_partial = partial_df[partial_df['Phase'] == 'C1ch']
    discharging_partial = partial_df[partial_df['Phase'] == 'C1dc']
    
    # Plot MAE comparison
    ax = axes[0]
    if len(charging_partial) > 0:
        x_charge = np.arange(len(charging_partial))
        ax.bar(x_charge - 0.2, charging_partial['MAE'], 0.4, 
               label='Charging', alpha=0.8, color='green')
    if len(discharging_partial) > 0:
        x_discharge = np.arange(len(discharging_partial))
        ax.bar(x_discharge + 0.2, discharging_partial['MAE'], 0.4, 
               label='Discharging', alpha=0.8, color='orange')
    
    all_tests = pd.concat([charging_partial, discharging_partial]) if len(discharging_partial) > 0 else charging_partial
    ax.set_xticks(np.arange(len(all_tests)))
    ax.set_xticklabels([t.replace('Charge ', 'C ').replace('Discharge ', 'D ') 
                        for t in all_tests['Test']], rotation=45, ha='right')
    ax.set_ylabel('MAE (%)')
    ax.set_title('MAE for Partial Curves (PI-GRU)', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot RMSE comparison
    ax = axes[1]
    if len(charging_partial) > 0:
        ax.bar(x_charge - 0.2, charging_partial['RMSE'], 0.4, 
               label='Charging', alpha=0.8, color='green')
    if len(discharging_partial) > 0:
        ax.bar(x_discharge + 0.2, discharging_partial['RMSE'], 0.4, 
               label='Discharging', alpha=0.8, color='orange')
    
    ax.set_xticks(np.arange(len(all_tests)))
    ax.set_xticklabels([t.replace('Charge ', 'C ').replace('Discharge ', 'D ') 
                        for t in all_tests['Test']], rotation=45, ha='right')
    ax.set_ylabel('RMSE (%)')
    ax.set_title('RMSE for Partial Curves (PI-GRU)', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()

# 3. Range-specific Visualization
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

ax = axes[0, 0]
x_pos = np.arange(len(range_metrics))
ax.bar(x_pos, [m['MAE'] for m in range_metrics], alpha=0.7, color='steelblue')
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('MAE (%)')
ax.set_title('MAE by SOC Range (PI-GRU)', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

ax = axes[0, 1]
ax.bar(x_pos, [m['RMSE'] for m in range_metrics], alpha=0.7, color='coral')
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('RMSE (%)')
ax.set_title('RMSE by SOC Range (PI-GRU)', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

ax = axes[0, 2]
ax.bar(x_pos, [m['Max Error'] for m in range_metrics], alpha=0.7, color='lightcoral')
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('Max Error (%)')
ax.set_title('Max Error by SOC Range (PI-GRU)', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Error distribution by range
for idx, (soc_min, soc_max, range_name) in enumerate(soc_ranges[:3]):
    ax = axes[1, idx]
    mask = (actuals >= soc_min) & (actuals <= soc_max)
    
    if np.sum(mask) > 0:
        range_errors = (predictions[mask] - actuals[mask]).flatten()
        ax.hist(range_errors, bins=30, edgecolor='black', alpha=0.7, color='teal')
        ax.axvline(0, color='red', linestyle='--', linewidth=2, label='Zero Error')
        ax.set_xlabel('Error (%)')
        ax.set_ylabel('Frequency')
        ax.set_title(f'{range_name}\nMean Error: {np.mean(range_errors):.3f}%', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)

plt.suptitle('Physics-Informed GRU SOC Range Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# 4. Phase-specific Visualization
if len(phase_metrics) > 0:
    phase_df = pd.DataFrame(phase_metrics)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    ax = axes[0]
    ranges_unique = phase_df['Range'].unique()
    x = np.arange(len(ranges_unique))
    width = 0.35
    
    charging_data = phase_df[phase_df['Phase'] == 'Charging']
    discharging_data = phase_df[phase_df['Phase'] == 'Discharging']
    
    charge_mae = [charging_data[charging_data['Range'] == r]['MAE'].values[0] 
                  if len(charging_data[charging_data['Range'] == r]) > 0 else 0 
                  for r in ranges_unique]
    discharge_mae = [discharging_data[discharging_data['Range'] == r]['MAE'].values[0] 
                     if len(discharging_data[discharging_data['Range'] == r]) > 0 else 0 
                     for r in ranges_unique]
    
    ax.bar(x - width/2, charge_mae, width, label='Charging', alpha=0.8, color='green')
    ax.bar(x + width/2, discharge_mae, width, label='Discharging', alpha=0.8, color='orange')
    ax.set_xticks(x)
    ax.set_xticklabels([r.split('(')[0].strip() for r in ranges_unique], rotation=45, ha='right')
    ax.set_ylabel('MAE (%)')
    ax.set_title('MAE: Charging vs Discharging (PI-GRU)', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    ax = axes[1]
    charge_rmse = [charging_data[charging_data['Range'] == r]['RMSE'].values[0] 
                   if len(charging_data[charging_data['Range'] == r]) > 0 else 0 
                   for r in ranges_unique]
    discharge_rmse = [discharging_data[discharging_data['Range'] == r]['RMSE'].values[0] 
                      if len(discharging_data[discharging_data['Range'] == r]) > 0 else 0 
                      for r in ranges_unique]
    
    ax.bar(x - width/2, charge_rmse, width, label='Charging', alpha=0.8, color='green')
    ax.bar(x + width/2, discharge_rmse, width, label='Discharging', alpha=0.8, color='orange')
    ax.set_xticks(x)
    ax.set_xticklabels([r.split('(')[0].strip() for r in ranges_unique], rotation=45, ha='right')
    ax.set_ylabel('RMSE (%)')
    ax.set_title('RMSE: Charging vs Discharging (PI-GRU)', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()

# 5. Sample Predictions
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

for i in range(4):
    idx = np.random.randint(0, len(X_test))
    with torch.no_grad():
        x_sample = torch.tensor(X_test[idx:idx+1], dtype=torch.float32, device=device)
        pred = model(x_sample).cpu().numpy()[0, :, 0] * 100
    
    true = y_test[idx, :, 0]
    error = np.mean(np.abs(true - pred))
    phase_label = "Charging" if phases_test[idx] == 'C1ch' else "Discharging"
    
    ax = axes[i//2, i%2]
    ax.plot(true, 'k-', linewidth=2, label='True SOC')
    ax.plot(pred, 'r--', linewidth=2, label='PI-GRU Prediction')
    ax.set_title(f'{phase_label} - Sample {i+1} | MAE: {error:.2f}%', 
                 fontsize=11, fontweight='bold')
    ax.set_xlabel('Time Step')
    ax.set_ylabel('SOC (%)')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('Physics-Informed GRU SOC Estimation (Random Samples)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# 6. Overall Error Distribution
plt.figure(figsize=(10, 5))
errors = (predictions - actuals).flatten()

plt.subplot(1, 2, 1)
plt.hist(errors, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('Prediction Error (%)')
plt.ylabel('Frequency')
plt.title('Error Distribution (PI-GRU)')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.scatter(actuals.flatten(), predictions.flatten(), alpha=0.5, s=1)
plt.plot([0, 100], [0, 100], 'r--', linewidth=2, label='Perfect Prediction')
plt.xlabel('True SOC (%)')
plt.ylabel('Predicted SOC (%)')
plt.title('Prediction vs Actual (PI-GRU)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nDone! Physics-Informed GRU training and evaluation complete.")

GRU-PINN-AUGUMENTATION

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

# ============================================================
# -------------------- Data Loading --------------------------
# ============================================================
data = scipy.io.loadmat('Oxford_Battery_Degradation_Dataset_1.mat')

def soc_from_q(q, phase):
    """Compute SOC% from charge trace."""
    q0, q1 = float(q[0]), float(q[-1])
    if np.isclose(q1, q0):
        return None
    qn = (q - q0) / (q1 - q0)
    if phase == 'C1ch':
        return 100.0 * np.clip(qn, 0, 1)
    elif phase == 'C1dc':
        return 100.0 * (1.0 - np.clip(qn, 0, 1))
    return None

def extract_sequences(data, L=128):
    """Extract V, T, I features and SOC labels."""
    X_list, y_list, phase_list = [], [], []
    
    for ci in range(1, 9):
        cell = data[f'Cell{ci}']
        for cyc_name in sorted(cell.dtype.names, key=lambda s: int(s[3:])):
            cyc = cell[cyc_name][0, 0]
            for phase in ['C1ch', 'C1dc']:
                if phase not in cyc.dtype.names:
                    continue
                blk = cyc[phase][0, 0]
                if not all(k in blk.dtype.names for k in ['t','v','q']):
                    continue
                
                t = blk['t'][0,0].ravel().astype(float)
                v = blk['v'][0,0].ravel().astype(float)
                q = blk['q'][0,0].ravel().astype(float)
                
                if t.size < 5:
                    continue
                
                # Temperature (fill missing with forward/backward fill)
                if 'T' in blk.dtype.names:
                    T = blk['T'][0,0].ravel().astype(float)
                else:
                    T = np.full_like(t, 25.0)
                T = pd.Series(T).ffill().bfill().values
                
                # Current (use defaults if missing)
                if 'i' in blk.dtype.names:
                    I = blk['i'][0,0].ravel().astype(float)
                else:
                    I = np.full_like(t, 0.74 if phase == 'C1ch' else -0.74)
                
                soc = soc_from_q(q, phase)
                if soc is None:
                    continue
                
                # Resample to fixed length L
                t_new = np.linspace(t[0], t[-1], L)
                v_new = np.interp(t_new, t, v)
                T_new = np.interp(t_new, t, T)
                I_new = np.interp(t_new, t, I)
                y_new = np.interp(t_new, t, soc)
                
                X = np.stack([v_new, T_new, I_new], axis=-1)  # [L, 3]
                y = y_new[:, None]  # [L, 1]
                
                X_list.append(X)
                y_list.append(y)
                phase_list.append(phase)
    
    return np.array(X_list), np.array(y_list), np.array(phase_list)

# ============================================================
# ------------- Data Augmentation for Mid-SOC ----------------
# ============================================================
def augment_with_partial_curves(X, y, phases, augmentation_factor=3):
    """
    Augment dataset by extracting partial curves from different SOC ranges.
    This creates more training examples, especially for mid-SOC regions.
    
    Args:
        X: [N, L, 3] - features
        y: [N, L, 1] - SOC labels (in %)
        phases: [N] - phase labels
        augmentation_factor: how many partial curves to extract per sequence
    
    Returns:
        Augmented X, y, phases
    """
    X_aug, y_aug, phase_aug = [], [], []
    
    # Keep all original sequences
    for i in range(len(X)):
        X_aug.append(X[i])
        y_aug.append(y[i])
        phase_aug.append(phases[i])
    
    # Define partial curve ranges with emphasis on mid-SOC
    # Format: (start_soc, end_soc, weight) - weight determines sampling probability
    if phases[0] == 'C1ch':  # Charging ranges
        partial_ranges = [
            (0, 40, 1.0),    # Early charging
            (10, 50, 2.0),   # Early-mid charging (overlap with mid)
            (20, 60, 3.0),   # Mid charging (HIGH WEIGHT)
            (30, 70, 3.0),   # Mid charging (HIGH WEIGHT)
            (40, 80, 2.0),   # Mid-late charging
            (50, 90, 1.0),   # Late charging
            (60, 100, 1.0),  # Very late charging
        ]
    else:  # Discharging ranges
        partial_ranges = [
            (100, 60, 1.0),  # Early discharging
            (90, 50, 2.0),   # Early-mid discharging
            (80, 40, 3.0),   # Mid discharging (HIGH WEIGHT)
            (70, 30, 3.0),   # Mid discharging (HIGH WEIGHT)
            (60, 20, 2.0),   # Mid-late discharging
            (50, 10, 1.0),   # Late discharging
            (40, 0, 1.0),    # Very late discharging
        ]
    
    print(f"\nAugmenting data with partial curves...")
    print(f"Original sequences: {len(X)}")
    
    augmented_count = 0
    
    for i in range(len(X)):
        soc_values = y[i, :, 0]
        phase = phases[i]
        
        # Sample partial ranges based on weights
        for _ in range(augmentation_factor):
            # Choose range based on weights
            ranges = [r for r in partial_ranges]
            weights = [r[2] for r in ranges]
            weights = np.array(weights) / np.sum(weights)
            
            chosen_idx = np.random.choice(len(ranges), p=weights)
            soc_start, soc_end, _ = ranges[chosen_idx]
            
            # Extract partial curve
            if phase == 'C1ch':  # Charging
                mask = (soc_values >= soc_start) & (soc_values <= soc_end)
            else:  # Discharging
                mask = (soc_values <= soc_start) & (soc_values >= soc_end)
            
            if np.sum(mask) < 20:  # Need reasonable number of points
                continue
            
            indices = np.where(mask)[0]
            start_idx, end_idx = indices[0], indices[-1] + 1
            
            # Extract and pad/truncate to fixed length
            X_partial = X[i, start_idx:end_idx, :]
            y_partial = y[i, start_idx:end_idx, :]
            
            # Pad or sample to target length (64 for partial curves)
            target_len = 64
            if len(X_partial) < target_len:
                # Pad with last values
                pad_len = target_len - len(X_partial)
                X_partial = np.vstack([X_partial, np.tile(X_partial[-1:], (pad_len, 1))])
                y_partial = np.vstack([y_partial, np.tile(y_partial[-1:], (pad_len, 1))])
            else:
                # Resample to target length
                indices_sample = np.linspace(0, len(X_partial)-1, target_len, dtype=int)
                X_partial = X_partial[indices_sample]
                y_partial = y_partial[indices_sample]
            
            X_aug.append(X_partial)
            y_aug.append(y_partial)
            phase_aug.append(phase)
            augmented_count += 1
    
    print(f"Added {augmented_count} augmented partial curves")
    print(f"Total sequences after augmentation: {len(X_aug)}")
    
    # Convert to arrays
    # Need to handle different lengths - pad shorter sequences
    max_len = max(x.shape[0] for x in X_aug)
    
    X_padded = []
    y_padded = []
    
    for i in range(len(X_aug)):
        x_seq = X_aug[i]
        y_seq = y_aug[i]
        
        if x_seq.shape[0] < max_len:
            pad_len = max_len - x_seq.shape[0]
            x_seq = np.vstack([x_seq, np.tile(x_seq[-1:], (pad_len, 1))])
            y_seq = np.vstack([y_seq, np.tile(y_seq[-1:], (pad_len, 1))])
        
        X_padded.append(x_seq)
        y_padded.append(y_seq)
    
    return np.array(X_padded), np.array(y_padded), np.array(phase_aug)

# ============================================================
# -------------------- Prepare Data --------------------------
# ============================================================
print("Loading data...")
X_raw, y_raw, phases_raw = extract_sequences(data, L=128)
print(f"Raw sequences: {len(X_raw)}, Shape: {X_raw.shape}")

# Train/test split BEFORE augmentation (to avoid data leakage)
n_train = int(0.8 * len(X_raw))
X_train_raw, X_test = X_raw[:n_train], X_raw[n_train:]
y_train_raw, y_test = y_raw[:n_train], y_raw[n_train:]
phases_train_raw, phases_test = phases_raw[:n_train], phases_raw[n_train:]

print(f"\nBefore augmentation:")
print(f"  Train: {len(X_train_raw)} | Test: {len(X_test)}")

# Augment ONLY training data with partial curves
X_train_aug, y_train_aug, phases_train_aug = augment_with_partial_curves(
    X_train_raw, y_train_raw, phases_train_raw, augmentation_factor=3
)

print(f"\nAfter augmentation:")
print(f"  Train: {len(X_train_aug)} | Test: {len(X_test)}")

# Normalize features using stats from ORIGINAL training data only
mu = X_train_raw.reshape(-1, X_train_raw.shape[-1]).mean(axis=0)
sd = X_train_raw.reshape(-1, X_train_raw.shape[-1]).std(axis=0) + 1e-8

X_train = (X_train_aug - mu) / sd
X_test = (X_test - mu) / sd

# Normalize labels to [0, 1]
y_train_norm = y_train_aug / 100.0
y_test_norm = y_test / 100.0

# ============================================================
# -------------------- DataLoaders ---------------------------
# ============================================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}\n")

train_loader = DataLoader(
    TensorDataset(torch.tensor(X_train, dtype=torch.float32),
                  torch.tensor(y_train_norm, dtype=torch.float32)),
    batch_size=32, shuffle=True)

test_loader = DataLoader(
    TensorDataset(torch.tensor(X_test, dtype=torch.float32),
                  torch.tensor(y_test_norm, dtype=torch.float32)),
    batch_size=32, shuffle=False)

# ============================================================
# -------------------- GRU-PINN Model ------------------------
# ============================================================
class GRU_PINN(nn.Module):
    """
    Physics-informed GRU for SOC estimation.
    - Base GRU learns mapping from (V, T, I) -> SOC
    - Physics loss enforces discrete Coulomb counting:
        SOC_{k+1} - SOC_k ≈ alpha * I_k
    - Bounds loss softly enforces SOC in [0, 1]
    """
    def __init__(self, input_dim=3, hidden=128, layers=2, dropout=0.2):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden, num_layers=layers,
                          batch_first=True,
                          dropout=dropout if layers > 1 else 0.0)
        self.fc = nn.Linear(hidden, 1)

        # Physics parameter: SOC increment per unit current per step (in normalized SOC units).
        self.alpha = nn.Parameter(torch.tensor(1e-3, dtype=torch.float32))

    def forward(self, x):
        # x: [batch, seq_len, input_dim]
        out, _ = self.gru(x)      # [batch, seq_len, hidden]
        soc_hat = self.fc(out)    # [batch, seq_len, 1], normalized SOC
        return soc_hat

    def physics_loss(self, x, y_pred):
        """
        Physics-informed loss based on discrete Coulomb counting:
            SOC_{k+1} - SOC_k ≈ alpha * I_k
        """
        I = x[..., 2]          # [B, L] - Current
        soc = y_pred[..., 0]   # [B, L]

        dsoc = soc[:, 1:] - soc[:, :-1]  # [B, L-1]
        I_mid = I[:, :-1]                # [B, L-1]

        residual = dsoc - self.alpha * I_mid
        phys_loss = (residual ** 2).mean()
        return phys_loss

    def bounds_loss(self, y_pred):
        """Softly enforce SOC in [0, 1]."""
        below = torch.relu(-y_pred)
        above = torch.relu(y_pred - 1.0)
        return (below**2 + above**2).mean()

model = GRU_PINN(input_dim=3, hidden=128, layers=2, dropout=0.2).to(device)
print(f"Model: GRU-PINN with {sum(p.numel() for p in model.parameters()):,} parameters\n")

# ============================================================
# -------------------- Training (GRU-PINN) -------------------
# ============================================================
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

epochs = 150
best_loss = float('inf')

# Physics / bounds weights
lambda_phys = 0.1
lambda_bounds = 0.01

print("Training GRU-PINN model with augmented data...")
print(f"{'='*80}")

for epoch in range(epochs):
    # Train
    model.train()
    train_loss = 0.0
    train_mse = 0.0
    train_phys = 0.0
    train_bounds = 0.0
    
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)

        optimizer.zero_grad()
        y_pred = model(xb)

        # Component losses
        data_loss = loss_fn(y_pred, yb)
        phys_loss = model.physics_loss(xb, y_pred)
        bnd_loss = model.bounds_loss(y_pred)

        # Total PINN loss
        loss = data_loss + lambda_phys * phys_loss + lambda_bounds * bnd_loss

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * xb.size(0)
        train_mse += data_loss.item() * xb.size(0)
        train_phys += phys_loss.item() * xb.size(0)
        train_bounds += bnd_loss.item() * xb.size(0)

    train_loss /= len(train_loader.dataset)
    train_mse /= len(train_loader.dataset)
    train_phys /= len(train_loader.dataset)
    train_bounds /= len(train_loader.dataset)

    # Validate
    model.eval()
    test_loss = 0.0
    test_mse = 0.0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            y_pred = model(xb)

            data_loss = loss_fn(y_pred, yb)
            phys_loss = model.physics_loss(xb, y_pred)
            bnd_loss = model.bounds_loss(y_pred)

            loss = data_loss + lambda_phys * phys_loss + lambda_bounds * bnd_loss
            test_loss += loss.item() * xb.size(0)
            test_mse += data_loss.item() * xb.size(0)

    test_loss /= len(test_loader.dataset)
    test_mse /= len(test_loader.dataset)

    # Save best model (based on test MSE for fair comparison)
    if test_mse < best_loss:
        best_loss = test_mse
        torch.save(model.state_dict(), 'best_gru_pinn_augmented.pt')

    if (epoch + 1) % 10 == 0:
        print(
            f"Epoch {epoch+1:03d} | "
            f"Train: {train_loss:.6f} (MSE: {train_mse:.6f}, Phys: {train_phys:.6f}, Bnd: {train_bounds:.6f}) | "
            f"Test: {test_loss:.6f} (MSE: {test_mse:.6f}) | "
            f"alpha: {model.alpha.item():.6e}"
        )

print(f"\n{'='*80}")
print(f"Best Test MSE: {best_loss:.6f}")
print(f"Learned alpha: {model.alpha.item():.6e}")

# ============================================================
# -------------------- Validation ----------------------------
# ============================================================
model.load_state_dict(torch.load('best_gru_pinn_augmented.pt'))
model.eval()

# Get predictions on test set
predictions, actuals = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        pred = model(xb).cpu().numpy()
        predictions.append(pred)
        actuals.append(yb.numpy())

predictions = np.concatenate(predictions) * 100.0
actuals = np.concatenate(actuals) * 100.0

# Overall metrics
mae = np.mean(np.abs(predictions - actuals))
rmse = np.sqrt(np.mean((predictions - actuals)**2))
max_error = np.max(np.abs(predictions - actuals))

print(f"\n{'='*60}")
print("Overall Validation Metrics (GRU-PINN + Augmentation):")
print(f"{'='*60}")
print(f"MAE:        {mae:.4f}%")
print(f"RMSE:       {rmse:.4f}%")
print(f"Max Error:  {max_error:.4f}%")

# ============================================================
# -------------------- SOC Range Analysis --------------------
# ============================================================
print(f"\n{'='*60}")
print("SOC Range Analysis (Emphasis on Mid-SOC)")
print(f"{'='*60}")

soc_ranges = [
    (0, 30, "Low SOC (0-30%)"),
    (30, 60, "Mid SOC (30-60%)"),  # This should be much better now!
    (60, 90, "High SOC (60-90%)"),
    (0, 100, "Full Range (0-100%)")
]

range_metrics = []

for soc_min, soc_max, range_name in soc_ranges:
    mask = (actuals >= soc_min) & (actuals <= soc_max)
    if np.sum(mask) == 0:
        print(f"\n{range_name}: No data points in this range")
        continue
    
    range_preds = predictions[mask]
    range_actuals = actuals[mask]
    
    range_mae = np.mean(np.abs(range_preds - range_actuals))
    range_rmse = np.sqrt(np.mean((range_preds - range_actuals)**2))
    range_max_error = np.max(np.abs(range_preds - range_actuals))
    n_points = np.sum(mask)
    
    range_metrics.append({
        'Range': range_name,
        'MAE': range_mae,
        'RMSE': range_rmse,
        'Max Error': range_max_error,
        'N Points': n_points
    })
    
    print(f"\n{range_name}:")
    print(f"  Points:     {n_points:,}")
    print(f"  MAE:        {range_mae:.4f}%")
    print(f"  RMSE:       {range_rmse:.4f}%")
    print(f"  Max Error:  {range_max_error:.4f}%")

metrics_df = pd.DataFrame(range_metrics)
print(f"\n{'='*60}")
print("Summary Table:")
print(f"{'='*60}")
print(metrics_df.to_string(index=False))
print(f"{'='*60}\n")

# ============================================================
# ---------------- Charging vs Discharging -------------------
# ============================================================
print(f"\n{'='*60}")
print("Charging vs Discharging Analysis")
print(f"{'='*60}")

phase_metrics = []

for phase_type in ['C1ch', 'C1dc']:
    phase_name = "Charging" if phase_type == 'C1ch' else "Discharging"
    
    phase_mask = phases_test == phase_type
    if np.sum(phase_mask) == 0:
        continue
    
    phase_preds = predictions[phase_mask]
    phase_actuals = actuals[phase_mask]
    
    phase_mae = np.mean(np.abs(phase_preds - phase_actuals))
    phase_rmse = np.sqrt(np.mean((phase_preds - phase_actuals)**2))
    
    print(f"\n{phase_name}:")
    print(f"  Sequences:  {np.sum(phase_mask)}")
    print(f"  MAE:        {phase_mae:.4f}%")
    print(f"  RMSE:       {phase_rmse:.4f}%")
    
    for soc_min, soc_max, range_name in soc_ranges[:3]:
        range_mask = (phase_actuals >= soc_min) & (phase_actuals <= soc_max)
        
        if np.sum(range_mask) > 0:
            sub_preds = phase_preds[range_mask]
            sub_actuals = phase_actuals[range_mask]
            sub_mae = np.mean(np.abs(sub_preds - sub_actuals))
            sub_rmse = np.sqrt(np.mean((sub_preds - sub_actuals)**2))
            
            print(f"    {range_name}:")
            print(f"      Points: {np.sum(range_mask):,} | MAE: {sub_mae:.4f}% | RMSE: {sub_rmse:.4f}%")
            
            phase_metrics.append({
                'Phase': phase_name,
                'Range': range_name,
                'MAE': sub_mae,
                'RMSE': sub_rmse,
                'N Points': np.sum(range_mask)
            })

# ============================================================
# --------------- Partial Curve Testing ----------------------
# ============================================================
print(f"\n{'='*60}")
print("Testing on Partial Charge/Discharge Curves")
print(f"{'='*60}")

def create_partial_curve(X_seq, y_seq, soc_start, soc_end):
    """Extract a partial curve from a full charge/discharge sequence."""
    soc_values = y_seq[:, 0]
    
    if soc_start < soc_end:  # Charging
        mask = (soc_values >= soc_start) & (soc_values <= soc_end)
    else:  # Discharging
        mask = (soc_values <= soc_start) & (soc_values >= soc_end)
    
    if np.sum(mask) < 10:
        return None, None
    
    indices = np.where(mask)[0]
    start_idx, end_idx = indices[0], indices[-1] + 1
    
    return X_seq[start_idx:end_idx], y_seq[start_idx:end_idx]

partial_curve_tests = [
    (0, 30, 'C1ch', 'Charge 0-30%'),
    (30, 60, 'C1ch', 'Charge 30-60%'),
    (60, 90, 'C1ch', 'Charge 60-90%'),
    (20, 80, 'C1ch', 'Charge 20-80%'),
    (100, 70, 'C1dc', 'Discharge 100-70%'),
    (70, 40, 'C1dc', 'Discharge 70-40%'),
    (40, 10, 'C1dc', 'Discharge 40-10%'),
    (80, 20, 'C1dc', 'Discharge 80-20%'),
]

partial_results = []

for soc_start, soc_end, phase_type, test_name in partial_curve_tests:
    phase_mask = phases_test == phase_type
    phase_indices = np.where(phase_mask)[0]
    
    if len(phase_indices) == 0:
        print(f"\n{test_name}: No {phase_type} sequences available")
        continue
    
    test_preds, test_actuals = [], []
    valid_sequences = 0
    
    for idx in phase_indices[:50]:
        X_seq = X_test[idx]
        y_seq = y_test[idx]
        
        X_partial, y_partial = create_partial_curve(X_seq, y_seq, soc_start, soc_end)
        if X_partial is None:
            continue
        
        valid_sequences += 1
        
        with torch.no_grad():
            x_input = torch.tensor(X_partial[np.newaxis, :, :], dtype=torch.float32, device=device)
            pred = model(x_input).cpu().numpy()[0, :, 0] * 100.0
        
        true = y_partial[:, 0]
        
        test_preds.append(pred)
        test_actuals.append(true)
    
    if valid_sequences == 0:
        print(f"\n{test_name}: No valid partial curves found in range")
        continue
    
    all_preds = np.concatenate(test_preds)
    all_actuals = np.concatenate(test_actuals)
    
    mae_pc = np.mean(np.abs(all_preds - all_actuals))
    rmse_pc = np.sqrt(np.mean((all_preds - all_actuals)**2))
    max_error_pc = np.max(np.abs(all_preds - all_actuals))
    
    partial_results.append({
        'Test': test_name,
        'Phase': phase_type,
        'SOC Range': f"{soc_start}-{soc_end}%",
        'Sequences': valid_sequences,
        'Points': len(all_preds),
        'MAE': mae_pc,
        'RMSE': rmse_pc,
        'Max Error': max_error_pc
    })
    
    print(f"\n{test_name}:")
    print(f"  Valid Sequences: {valid_sequences}")
    print(f"  Total Points:    {len(all_preds):,}")
    print(f"  MAE:             {mae_pc:.4f}%")
    print(f"  RMSE:            {rmse_pc:.4f}%")
    print(f"  Max Error:       {max_error_pc:.4f}%")

if len(partial_results) > 0:
    partial_df = pd.DataFrame(partial_results)
    print(f"\n{'='*60}")
    print("Partial Curve Testing Summary:")
    print(f"{'='*60}")
    print(partial_df.to_string(index=False))
    print(f"{'='*60}\n")

# ============================================================
# -------------- Visualizations ------------------------------
# ============================================================
print("\nGenerating visualizations...")

# 1. Partial Curve Visualization
fig, axes = plt.subplots(4, 2, figsize=(14, 16))

visualization_tests = [
    (0, 30, 'C1ch', 'Charge 0-30%', 0, 0),
    (30, 60, 'C1ch', 'Charge 30-60%', 0, 1),
    (60, 90, 'C1ch', 'Charge 60-90%', 1, 0),
    (20, 80, 'C1ch', 'Charge 20-80%', 1, 1),
    (100, 70, 'C1dc', 'Discharge 100-70%', 2, 0),
    (70, 40, 'C1dc', 'Discharge 70-40%', 2, 1),
    (40, 10, 'C1dc', 'Discharge 40-10%', 3, 0),
    (80, 20, 'C1dc', 'Discharge 80-20%', 3, 1),
]

for soc_start, soc_end, phase_type, test_name, row, col in visualization_tests:
    ax = axes[row, col]
    
    phase_mask = phases_test == phase_type
    phase_indices = np.where(phase_mask)[0]
    
    found = False
    for idx in phase_indices[:100]:
        X_seq = X_test[idx]
        y_seq = y_test[idx]
        
        X_partial, y_partial = create_partial_curve(X_seq, y_seq, soc_start, soc_end)
        if X_partial is None or len(X_partial) < 10:
            continue
        
        with torch.no_grad():
            x_input = torch.tensor(X_partial[np.newaxis, :, :], dtype=torch.float32, device=device)
            pred = model(x_input).cpu().numpy()[0, :, 0] * 100.0
        
        true = y_partial[:, 0]
        error = np.abs(true - pred)
        mae_pc = np.mean(error)
        max_error_pc = np.max(error)
        
        time_steps = np.arange(len(true))
        ax.plot(time_steps, true, 'b-', linewidth=2.5, label='True SOC', alpha=0.8)
        ax.plot(time_steps, pred, 'r--', linewidth=2, label='GRU-PINN + Aug', alpha=0.8)
        ax.fill_between(time_steps, true, pred, alpha=0.2, color='orange')
        
        ax.axhline(soc_start, color='green', linestyle=':', linewidth=1.5, alpha=0.7, label=f'Start: {soc_start}%')
        ax.axhline(soc_end, color='red', linestyle=':', linewidth=1.5, alpha=0.7, label=f'End: {soc_end}%')
        
        ax.set_title(f'{test_name}\nMAE: {mae_pc:.2f}% | Max Error: {max_error_pc:.2f}%', 
                     fontsize=10, fontweight='bold')
        ax.set_xlabel('Time Step', fontsize=9)
        ax.set_ylabel('SOC (%)', fontsize=9)
        ax.legend(fontsize=7, loc='best')
        ax.grid(True, alpha=0.3)
        ax.set_ylim(max(0, min(soc_start, soc_end) - 10), 
                    min(100, max(soc_start, soc_end) + 10))
        
        found = True
        break
    
    if not found:
        ax.text(0.5, 0.5, f'No valid\n{test_name}\ndata found',
                ha='center', va='center', fontsize=10, transform=ax.transAxes)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

plt.suptitle('GRU-PINN + Data Augmentation: Partial Curves Performance', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# 2. Range comparison
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

ax = axes[0, 0]
x_pos = np.arange(len(range_metrics))
bars = ax.bar(x_pos, [m['MAE'] for m in range_metrics], alpha=0.7, color='steelblue')
# Highlight mid-SOC bar
bars[1].set_color('orange')
bars[1].set_alpha(0.9)
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('MAE (%)')
ax.set_title('MAE by SOC Range (Mid-SOC Improved!)', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

ax = axes[0, 1]
bars = ax.bar(x_pos, [m['RMSE'] for m in range_metrics], alpha=0.7, color='coral')
bars[1].set_color('orange')
bars[1].set_alpha(0.9)
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('RMSE (%)')
ax.set_title('RMSE by SOC Range', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

ax = axes[0, 2]
bars = ax.bar(x_pos, [m['Max Error'] for m in range_metrics], alpha=0.7, color='lightcoral')
bars[1].set_color('orange')
bars[1].set_alpha(0.9)
ax.set_xticks(x_pos)
ax.set_xticklabels([m['Range'].split('(')[0].strip() for m in range_metrics], rotation=45, ha='right')
ax.set_ylabel('Max Error (%)')
ax.set_title('Max Error by SOC Range', fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

for idx, (soc_min, soc_max, range_name) in enumerate(soc_ranges[:3]):
    ax = axes[1, idx]
    mask = (actuals >= soc_min) & (actuals <= soc_max)
    
    if np.sum(mask) > 0:
        range_errors = (predictions[mask] - actuals[mask]).flatten()
        color = 'orange' if idx == 1 else 'teal'  # Highlight mid-SOC
        ax.hist(range_errors, bins=30, edgecolor='black', alpha=0.7, color=color)
        ax.axvline(0, color='red', linestyle='--', linewidth=2, label='Zero Error')
        ax.set_xlabel('Error (%)')
        ax.set_ylabel('Frequency')
        title_prefix = '[MID-SOC] ' if idx == 1 else ''
        ax.set_title(f'{title_prefix}{range_name}\nMean Error: {np.mean(range_errors):.3f}%', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)

plt.suptitle('GRU-PINN + Augmentation: SOC Range Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n✅ Training complete with data augmentation!")
print("📊 Check the mid-SOC (30-60%) metrics - they should be significantly improved!")
print("\nDone!")