In [16]:
import torch
from icecream import ic
from IPython.display import display, HTML

style = """
<style>
/* Classic Notebook */
.output_wrapper, .output, .output_area {
    max-height: 800px !important;
}
.output_scroll {
    height: auto !important;
    max-height: 800px !important;
    overflow: auto !important;
}

/* JupyterLab */
.jp-OutputArea-output {
    max-height: none !important;
    overflow: visible !important;
}
.jp-OutputArea-output .jp-RenderedText {
    white-space: pre-wrap;
}

/* Tracebacks / error boxes */
div.traceback, .jp-OutputArea-output .jp-Error, .error {
    max-height: 800px !important;
    overflow: auto !important;
    white-space: pre-wrap;
}
</style>
"""
display(HTML(style))



In [37]:
# model_name = 'pendulum'
# model_name = 'cvs' # no opt
# model_name='double_pendulum' # no opt
model_name = 'pendulum_friction'

model_name2="lstm"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open(f'checkpoints/{model_name}/{model_name2}_model.pkl', 'rb') as f:
    ckpt = torch.load(f, weights_only=False)
with open(f'data/{model_name}/processed_data.pkl', 'rb') as processed_data_file:
    x = torch.load(processed_data_file, weights_only=False)
test_data = x['test']
train_data = x['train']
args = ckpt['args']   # saved argparse.Namespace
model_state = ckpt['model']
data_args = ckpt.get('data_args', None)
print(ckpt.keys())
print(x.keys())
# --- build model and load weights ---
print(train_data.shape)
print(test_data.shape)
assert 'opt' in list(ckpt.keys())


dict_keys(['args', 'model', 'opt', 'data_args'])
dict_keys(['train', 'test'])
(450, 100, 28, 28)
(50, 100, 28, 28)


In [38]:
print(model_name)
with open(f'data/{model_name}/data_args.pkl', 'rb') as data_args:
    data_args = torch.load(data_args, weights_only=False)
ic(data_args)

pendulum_friction


