# Saturation Problem Analysis: rdNALA Prediction Plateau
## New Discovery: Different (V, E) combinations converge to identical steady states

**Key Finding**:
- `10V + E=0.25M` → Low CALA (~0.016 mol/L)
- `30V + E=0.25M` → High CALA (~1.47 mol/L)
- `10V + E=1.0M` → High CALA (~1.47 mol/L) ← **Same as 30V + 0.25M!**

**New Hypothesis**: Model has a **saturation problem** where high separation efficiency conditions all converge to the same plateau, regardless of specific (V, E) combinations.

**This notebook investigates**:
1. Does 10V+1M really converge to same state as 30V+0.25M?
2. What happens in LSTM hidden states across these conditions?
3. How does Layer Normalization affect feature discrimination?
4. What are the rdNALA prediction patterns?

In [1]:
# import module libraries
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import pandas as pd
from scipy.spatial.distance import cosine

In [None]:
# Copy model class definitions from 7.bmed_1stage_simulator.ipynb
# (LayerNormLSTM, StateExtr, PhysRegr, CurrRegr, PhysConstr, BMEDModel)

# LSTM with layer normalization
class LayerNormLSTM(nn.Module):
    def __init__(self, input_node, hidden_node):
        super().__init__()
        self.input_node = input_node
        self.hidden_node = hidden_node

        self.w_i = nn.Linear(input_node, 4*hidden_node, bias=False)
        self.w_h = nn.Linear(hidden_node, 4*hidden_node, bias=False)

        self.ln_i = nn.LayerNorm(hidden_node)
        self.ln_f = nn.LayerNorm(hidden_node)
        self.ln_w = nn.LayerNorm(hidden_node)
        self.ln_o = nn.LayerNorm(hidden_node)
        self.ln_c = nn.LayerNorm(hidden_node)

    def forward(self, input, hidden):
        h_prev, c_prev = hidden

        gi = self.w_i(input)
        gh = self.w_h(h_prev)
        i_i, i_f, i_w, i_o = gi.chunk(4, dim=-1)
        h_i, h_f, h_w, h_o = gh.chunk(4, dim=-1)

        i_g = torch.sigmoid(self.ln_i(i_i + h_i))
        f_g = torch.sigmoid(self.ln_f(i_f + h_f))
        w_g = torch.tanh(self.ln_w(i_w + h_w))
        o_g = torch.sigmoid(self.ln_o(i_o + h_o))

        c_new = f_g * c_prev + i_g * w_g
        c_new = self.ln_c(c_new)

        h_new = o_g * torch.tanh(c_new)

        return h_new, c_new

# State feature extractor using LayerNorm LSTM
class StateExtr(nn.Module):
    def __init__(self, input_node, hidden_node, n_layer, dropout):
        super().__init__()
        self.hidden_node = hidden_node
        self.n_layer = n_layer
        self.input_node = input_node

        self.lstm_cells = nn.ModuleList()
        self.lstm_cells.append(LayerNormLSTM(input_node, hidden_node))

        for i in range(n_layer - 1):
            self.lstm_cells.append(LayerNormLSTM(hidden_node, hidden_node))

        self.dropout = nn.Dropout(dropout)
        self.layernorm = nn.LayerNorm(hidden_node)

    def forward(self, x, seq_len):
        batch_size, max_len, _ = x.size()
        device = x.device

        h_states = []
        c_states = []

        for _ in range(self.n_layer):
            h_states.append(torch.zeros(batch_size, self.hidden_node, device=device))
            c_states.append(torch.zeros(batch_size, self.hidden_node, device=device))

        outputs = []
        for t in range(max_len):
            x_t = x[:, t, :]

            layer_input = x_t
            for layer_idx, lstm_cell in enumerate(self.lstm_cells):
                h_new, c_new = lstm_cell(layer_input, (h_states[layer_idx], c_states[layer_idx]))
                
                h_states[layer_idx] = h_new
                c_states[layer_idx] = c_new

                if layer_idx < len(self.lstm_cells) - 1:
                    layer_input = self.dropout(h_new)
                else:
                    layer_input = h_new

            outputs.append(layer_input)

        output_tensor = torch.stack(outputs, dim=1)
        seq_len_cpu = seq_len.detach().cpu().long()
        mask = torch.arange(max_len, device='cpu')[None, :] < seq_len_cpu[:, None]
        mask = mask.float().to(device).unsqueeze(-1)

        masked_output = output_tensor * mask
        normed_output = self.layernorm(masked_output)
        return self.dropout(normed_output)

