## Model Architecture

In [1]:
import torch
import torch.nn as nn


# ------------------------------------------------------------
# Minimal Transformer Encoder Layer WITHOUT LayerNorm
# ------------------------------------------------------------
class SimpleTransformerLayer(nn.Module):
    def __init__(self, d_model=1000, nhead=10, dim_feedforward=2048):
        super().__init__()

        self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # No dropout, no LayerNorm
        self.act = nn.ReLU()

    def forward(self, x, attn_mask):
        # ---- Self Attention ----
        attn_out, _ = self.self_attn(x, x, x, attn_mask=attn_mask)
        x = x + attn_out  # residual

        # ---- Feed-forward ----
        ff = self.linear2(self.act(self.linear1(x)))
        x = x + ff  # residual

        return x


# ------------------------------------------------------------
# Decoder-Only Transformer (GPT-style)
# ------------------------------------------------------------
class DecoderOnlyTransformer(nn.Module):
    def __init__(
        self,
        in_dim=11,
        model_dim=1000,
        num_layers=3,
        dim_feedforward=2048,
        nhead=10,
        out_dim=2,
    ):
        super().__init__()

        self.model_dim = model_dim

        # project sensor input → transformer dimension
        self.input_proj = nn.Linear(in_dim, model_dim)

        # build N layers
        self.layers = nn.ModuleList(
            [
                SimpleTransformerLayer(
                    d_model=model_dim, nhead=nhead, dim_feedforward=dim_feedforward
                )
                for _ in range(num_layers)
            ]
        )

        # final output head (predict 2 values)
        self.fc = nn.Linear(model_dim, out_dim)

    def _causal_mask(self, T, device):
        """
        Returns a [T, T] causal mask with -inf above diagonal.
        """
        mask = torch.triu(torch.ones(T, T, device=device) * float("-inf"), diagonal=1)
        return mask

    def forward(self, x):
        """
        x: [B, T, in_dim]
        returns last-step prediction: [B, out_dim]
        """
        B, T, _ = x.size()

        # input embedding
        x = self.input_proj(x)  # → [B, T, model_dim]

        # causal mask
        mask = self._causal_mask(T, x.device)

        # transformer stack
        for layer in self.layers:
            x = layer(x, mask)  # [B, T, model_dim]

        # Get only last timestep (like LSTM last hidden state)
        last = x[:, -1, :]  # [B, model_dim]

        # Final prediction
        return self.fc(last)  # [B, out_dim]


In [3]:
# Instantiate and load weights
model = SimpleTransformerLayer()
ckpt_path = "best_t_model.pt"  # <-- change to your checkpoint
torch.save(model, "best_t_model.pt")

In [10]:
import torch
import numpy as np

ckpt_path = "best_t_model.pt"   # <-- change me

obj = torch.load(ckpt_path, map_location="cpu", weights_only=False)

# 1) Pick the right sub-dict containing tensors
if isinstance(obj, dict):
    if "model_state_dict" in obj:
        sd = obj["model_state_dict"]
    elif "state_dict" in obj:
        sd = obj["state_dict"]
    elif any(torch.is_tensor(v) for v in obj.values()):
        sd = obj
    else:
        raise ValueError("Could not find model weights in checkpoint dict.")
elif hasattr(obj, "state_dict"):
    sd = obj.state_dict()
else:
    raise ValueError("Unrecognized checkpoint format")

# 2) Flatten one level if some entries are nested dicts
flat = {}
for k, v in sd.items():
    if isinstance(v, dict):
        for kk, vv in v.items():
            if torch.is_tensor(vv):
                flat[f"{k}.{kk}"] = vv
    elif torch.is_tensor(v):
        flat[k] = v

# 3) Strip common wrappers (DDP/Lightning/etc.)
def strip_prefix(k):
    for p in ("module.", "model.", "net.", "student."):
        if k.startswith(p):
            return k[len(p):]
    return k

flat = {strip_prefix(k): v for k, v in flat.items()}

# 4) Save to NPZ
npz_path = "state_dict_npz.npz"
np.savez(npz_path, **{k: v.detach().cpu().numpy() for k, v in flat.items()})
print(f"✅ Saved {len(flat)} tensors → {npz_path}")
print("Sample keys:", list(flat.keys()))


✅ Saved 8 tensors → state_dict_npz.npz
Sample keys: ['self_attn.in_proj_weight', 'self_attn.in_proj_bias', 'self_attn.out_proj.weight', 'self_attn.out_proj.bias', 'linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias']


In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# ---- config (static) ----
BATCH, SEQ = 1, 100
MODEL_DIM = 1000
IN, OUT   = 11, 2
NPZ        = "state_dict_npz.npz"
SAVE_DIR   = "tf_export_lstm3_unrolled_static"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"   # optional for stable numerics

sd = np.load(NPZ)

