#### Verify the Encoder Only 

We aim to make a script to verify the performance of TimeAwareEncoder. i.e., to train a encoder suitable for the dataset. 
  
Assume the ${x}(t|t)$ is known (while in real conditions, it's not known), shown in the 5th column in the .csv file.  
Typacially (in privious subnet work) na = nb =20.
![flowchart](/Users/bloodytaken/graduate_project/deepSI/notes/flowchart.png)

Design a method to test if 'TimeAwareEncoder' works by training a encoder using that network.  
The ${x}(t|t)$ shoule be compared with the $\hat{x}(t|t)$ to see if it is precise enough.


In [None]:
import torch
# x_true = torch.tensor(df['TrueState_'].values, dtype=torch.float32)

Encoder: 
$$x_0 = \psi_\theta (u_\text{past}, y_\text{past}, \Delta t_\text{past})$$

In [None]:
import torch
import torch.nn as nn
from torchviz import make_dot

class MLP_res_net_with_time(nn.Module):
    '''Modified MLP_res_net with time interval (delta_t) as additional input.'''
    def __init__(self, input_size: str | int | list, output_size: str | int | list, n_hidden_layers=3, n_hidden_nodes=128,
                 activation=nn.GELU, zero_bias=True):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.scalar_output = output_size == 'scalar'

        # Function to convert input shape
        def to_num(s):
            if isinstance(s, int):
                return s
            if s == 'scalar':
                return 1
            a = 1
            for si in s:
                a *= (1 if si == 'scalar' else si)
            return a
        
        if isinstance(input_size, list):
            input_size = sum(to_num(s) for s in input_size)

        output_size = 1 if self.scalar_output else output_size
        self.net_res = nn.Linear(input_size, output_size)

        # Sequential MLP with residual connections
        seq = [nn.Linear(input_size, n_hidden_nodes), activation()]
        for _ in range(n_hidden_layers - 1):
            seq.append(nn.Linear(n_hidden_nodes, n_hidden_nodes))
            seq.append(activation())
        seq.append(nn.Linear(n_hidden_nodes, output_size))
        self.net_nonlin = nn.Sequential(*seq)

        if zero_bias:
            for m in self.modules(): 
                if isinstance(m, nn.Linear):
                    nn.init.constant_(m.bias, val=0)  # Set bias to zero


    def forward(self, *ars):
        if len(ars) == 1:
            net_in = ars[0]
            net_in = net_in.view(net_in.shape[0], -1)  # Adds a dim when needed
        else:
            net_in = torch.cat([a.view(a.shape[0], -1) for a in ars], dim=1)  # Flattens everything

        out = self.net_nonlin(net_in) + self.net_res(net_in)
        return out[:, 0] if self.scalar_output else out


In [None]:
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from tqdm import trange
from networks import MLP_res_net


# ----------------------------------------------------------------
# 1. Normaliser
# ----------------------------------------------------------------
class Normalizer:
    def __init__(self):
        self.stats, self.log_flags = {}, {}

    def fit_transform(self, name, tensor: torch.Tensor, use_log: bool = False):
        if use_log:
            tensor = torch.log1p(tensor)
        mean, std = tensor.mean(), tensor.std()
        self.stats[name] = dict(mean=mean, std=std)
        self.log_flags[name] = use_log
        return (tensor - mean) / std

    def inverse_transform(self, name: str, norm_tensor: torch.Tensor):
        if name not in self.stats:
            raise ValueError(f"No stats found for '{name}'")
        mean, std = self.stats[name]['mean'], self.stats[name]['std']
        out = norm_tensor * std + mean
        return torch.expm1(out) if self.log_flags[name] else out


# ----------------------------------------------------------------
# 2. Create windowed (past-L) data
# ----------------------------------------------------------------
def create_past_data(u, y, dt, true_state, past_len: int):
    """Return (u_past, y_past, dt_past, next_state)."""
    dt_shifted = torch.cat([dt[1:], torch.tensor([[0.]], dtype=torch.float32)], dim=0)
    u_past, y_past, dt_past, state_next = [], [], [], []

    for i in range(past_len, len(u) - 1):
        u_past.append(torch.flip(u[i - past_len + 1:i + 1], dims=[0]).flatten())
        y_past.append(torch.flip(y[i - past_len + 1:i + 1], dims=[0]).flatten())

        dt_window = torch.flip(dt_shifted[i - past_len + 1:i + 1], dims=[0]).flatten()
        dt_window[dt_window == 0] = 1e-9          # avoid exact zeros
        dt_past.append(dt_window)

        state_next.append(true_state[i + 1])

    return (torch.stack(u_past),
            torch.stack(y_past),
            torch.stack(dt_past),
            torch.stack(state_next))


# ----------------------------------------------------------------
# 3. Load data and build DataLoaders
# ----------------------------------------------------------------
# CSV_PATH = '/Users/bloodytaken/graduate_project/data/MSD_linear_noiseless_k_040.csv'
df = pd.read_csv('/Users/bloodytaken/graduate_project/data/MSD/mass_spring_damper_data_state.csv')

u_raw = torch.tensor(df['Input'].values,  dtype=torch.float32).view(-1, 1)
y_raw = torch.tensor(df['Output'].values, dtype=torch.float32).view(-1, 1)
dt_raw = torch.tensor(df['Delta_t'].values, dtype=torch.float32).view(-1, 1)
state  = torch.tensor(df[['TrueState_1', 'TrueState_2']].values, dtype=torch.float32)

normal = Normalizer()
u_norm = normal.fit_transform('u', u_raw)
y_norm = normal.fit_transform('y', y_raw)

nx, nb, na = 2, 5, 5     # state dim, input taps, output taps
u_past, y_past, dt_past, state_next = create_past_data(
        u_norm, y_norm, dt_raw, state, past_len=max(nb, na))

# ---- split: 70 % train / 15 % val / 15 % test
N = len(u_past)
train_end, val_end = int(0.60 * N), int(0.75 * N)

train_ds = TensorDataset(u_past[:train_end], y_past[:train_end],
                         dt_past[:train_end], state_next[:train_end])
val_ds   = TensorDataset(u_past[train_end:val_end], y_past[train_end:val_end],
                         dt_past[train_end:val_end], state_next[train_end:val_end])
test_ds  = TensorDataset(u_past[val_end:], y_past[val_end:],
                         dt_past[val_end:], state_next[val_end:])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False)