# Physical change regressor
class PhysRegr(nn.Module):
    def __init__(self, input_node, output_node, n_layer, hidden_node, dropout):
        super().__init__()

        layers = []

        layers.extend([
            nn.Linear(input_node, hidden_node),
            nn.ReLU(),
            nn.Dropout(dropout)
        ])

        for _ in range(n_layer - 1):
            layers.extend([
                nn.Linear(hidden_node, hidden_node),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
        
        layers.append(nn.Linear(hidden_node, output_node))
        layers.append(nn.Sigmoid())

        self.layers = nn.Sequential(*layers)

    def forward(self, hidden_states):
        return self.layers(hidden_states)

# Current regressor
class CurrRegr(nn.Module):
    def __init__(self, input_node, hidden_node, n_layer, dropout):
        super().__init__()

        layers = []

        layers.extend([
            nn.Linear(input_node, hidden_node),
            nn.ReLU(),
            nn.Dropout(dropout)
        ])

        for _ in range(n_layer - 1):
            layers.extend([
                nn.Linear(hidden_node, hidden_node),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
        
        layers.append(nn.Linear(hidden_node, 1))

        self.layers = nn.Sequential(*layers)

    def forward(self, hidden_states):
        return self.layers(hidden_states)

# Physical Constraint Layer
class PhysConstr(nn.Module):
    def __init__(self, range_mm, curr_regr, eps=1e-2):
        super().__init__()
        
        self.eps = eps
        self.curr_regr = curr_regr
        self.register_buffer('range_mm_tensor',self._range2tensor(range_mm))

    def _range2tensor(self, range_mm):
        feature_names = ['V', 'E', 'VF', 'VA', 'VB', 'CFLA', 'CALA', 'CFK', 'CBK', 'I']
        ranges = torch.zeros(len(feature_names), 2)

        for i, name in enumerate(feature_names):
            ranges[i, 0] = range_mm[name]['min']
            ranges[i, 1] = range_mm[name]['max']

        return ranges

    def _norm_tensor(self, data, feature_idx):
        min_val = self.range_mm_tensor[feature_idx, 0]
        max_val = self.range_mm_tensor[feature_idx, 1]
        return (data - min_val) / (max_val - min_val)

    def _denorm_tensor(self, norm_data, feature_idx):
        min_val = self.range_mm_tensor[feature_idx, 0]
        max_val = self.range_mm_tensor[feature_idx, 1]
        return norm_data * (max_val - min_val) + min_val

    def forward(self, phys_chng, cur_state, fin, initV):
        V_idx, E_idx, VF_idx, VA_idx, VB_idx = 0, 1, 2, 3, 4
        CFLA_idx, CALA_idx, CFK_idx, CBK_idx, I_idx = 5, 6, 7, 8, 9

        VF = self._denorm_tensor(cur_state[..., 2:3], VF_idx)
        VA = self._denorm_tensor(cur_state[..., 3:4], VA_idx)
        VB = self._denorm_tensor(cur_state[..., 4:5], VB_idx)
        CFLA = self._denorm_tensor(cur_state[..., 5:6], CFLA_idx)
        CALA = self._denorm_tensor(cur_state[..., 6:7], CALA_idx)
        CFK = self._denorm_tensor(cur_state[..., 7:8], CFK_idx)
        CBK = self._denorm_tensor(cur_state[..., 8:9], CBK_idx)

        FvF, FvA, FvB, CiLA, CiK = fin
        VFi, VAi, VBi = initV

        dVF_in, dVA_in, dVB_in = FvF, FvA, FvB
        dNFLA_in, dNFK_in = FvF * CiLA, FvF * CiK

        rdVA = phys_chng[..., 0:1]
        rdVB = phys_chng[..., 1:2]
        rdNALA = phys_chng[..., 2:3]
        rdNBK = phys_chng[..., 3:4]

        NFLA = CFLA * VF
        NALA = CALA * VA
        NFK = CFK * VF
        NBK = CBK * VB

        VF_after_feed = VF + dVF_in
        VA_after_feed = VA + dVA_in
        VB_after_feed = VB + dVB_in
        NFLA_after_feed = NFLA + dNFLA_in
        NALA_after_feed = NALA
        NFK_after_feed = NFK + dNFK_in
        NBK_after_feed = NBK

        dVA = VF_after_feed * (rdVA - 0.5)
        dVB = VF_after_feed * (rdVB - 0.5)
        dNALA = NFLA_after_feed * rdNALA
        dNBK = NFK_after_feed * rdNBK

        nVF_bf = VF_after_feed - dVA - dVB
        nVA_bf = VA_after_feed + dVA
        nVB_bf = VB_after_feed + dVB

        nNFLA_bf = NFLA_after_feed - dNALA
        nNALA_bf = NALA_after_feed + dNALA
        nNFK_bf = NFK_after_feed - dNBK
        nNBK_bf = NBK_after_feed + dNBK

        nCFLA = nNFLA_bf / nVF_bf
        nCALA = nNALA_bf / nVA_bf
        nCFK = nNFK_bf / nVF_bf
        nCBK = nNBK_bf / nVB_bf

        dVF_out = nVF_bf - VFi
        dVA_out = nVA_bf - VAi
        dVB_out = nVB_bf - VBi

        nVF = torch.where(dVF_out > 0, nVF_bf - dVF_out, nVF_bf)
        nVA = torch.where(dVA_out > 0, nVA_bf - dVA_out, nVA_bf)
        nVB = torch.where(dVB_out > 0, nVB_bf - dVB_out, nVB_bf)
        
        dNFLA_out = torch.where(dVF_out > 0, nCFLA * dVF_out, torch.zeros_like(dVF_out))
        dNFK_out = torch.where(dVF_out > 0, nCFK * dVF_out, torch.zeros_like(dVF_out))
        dNALA_out = torch.where(dVA_out > 0, nCALA * dVA_out, torch.zeros_like(dVA_out))
        dNBK_out = torch.where(dVB_out > 0, nCBK * dVB_out, torch.zeros_like(dVB_out))

        nNFLA = nNFLA_bf - dNFLA_out
        nNALA = nNALA_bf - dNALA_out
        nNFK = nNFK_bf - dNFK_out
        nNBK = nNBK_bf - dNBK_out

        V = cur_state[..., 0:1]
        E = cur_state[..., 1:2]

        nVF_norm = self._norm_tensor(nVF, VF_idx)
        nVA_norm = self._norm_tensor(nVA, VA_idx)
        nVB_norm = self._norm_tensor(nVB, VB_idx)
        nCFLA_norm = self._norm_tensor(nCFLA, CFLA_idx)
        nCALA_norm = self._norm_tensor(nCALA, CALA_idx)
        nCFK_norm = self._norm_tensor(nCFK, CFK_idx)
        nCBK_norm = self._norm_tensor(nCBK, CBK_idx)

        temp_state = torch.cat([
            V, E, nVF_norm, nVA_norm, nVB_norm, nCFLA_norm, nCALA_norm, nCFK_norm, nCBK_norm
        ], dim=-1)

        nI_pred = self.curr_regr(temp_state)
        nI_real = self._denorm_tensor(nI_pred, I_idx)
        nI_real = torch.clamp(nI_real, min=0.0)
        nI_norm = self._norm_tensor(nI_real, I_idx)

        next_state = torch.cat([
            V, E, nVF_norm, nVA_norm, nVB_norm, nCFLA_norm, nCALA_norm, nCFK_norm, nCBK_norm, nI_norm
        ], dim=-1)

        discharge = {
            'VF': dVF_out,
            'VA': dVA_out,
            'VB': dVB_out,
            'NFLA': dNFLA_out,
            'NALA': dNALA_out,
            'NFK': dNFK_out,
            'NBK': dNBK_out,
            'CFLA': nCFLA,
            'CALA': nCALA,
            'CFK': nCFK,
            'CBK': nCBK
        }

        return next_state, discharge

# BMED model with hidden state tracking
class BMEDModelWithTracking(nn.Module):
    def __init__(self, state_extr_params, phys_regr_params, curr_regr_params, range_mm):
        super().__init__()
        self.state_extr = StateExtr(**state_extr_params)
        self.phys_regr = PhysRegr(**phys_regr_params)
        self.curr_regr = CurrRegr(**curr_regr_params)
        self.phys_constr = PhysConstr(range_mm, self.curr_regr)

        self._hidden_states = None
        self._cell_states = None
        
        # Tracking variables
        self.tracked_hidden_states = []
        self.tracked_phys_regr_inputs = []
        self.tracked_phys_chng = []

    def _reset_hidden_states(self, batch_size, device):
        self._hidden_states = []
        self._cell_states = []
        for _ in range(self.state_extr.n_layer):
            self._hidden_states.append(torch.zeros(batch_size, self.state_extr.hidden_node, device=device))
            self._cell_states.append(torch.zeros(batch_size, self.state_extr.hidden_node, device=device))

    def cont_sim(self, init_state, target_len, fin, initV, track=False):
        batch_size = init_state.size(0)
        feature_size = init_state.size(1)
        device = init_state.device

        self._reset_hidden_states(batch_size, device)
        
        if track:
            self.tracked_hidden_states = []
            self.tracked_phys_regr_inputs = []
            self.tracked_phys_chng = []

        pred = torch.zeros(batch_size, target_len, feature_size, device=device)
        discharge_record = []
        cur_state = init_state.clone()

        for t in range(target_len):
            pred[:, t, :] = cur_state

            if t < target_len - 1:
                lstm_input = cur_state[:, :-1]
                hidden_output = self._lstm_single_step(lstm_input)
                
                if track:
                    self.tracked_hidden_states.append(hidden_output.detach().cpu())
                    self.tracked_phys_regr_inputs.append(hidden_output.detach().cpu())

                phys_chng = self.phys_regr(hidden_output.unsqueeze(1))
                
                if track:
                    self.tracked_phys_chng.append(phys_chng.detach().cpu())
                
                cur_state_expanded = cur_state.unsqueeze(1)

                next_state, discharge = self.phys_constr(
                    phys_chng, cur_state_expanded, fin, initV
                )

                cur_state = next_state.squeeze(1)
                discharge_record.append(discharge)
        return pred, discharge_record

    def _lstm_single_step(self, x_t):
        layer_input = x_t

        for layer_idx, lstm_cell in enumerate(self.state_extr.lstm_cells):
            h_new, c_new = lstm_cell(layer_input, (self._hidden_states[layer_idx], self._cell_states[layer_idx]))
            
            self._hidden_states[layer_idx] = h_new
            self._cell_states[layer_idx] = c_new

            if layer_idx < len(self.state_extr.lstm_cells) - 1:
                layer_input = self.state_extr.dropout(h_new)
            else:
                layer_input = h_new

        normed_output = self.state_extr.layernorm(layer_input)
        return self.state_extr.dropout(normed_output)

    def forward(self, init_state, target_len, fin, initV, track=False):
        return self.cont_sim(init_state, target_len, fin, initV, track=track)

print('Model classes loaded with hidden state tracking capability')

Model classes loaded with hidden state tracking capability


In [3]:
def normalize(inputs, range_mm):
    features = ['V', 'E', 'VF', 'VA', 'VB', 'CFLA', 'CALA', 'CFK', 'CBK']
    norm = []
    for _, (name, value) in enumerate(zip(features, inputs)):
        min_val = range_mm[name]['min']
        max_val = range_mm[name]['max']
        norm_val = (value - min_val) / (max_val - min_val)
        norm.append(norm_val)
    return norm

def denormalize(outputs, range_mm):
    feature_names = ['V', 'E', 'VF', 'VA', 'VB', 'CFLA', 'CALA', 'CFK', 'CBK', 'I']
    denormalized = np.zeros_like(outputs)
    for i, name in enumerate(feature_names):
        if name in range_mm:
            min_val = range_mm[name]['min']
            max_val = range_mm[name]['max']
            denormalized[:, :, i] = outputs[:, :, i] * (max_val - min_val) + min_val
        else:
            denormalized[:, :, i] = outputs[:, :, i]
    return denormalized

In [4]:
# Load trained model
model_path = 'BMED_FR_250930.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Model: {model_path}')
print(f'Device: {device}')

model = torch.load(model_path, map_location=device, weights_only=False)
model_config = model['model_config']
state_extr_params = model_config['state_extr_params']
phys_regr_params = model_config['phys_regr_params']
curr_regr_params = model_config['curr_regr_params']
model_range_mm = model_config['range_mm']

simulator = BMEDModelWithTracking(
    state_extr_params = state_extr_params,
    phys_regr_params = phys_regr_params,
    curr_regr_params = curr_regr_params,
    range_mm = model_range_mm
).to(device)

simulator.load_state_dict(model['model_state_dict'], strict=False)
simulator.eval()

print('Load model parameters with tracking capability')

Model: BMED_FR_250930.pth
Device: cuda
Load model parameters with tracking capability


## Test Cases: Three Conditions

**Case 1**: `10V + E=0.25M` (Low driving force)
- Expected: Low CALA (~0.016 mol/L)

**Case 2**: `30V + E=0.25M` (High driving force from voltage)
- Expected: High CALA (~1.47 mol/L)

**Case 3**: `10V + E=1.0M` (High driving force from electrolyte)
- User discovery: Converges to **same** CALA as Case 2 (~1.47 mol/L)
- **This suggests saturation problem**

All use:
- CFLA = 3 mol/L
- QF = QA = QB = 10 mL/min
- Simulation time = 200 timesteps (50 hours)

In [5]:
# === Case 1: 10V + E=0.25M ===
print('=== Case 1: 10V + E=0.25M (Low driving force) ===')
input_init_case1 = [10, 0.25, 3, 10, 10, 10]  # [V, E, CFLA, QF, QA, QB]
cond_init_case1 = [input_init_case1[0], input_init_case1[1], 0.7, 0.7, 0.7, 
                   input_init_case1[2], 0, input_init_case1[1]*2, 0]
simulation_time = 200

QF, QA, QB = input_init_case1[3], input_init_case1[4], input_init_case1[5]
cond_flow = [QF*60/1000*0.25, QA*60/1000*0.25, QB*60/1000*0.25, input_init_case1[2], 2*input_init_case1[2]]
initV = [0.7, 0.7, 0.7]

norm_inputs_case1 = normalize(cond_init_case1, model_range_mm)
init_state_case1 = torch.tensor([norm_inputs_case1 + [0.0]]).float().to(device)

with torch.no_grad():
    pred_case1, discharge_case1 = simulator(init_state_case1, simulation_time, cond_flow, initV, track=True)

pred_case1_real = denormalize(pred_case1.cpu().numpy(), model_range_mm)
hidden_case1 = simulator.tracked_hidden_states
phys_chng_case1 = simulator.tracked_phys_chng

print(f'Case 1 Steady State CALA: {pred_case1_real[0, -1, 6]:.6f} mol/L')
print(f'Case 1 Steady State Current: {pred_case1_real[0, -1, 9]:.6f} A')
print(f'Hidden states tracked: {len(hidden_case1)} timesteps')

=== Case 1: 10V + E=0.25M (Low driving force) ===
Case 1 Steady State CALA: 0.016107 mol/L
Case 1 Steady State Current: 0.695053 A
Hidden states tracked: 199 timesteps


In [6]:
# === Case 2: 30V + E=0.25M ===
print('\n=== Case 2: 30V + E=0.25M (High voltage) ===')
input_init_case2 = [30, 0.25, 3, 10, 10, 10]  # [V, E, CFLA, QF, QA, QB]
cond_init_case2 = [input_init_case2[0], input_init_case2[1], 0.7, 0.7, 0.7, 
                   input_init_case2[2], 0, input_init_case2[1]*2, 0]

norm_inputs_case2 = normalize(cond_init_case2, model_range_mm)
init_state_case2 = torch.tensor([norm_inputs_case2 + [0.0]]).float().to(device)

with torch.no_grad():
    pred_case2, discharge_case2 = simulator(init_state_case2, simulation_time, cond_flow, initV, track=True)

pred_case2_real = denormalize(pred_case2.cpu().numpy(), model_range_mm)
hidden_case2 = simulator.tracked_hidden_states
phys_chng_case2 = simulator.tracked_phys_chng

print(f'Case 2 Steady State CALA: {pred_case2_real[0, -1, 6]:.6f} mol/L')
print(f'Case 2 Steady State Current: {pred_case2_real[0, -1, 9]:.6f} A')
print(f'Hidden states tracked: {len(hidden_case2)} timesteps')


=== Case 2: 30V + E=0.25M (High voltage) ===
Case 2 Steady State CALA: 1.480612 mol/L
Case 2 Steady State Current: 2.184147 A
Hidden states tracked: 199 timesteps


In [7]:
# === Case 3: 10V + E=1.0M ===
print('\n=== Case 3: 10V + E=1.0M (High electrolyte) ===')
input_init_case3 = [10, 1.0, 3, 10, 10, 10]  # [V, E, CFLA, QF, QA, QB]
cond_init_case3 = [input_init_case3[0], input_init_case3[1], 0.7, 0.7, 0.7, 
                   input_init_case3[2], 0, input_init_case3[1]*2, 0]

norm_inputs_case3 = normalize(cond_init_case3, model_range_mm)
init_state_case3 = torch.tensor([norm_inputs_case3 + [0.0]]).float().to(device)

with torch.no_grad():
    pred_case3, discharge_case3 = simulator(init_state_case3, simulation_time, cond_flow, initV, track=True)

pred_case3_real = denormalize(pred_case3.cpu().numpy(), model_range_mm)
hidden_case3 = simulator.tracked_hidden_states
phys_chng_case3 = simulator.tracked_phys_chng

print(f'Case 3 Steady State CALA: {pred_case3_real[0, -1, 6]:.6f} mol/L')
print(f'Case 3 Steady State Current: {pred_case3_real[0, -1, 9]:.6f} A')
print(f'Hidden states tracked: {len(hidden_case3)} timesteps')


=== Case 3: 10V + E=1.0M (High electrolyte) ===
Case 3 Steady State CALA: 1.483344 mol/L
Case 3 Steady State Current: 1.787252 A
Hidden states tracked: 199 timesteps


In [8]:
# === Three-Way Comparison Summary ===
print('\n' + '='*70)
print('SATURATION PROBLEM DIAGNOSIS')
print('='*70)

time_steps = np.arange(simulation_time) * 0.25
ss_start, ss_end = 160, 200

# Steady state values
CALA_case1_ss = np.mean(pred_case1_real[0, ss_start:ss_end, 6])
CALA_case2_ss = np.mean(pred_case2_real[0, ss_start:ss_end, 6])
CALA_case3_ss = np.mean(pred_case3_real[0, ss_start:ss_end, 6])

I_case1_ss = np.mean(pred_case1_real[0, ss_start:ss_end, 9])
I_case2_ss = np.mean(pred_case2_real[0, ss_start:ss_end, 9])
I_case3_ss = np.mean(pred_case3_real[0, ss_start:ss_end, 9])

print(f'\nCase 1 (10V + 0.25M): CALA={CALA_case1_ss:.6f} mol/L, I={I_case1_ss:.4f} A')
print(f'Case 2 (30V + 0.25M): CALA={CALA_case2_ss:.6f} mol/L, I={I_case2_ss:.4f} A')
print(f'Case 3 (10V + 1.0M):  CALA={CALA_case3_ss:.6f} mol/L, I={I_case3_ss:.4f} A')

# Case 2 vs Case 3 comparison (should be similar according to user's discovery)
CALA_diff_2_3 = abs(CALA_case2_ss - CALA_case3_ss)
CALA_diff_2_3_pct = (CALA_diff_2_3 / CALA_case2_ss) * 100

print(f'\n🔍 Case 2 vs Case 3 (SATURATION CHECK):')
print(f'   CALA Difference: {CALA_diff_2_3:.6f} mol/L ({CALA_diff_2_3_pct:.2f}%)')

if CALA_diff_2_3_pct < 5.0:
    print('   ❌ SATURATION CONFIRMED: Different (V,E) → Same CALA')
    print('      30V + 0.25M ≈ 10V + 1.0M (both high efficiency → plateau)')
else:
    print('   ✅ Saturation NOT observed (>5% difference)')

# Case 1 vs Case 2/3
print(f'\n📊 Case 1 vs High-Efficiency Cases:')
CALA_diff_1_2 = abs(CALA_case1_ss - CALA_case2_ss)
CALA_diff_1_2_pct = (CALA_diff_1_2 / CALA_case2_ss) * 100
print(f'   Case 1 vs Case 2: {CALA_diff_1_2:.6f} mol/L ({CALA_diff_1_2_pct:.2f}%)')

CALA_diff_1_3 = abs(CALA_case1_ss - CALA_case3_ss)
CALA_diff_1_3_pct = (CALA_diff_1_3 / CALA_case3_ss) * 100
print(f'   Case 1 vs Case 3: {CALA_diff_1_3:.6f} mol/L ({CALA_diff_1_3_pct:.2f}%)')

print('\n' + '='*70)


SATURATION PROBLEM DIAGNOSIS

Case 1 (10V + 0.25M): CALA=0.016107 mol/L, I=0.6951 A
Case 2 (30V + 0.25M): CALA=1.480633 mol/L, I=2.1841 A
Case 3 (10V + 1.0M):  CALA=1.483366 mol/L, I=1.7873 A

🔍 Case 2 vs Case 3 (SATURATION CHECK):
   CALA Difference: 0.002733 mol/L (0.18%)
   ❌ SATURATION CONFIRMED: Different (V,E) → Same CALA
      30V + 0.25M ≈ 10V + 1.0M (both high efficiency → plateau)

📊 Case 1 vs High-Efficiency Cases:
   Case 1 vs Case 2: 1.464526 mol/L (98.91%)
   Case 1 vs Case 3: 1.467259 mol/L (98.91%)



## Hidden State Analysis: Why do Case 2 and Case 3 converge to same CALA?

**Key Question**: Case 2 (30V+0.25M) and Case 3 (10V+1M) have different (V, E) inputs but produce identical CALA (~1.48 mol/L).

**Hypothesis**: LSTM hidden states become similar over time, causing PhysRegr to predict similar rdNALA values despite different inputs.

**This section investigates**:
1. **Cosine similarity** between hidden states (Case 2 vs Case 3)
2. **Euclidean distance** between hidden states over time
3. **rdNALA prediction patterns** across all three cases
4. **Early vs Late** timestep comparison

In [9]:
# === Hidden State Similarity Analysis ===
from scipy.spatial.distance import cosine, euclidean

print('\n' + '='*70)
print('HIDDEN STATE SIMILARITY ANALYSIS')
print('='*70)

# Select key timesteps
early_t = [0, 5, 10, 20]
late_t = [100, 120, 140, 160, 180, 195]

def compute_similarity(h1_list, h2_list, timesteps):
    similarities = []
    distances = []
    for t in timesteps:
        if t >= len(h1_list) or t >= len(h2_list):
            continue
        h1 = h1_list[t].numpy().flatten()
        h2 = h2_list[t].numpy().flatten()
        cos_sim = 1 - cosine(h1, h2)
        euc_dist = euclidean(h1, h2)
        similarities.append(cos_sim)
        distances.append(euc_dist)
    return similarities, distances

print('\n📊 Case 2 (30V+0.25M) vs Case 3 (10V+1M):')
print('-' * 70)

# Early timesteps
early_sims, early_dists = compute_similarity(hidden_case2, hidden_case3, early_t)
print('\n🔹 Early Timesteps (t=0-20):')
for t, sim, dist in zip(early_t[:len(early_sims)], early_sims, early_dists):
    print(f'   t={t:3d}: Cosine Sim={sim:.4f}, Euclidean Dist={dist:.4f}')

avg_early_sim = np.mean(early_sims)
avg_early_dist = np.mean(early_dists)
print(f'   Average: Cosine Sim={avg_early_sim:.4f}, Euclidean Dist={avg_early_dist:.4f}')

# Late timesteps
late_sims, late_dists = compute_similarity(hidden_case2, hidden_case3, late_t)
print('\n🔹 Late Timesteps (t=100-195):')
for t, sim, dist in zip(late_t[:len(late_sims)], late_sims, late_dists):
    print(f'   t={t:3d}: Cosine Sim={sim:.4f}, Euclidean Dist={dist:.4f}')

avg_late_sim = np.mean(late_sims)
avg_late_dist = np.mean(late_dists)
print(f'   Average: Cosine Sim={avg_late_sim:.4f}, Euclidean Dist={avg_late_dist:.4f}')

print('\n🔍 Convergence Analysis:')
if avg_late_sim > avg_early_sim:
    print(f'   ✅ Hidden states CONVERGE over time')
    print(f'      Similarity: {avg_early_sim:.4f} → {avg_late_sim:.4f} (+{(avg_late_sim - avg_early_sim):.4f})')
else:
    print(f'   ⚠️  Hidden states do NOT converge')

print('\n' + '='*70)


HIDDEN STATE SIMILARITY ANALYSIS

📊 Case 2 (30V+0.25M) vs Case 3 (10V+1M):
----------------------------------------------------------------------

🔹 Early Timesteps (t=0-20):
   t=  0: Cosine Sim=0.9976, Euclidean Dist=0.4741
   t=  5: Cosine Sim=0.9853, Euclidean Dist=0.9803
   t= 10: Cosine Sim=0.9944, Euclidean Dist=0.6097
   t= 20: Cosine Sim=0.9977, Euclidean Dist=0.3853
   Average: Cosine Sim=0.9938, Euclidean Dist=0.6124

🔹 Late Timesteps (t=100-195):
   t=100: Cosine Sim=1.0000, Euclidean Dist=0.0361
   t=120: Cosine Sim=1.0000, Euclidean Dist=0.0304
   t=140: Cosine Sim=1.0000, Euclidean Dist=0.0289
   t=160: Cosine Sim=1.0000, Euclidean Dist=0.0285
   t=180: Cosine Sim=1.0000, Euclidean Dist=0.0284
   t=195: Cosine Sim=1.0000, Euclidean Dist=0.0284
   Average: Cosine Sim=1.0000, Euclidean Dist=0.0301

🔍 Convergence Analysis:
   ✅ Hidden states CONVERGE over time
      Similarity: 0.9938 → 1.0000 (+0.0062)