[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mdata_args[39m[38;5;245m:[39m[38;5;245m [39m[38;5;245m{[39m[38;5;36m'[39m[38;5;36mmask_rate[39m[38;5;36m'[39m[38;5;245m:[39m[38;5;245m [39m[38;5;36m0.01[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m'[39m[38;5;36mmodel[39m[38;5;36m'[39m[38;5;245m:[39m[38;5;245m [39m[38;5;36m'[39m[38;5;36mpendulum[39m[38;5;36m'[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m'[39m[38;5;36mnoise_std[39m[38;5;36m'[39m[38;5;245m:[39m[38;5;245m [39m[38;5;36m0.0[39m[38;5;245m}[39m


{'mask_rate': 0.01, 'noise_std': 0.0, 'model': 'pendulum'}

In [39]:
import inspect
if model_name == "pendulum":
    from models.LSTM import create_lstm_pendulum as the_goku_model
    import models.LSTM as specific_goku_model


elif model_name == "pendulum_friction":
    from models.LSTM import create_lstm_pendulum_friction as the_goku_model
elif model_name == "cvs":
    from models.LSTM import create_lstm_cvs as the_goku_model
elif model_name == "double_pendulum":
    print("HELLLO")
    from models.LSTM import create_lstm_double_pendulum as the_goku_model
else:
    raise Exception("NO MODEL SELECTED")
print(model_name)
sig = inspect.signature(the_goku_model)
print(sig)
# parameter defaults (including kw-only)
spec = inspect.getfullargspec(the_goku_model)
defaults = {}
if spec.defaults:
    defaults.update({name: val for name, val in zip(spec.args[-len(spec.defaults):], spec.defaults)})
if spec.kwonlydefaults:
    defaults.update(spec.kwonlydefaults)
print("defaults:", defaults)

pendulum_friction
(input_dim=[28, 28], hidden_dim=16, num_layers=2)
defaults: {'input_dim': [28, 28], 'hidden_dim': 16, 'num_layers': 2}


In [40]:

goku = the_goku_model()
optimizer = torch.optim.Adam(goku.parameters(), lr=1e-3) # change to whatever optimizer was used
goku.load_state_dict(model_state)
optimizer.load_state_dict(ckpt['opt'])


In [41]:
# python
# make sure names and device match
goku = goku.to(device)                  # your model instance


goku.eval()
x = torch.as_tensor(test_data).float().to(device)   # (N,T,...) or (T,N,...)

# time array (use saved args if available)
delta_t = getattr(args, 'delta_t', 1.0)
seq_len = x.shape[1]
t = torch.arange(0.0, seq_len * float(delta_t), step=float(delta_t), device=device)

with torch.no_grad():
    out = goku(x)
    pred = out[0] if isinstance(out, tuple) else out

# undo normalization if data_args present (zscore or minmax)
def undo_norm(z, data_args):
    if not data_args: return z
    norm = data_args.get('norm', None)
    if norm == 'zscore' and 'x_mean' in data_args and 'x_std' in data_args:
        mean = torch.as_tensor(data_args['x_mean'], device=z.device).view(1,1,-1)
        std  = torch.as_tensor(data_args['x_std'], device=z.device).view(1,1,-1)
        return z * std + mean
    if norm in ('zero_to_one','minmax') and 'x_min' in data_args and 'x_max' in data_args:
        mn = torch.as_tensor(data_args['x_min'], device=z.device).view(1,1,-1)
        mx = torch.as_tensor(data_args['x_max'], device=z.device).view(1,1,-1)
        return z * (mx - mn) + mn
    return z

pred = undo_norm(pred, data_args)
true = undo_norm(x, data_args)

# compute per-timestep MAE / RMSE (average over all axes except time)
# Now compute error using robust alignment logic
import numpy as _np

print("pred.shape, true.shape:", pred.shape, true.shape, "device:", pred.device, true.device)

# If model returned (T, N, ...) transpose to (N, T, ...)
if pred.shape[0] != true.shape[0] and pred.shape[1] == true.shape[0]:
    perm = [1, 0] + list(range(2, pred.ndim))
    pred = pred.permute(*perm).contiguous()
    print("Permuted pred ->", pred.shape)

# batch dim check
if pred.shape[0] != true.shape[0]:
    raise RuntimeError(f"Batch size mismatch: pred {pred.shape[0]} vs true {true.shape[0]}")

# Align time length by slicing to min time (do this before subtraction)
min_time = min(pred.shape[1], true.shape[1])
if pred.shape[1] != true.shape[1]:
    print(f"Aligning time dimension: slicing to {min_time} timesteps (pred {pred.shape[1]}, true {true.shape[1]})")
    pred = pred[:, :min_time].contiguous()
    true = true[:, :min_time].contiguous()

# helper to match feature shapes (attempt reshape/broadcast)
def reconcile_feature_shapes(pred, true):
    pred_feat = tuple(pred.shape[2:])
    true_feat = tuple(true.shape[2:])
    if pred_feat == true_feat:
        return pred, true
    pred_prod = int(_np.prod(pred_feat)) if len(pred_feat) else 1
    true_prod = int(_np.prod(true_feat)) if len(true_feat) else 1

    # Case A: pred flattened (single dim equal to product of true feature dims)
    if len(pred_feat) == 1 and pred_prod == true_prod:
        new_shape = (pred.shape[0], pred.shape[1]) + true_feat
        try:
            pred = pred.view(*new_shape)
            print(f"Reshaped pred features {pred_feat} -> {true_feat}")
            return pred, true
        except Exception:
            pass

    # Case B: pred is scalar/channel (broadcast across true)
    if pred_prod == 1:
        pred = pred.view(pred.shape[0], pred.shape[1], *([1] * len(true_feat)))
        pred = pred.expand(-1, -1, *true_feat)
        print("Broadcasted scalar/single-channel pred ->", true_feat)
        return pred, true

    # Case C: pred has 1 spatial dimension and true has 2, try expand if dims match
    if len(pred_feat) == 1 and len(true_feat) == 2:
        p0 = pred_feat[0]
        h, w = true_feat
        if p0 == h:
            pred = pred.unsqueeze(-1).expand(-1, -1, h, w)
            print("Expanded pred H->HxW by repeating across W")
            return pred, true
        if p0 == w:
            pred = pred.unsqueeze(-2).expand(-1, -1, h, w)
            print("Expanded pred W->HxW by repeating across H")
            return pred, true

    # Case D: pred has same number of feature dims but unequal sizes -> try broadcasting if feasible
    if len(pred_feat) == len(true_feat):
        can_broadcast = all(p == t or p == 1 for p, t in zip(pred_feat, true_feat))
        if can_broadcast:
            expand_sizes = [t if p == 1 else p for p, t in zip(pred_feat, true_feat)]
            pred = pred.view(pred.shape[0], pred.shape[1], *pred_feat)
            pred = pred.expand(-1, -1, *expand_sizes)
            print("Broadcasted pred features ->", true_feat)
            return pred, true

    # Unfixable mismatch
    print("Cannot reconcile feature shapes automatically.")
    print(f"pred feature shape: {pred_feat} (prod={pred_prod}), true feature shape: {true_feat} (prod={true_prod})")
    raise RuntimeError(
        "Feature shape mismatch: pred {} vs true {}. "
        "If this is expected, add a reshape/broadcast rule; otherwise modify model/data to return matching shapes.".format(pred_feat, true_feat)
    )

pred, true = reconcile_feature_shapes(pred, true)

time_dim = 1
reduce_dims = tuple(i for i in range(pred.ndim) if i != time_dim)

# compute error AFTER aligning time and reconciling features
err = pred - true
mae_per_t = err.abs().mean(dim=reduce_dims).cpu().numpy()
rmse_per_t = err.pow(2).mean(dim=reduce_dims).sqrt().cpu().numpy()
print("Mean MAE:", mae_per_t.mean(), "Mean RMSE:", rmse_per_t.mean())


pred.shape, true.shape: torch.Size([50, 28, 28]) torch.Size([50, 100, 28, 28]) device: cuda:0 cuda:0
Aligning time dimension: slicing to 28 timesteps (pred 28, true 100)
Expanded pred H->HxW by repeating across W
Mean MAE: 0.42544726 Mean RMSE: 0.42823845


In [24]:
print(mae_per_t.shape)
print(mae_per_t.shape)

NameError: name 'mae_per_t' is not defined

In [25]:
# quick plotting
import numpy as np
import matplotlib.pyplot as plt
time = np.arange(100)
plt.figure(figsize=(6,3))
plt.plot(time, mae_per_t, label='MAE')
plt.plot(time, rmse_per_t, label='RMSE')
plt.xlabel('time'); plt.ylabel('error'); plt.legend(); plt.grid(True); plt.show()


NameError: name 'mae_per_t' is not defined

<Figure size 600x300 with 0 Axes>