# ----------------------------------------------------------------
# 4. Instantiate models and optimisers
# ----------------------------------------------------------------
model_time = MLP_res_net_with_time(
    input_size=[(nb, 1), (na, 1), (max(nb, na), 1)], output_size=nx)
model_plain = MLP_res_net(
    input_size=[(nb, 1), (na, 1)], output_size=nx)

criterion = nn.MSELoss()
lr = 1e-3
optim_time  = optim.Adam(model_time.parameters(),  lr=lr, weight_decay=1e-4)
optim_plain = optim.Adam(model_plain.parameters(), lr=lr, weight_decay=1e-4)

# ----------------------------------------------------------------
# 5. Training helper with NRMS metric
# ----------------------------------------------------------------
def compute_vector_nrms(y_true: torch.Tensor, y_pred: torch.Tensor) -> float:
    """
    Compute NRMS using vector-wise RMSE and vector-wise standard deviation.
    """
    residual = y_true - y_pred
    rmse = torch.sqrt(torch.mean(residual.pow(2)))                    # scalar
    std = torch.std(y_true.view(-1))                                  # flatten all dims
    return (rmse / std).item()


def train_model(model, optimiser, name):
    best_val = float('inf')
    early_cnt = 0
    sched = optim.lr_scheduler.ReduceLROnPlateau(
        optimiser, mode='min', factor=0.5, patience=10, verbose=True)

    for epoch in trange(30000, desc=f'{name} training', dynamic_ncols=True):
        # ---- train
        model.train()
        tr_loss = 0.0
        for u_b, y_b, dt_b, s_b in train_loader:
            optimiser.zero_grad()
            pred = model(u_b, y_b, dt_b) if 'time' in name else model(u_b, y_b)
            loss = criterion(pred, s_b)
            loss.backward()
            optimiser.step()
            tr_loss += loss.item()
        tr_loss /= len(train_loader)

        # ---- val
        model.eval()
        v_loss = 0.0
        y_true_all, y_pred_all = [], []
        with torch.no_grad():
            for u_b, y_b, dt_b, s_b in val_loader:
                pred = model(u_b, y_b, dt_b) if 'time' in name else model(u_b, y_b)
                y_true_all.append(s_b)
                y_pred_all.append(pred)
                v_loss += criterion(pred, s_b).item()
        v_loss /= len(val_loader)

        sched.step(v_loss)

        # Compute NRMS (on full val set)
        y_true = torch.cat(y_true_all, dim=0)
        y_pred = torch.cat(y_pred_all, dim=0)
        nrms_val = compute_vector_nrms(y_true, y_pred)


        if v_loss < best_val:
            best_val = v_loss
            torch.save(model.state_dict(), f'best_{name}.pth')
            early_cnt = 0
        else:
            early_cnt += 1

        if early_cnt >= 5000:
            print(f'Early stop: {name}')
            break

        if (epoch + 1) % 100 == 0 or epoch == 0:
            print(f"[{name}] Epoch {epoch+1:03d}  train {tr_loss:.4e}  val {v_loss:.4e}  NRMS: {nrms_val:.4f}")


# ----------------------------------------------------------------
# 6. Run training
# ----------------------------------------------------------------
train_model(model_time,  optim_time,  'MLP_res_net_with_time')
train_model(model_plain, optim_plain, 'MLP_res_net')

# reload best weights
model_time.load_state_dict(torch.load('best_MLP_res_net_with_time.pth'))
model_plain.load_state_dict(torch.load('best_MLP_res_net.pth'))
model_time.eval();  model_plain.eval()

#### Test results visulization

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

true_s, pred_time, pred_plain = [], [], []

with torch.no_grad():
    for u_b, y_b, dt_b, s_b in test_loader:
        true_s.append(s_b)
        pred_time.append(model_time(u_b, y_b, dt_b))
        pred_plain.append(model_plain(u_b, y_b))

true_s     = torch.cat(true_s).cpu().numpy()   # [N, 2]
pred_time  = torch.cat(pred_time).cpu().numpy()
pred_plain = torch.cat(pred_plain).cpu().numpy()
idx = np.arange(len(true_s))

fig1, ax1 = plt.subplots(2, 1, figsize=(8, 6), sharex=True)

# --- ① Position subplot
ax1[0].plot(idx, true_s[:, 0],        'k-',  lw=1, label='measurement $x$')
ax1[0].plot(idx, pred_time[:, 0],     'C1-', lw=1, label='prediction $x$ (with Δt)')
ax1[0].plot(idx, pred_time[:, 0] - true_s[:, 0],
             ls='--', lw=1, alpha=1, label='x residual')
ax1[0].set_ylabel('position x')
ax1[0].legend(loc='upper right')

# --- ② Velocity subplot
ax1[1].plot(idx, true_s[:, 1],        'k-',  lw=1, label='ground-truth ẋ')
ax1[1].plot(idx, pred_time[:, 1],     'C1-', lw=1, label='pred ẋ (with Δt)')
ax1[1].plot(idx, pred_time[:, 1] - true_s[:, 1],
             ls='--', lw=1, alpha=1, label='ẋ residual')
ax1[1].set_ylabel('velocity ẋ')
ax1[1].set_xlabel('time index')
ax1[1].legend(loc='upper right')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()


In [None]:
# residual arrays
res_plain_x = pred_plain[:, 0] - true_s[:, 0]
res_time_x  = pred_time[:, 0]  - true_s[:, 0]
plt.figure(figsize=(5,3))           # 4:3 ratio
plt.axhline(0, color='grey', lw=0.5)

# bottom layer  ─ plain
plt.plot(idx, res_plain_x, 
         lw=1.0, alpha=0.6, zorder=1, label='plain Encoder')

# top layer  ─ with Δt
plt.plot(idx, res_time_x,  
         lw=1.0, alpha=0.9, zorder=5, label='Encoder with Δt')

plt.xlabel('time index')
plt.ylabel('position residual  (x)')
# plt.title('Residual comparison')
plt.legend(loc='upper right')
plt.tight_layout()
plt.show()


In [None]:
true_s_tensor = torch.tensor(true_s)
pred_time_tensor = torch.tensor(pred_time)
pred_plain_tensor = torch.tensor(pred_plain)

nrms_time = compute_vector_nrms(true_s_tensor, pred_time_tensor)
nrms_plain = compute_vector_nrms(true_s_tensor, pred_plain_tensor)

print(f"NRMS (with Δt): {nrms_time:.4f}")
print(f"NRMS (plain):   {nrms_plain:.4f}")



k=0
NRMS (with Δt): 0.0001
NRMS (plain):   0.0001

k=0.1
NRMS (with Δt): 0.0411
NRMS (plain):   0.2272

k=0.5
NRMS (with Δt): 0.0635
NRMS (plain):   0.5889

