## Model Export Framework for PLC Deployment  

This module is developed as part of the research project:  

**_“Towards AI-Based Anomaly Detection at the Edge:  
Evaluating Real-Time Cyber Defense in Programmable Logic Controllers”_**  

The implementation establishes a deterministic and reproducible framework for deploying LSTM-based anomaly detection models on Programmable Logic Controllers (PLCs) in research and educational contexts.  

It enables consistent model loading from PyTorch checkpoints through a parameter-aligned interface that supports both full model serialization and `state_dict` checkpoints.  
All LSTM parameters — including input and recurrent weights, biases, and fully connected layer coefficients — are extracted per gate (`i, f, g, o`) and converted into PLC-native data structures.  

Two standardized export formats are produced:  
- **Siemens DATA_BLOCK (`.db`)** – for Siemens© PLC 'db' variable initialization.  
- **Structured Text (`.st`)** – for for Beckhoff© PLC 'st' variable initialization.  

The framework includes a user-confirmed overwrite mechanism for safe file handling and employs a hyperparameter-based naming convention to ensure traceability between the original neural network and its PLC-deployable form.  

This mechanism constitutes a core technical component of the above-mentioned research, supporting AI-based anomaly detection at the industrial edge and enabling real-time, resource-efficient, and auditable cyber defense in critical automation systems.  

---

## User Guide — Configuration and Execution  

This module automatically locates and exports the correct trained LSTM model based on a minimal configuration defined at the beginning of the script.  
The user only needs to configure a few parameters; all paths and exports are handled automatically.

---

### Step 1: Configure Model Parameters  

Before running the script, set these constants according to the trained model you want to export:

```python
MODEL_NAME = "LSTM_SWaT"
MODEL_VERSION = "v1"
HIDDEN_SIZE = 8
SEQUENCE_LENGTH = 10
BEST_EPOCH = 2
```

---

### Step 2: Run the Script  

Run the notebook or script manually (no need for “Run All”).  
Confirm overwrite if prompted.  
No additional user action is required — the export files will be created automatically.  

---

### Step 3: Check the Output  


After completion, verify that the following files exist in the model folder:
The exact filenames depend on the user-defined configuration parameters
(e.g., model name, hidden size, input window, version, and epoch).

Example output files:
- `LSTM_SWaT_HS8_IW10_v1_Epoch_02.db`
- `LSTM_SWaT_HS8_IW10_v1_Epoch_02.st`

These represent the PLC-ready exports of your trained model.  

---

## Legal Notice  

- **Author:** Zoltán Dobrády  
- **Website:** [www.cyberseclab.eu](https://www.cyberseclab.eu)  
- **Contact:** zoltan.dobrady@hotmail.com  
- **License:** Creative Commons BY-NC 4.0  
- **Version:** v1.0.1  
- **Copyright:** © 2025–2026  

This software is intended exclusively for **educational and research purposes**, or other **non-commercial applications**.  
Use in **industrial production environments** is prohibited.  
The **author must be credited** in all derivative or redistributive works.  
**Commercial use** requires **explicit written permission** from the author.  


In [1]:
import torch
import torch.nn as nn
from torch.serialization import add_safe_globals
import os
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display

BASE_DIR = Path.cwd()
PROJECT_DIR = BASE_DIR.parent
DATASET_DIR = PROJECT_DIR / "dataset"

In [2]:
MODEL_NAME = "LSTM_SWaT"
MODEL_VERSION = "v1"

HIDDEN_SIZE = 8
SEQUENCE_LENGTH = 10
BEST_EPOCH = 1 

MODEL_FILENAME = f"{MODEL_NAME}_HS{HIDDEN_SIZE}_IW{SEQUENCE_LENGTH}_{MODEL_VERSION}_Epoch_{BEST_EPOCH:02d}.pt"
MODEL_DIR = f"{PROJECT_DIR}/Models/HS{HIDDEN_SIZE}_IW{SEQUENCE_LENGTH}_{MODEL_VERSION}"
NORM_REFERENCE_FILE = f"{MODEL_DIR}/{MODEL_NAME}_HS{HIDDEN_SIZE}_IW{SEQUENCE_LENGTH}_{MODEL_VERSION}_featureNormRef.csv"
MODEL = f"{MODEL_DIR}/{MODEL_FILENAME}"

NUM_LAYERS = 1
NUM_FEATURES = 40 

In [3]:
class TempLSTMAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(0.0)
        self.fc = nn.Linear(hidden_size, input_size)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.dropout(out[:, -1, :])
        out = self.fc(out)
        return out

def get_array_definition(name, arr):
    if arr.ndim == 2:
        rows, cols = arr.shape
        return f"      {name} : ARRAY[0..{rows-1}, 0..{cols-1}] OF REAL;"
    elif arr.ndim == 1:
        size = arr.shape[0]
        return f"      {name} : ARRAY[0..{size-1}] OF REAL;"
    else:
        raise ValueError("Only 1D or 2D arrays are supported.")

def format_structured_weights(name, matrix):
    rows, cols = matrix.shape
    lines = []
    for r in range(rows):
        for c in range(cols):
            lines.append(f"      {name}[{r},{c}] := {matrix[r, c]:.8f};")
    return "\n".join(lines)

def format_structured_bias(name, vector):
    lines = []
    for i, val in enumerate(vector):
        lines.append(f"      {name}[{i}] := {val:.8f};")
    return "\n".join(lines)


def load_model_from_pt(path, input_size, hidden_size, num_layers):
    obj = torch.load(path, map_location="cpu", weights_only=False)
    if isinstance(obj, dict):
        model = TempLSTMAutoencoder(input_size, hidden_size, num_layers)
        model.load_state_dict(obj)
    else:
        model = obj

    model.eval()
    return model

def split_gates(mat, units):
    return {
        "i": mat[0*units:1*units],
        "f": mat[1*units:2*units],
        "g": mat[2*units:3*units],
        "o": mat[3*units:4*units],
    }

In [4]:


add_safe_globals([TempLSTMAutoencoder])
model = load_model_from_pt(MODEL, NUM_FEATURES, HIDDEN_SIZE, NUM_LAYERS)

print("Loaded model from:", MODEL)
print("Input features:", model.lstm.input_size)
print("Hidden size:", model.lstm.hidden_size)
print("Number of LSTM layers:", model.lstm.num_layers)
print("FC output size:", model.fc.out_features)


W = model.lstm.weight_ih_l0.detach().cpu().numpy()
R = model.lstm.weight_hh_l0.detach().cpu().numpy()
B = (model.lstm.bias_ih_l0.detach().cpu() + model.lstm.bias_hh_l0.detach().cpu()).numpy()
FC_W = model.fc.weight.detach().cpu().numpy()
FC_B = model.fc.bias.detach().cpu().numpy()

W_gates = split_gates(W, HIDDEN_SIZE)
R_gates = split_gates(R, HIDDEN_SIZE)
B_gates = split_gates(B, HIDDEN_SIZE)

print("W shape:", W.shape)
print("R shape:", R.shape)
print("B shape:", B.shape)
print("FC_W shape:", FC_W.shape)
print("FC_B shape:", FC_B.shape)

Loaded model from: c:\Repos\Towards-AI-Based-Anomaly-Detection-at-the-Edge-Evaluating-Real-Time-CyberDefense-in-PLC/Models/HS8_IW10_v1/LSTM_SWaT_HS8_IW10_v1_Epoch_01.pt
Input features: 40
Hidden size: 8
Number of LSTM layers: 1
FC output size: 40
W shape: (32, 40)
R shape: (32, 8)
B shape: (32,)
FC_W shape: (40, 8)
FC_B shape: (40,)


In [5]:
# --- Generate declarations
var_declarations = []
var_declarations.append(get_array_definition("W_i", W_gates["i"]))
var_declarations.append(get_array_definition("W_f", W_gates["f"]))
var_declarations.append(get_array_definition("W_g", W_gates["g"]))
var_declarations.append(get_array_definition("W_o", W_gates["o"]))

var_declarations.append(get_array_definition("R_i", R_gates["i"]))
var_declarations.append(get_array_definition("R_f", R_gates["f"]))
var_declarations.append(get_array_definition("R_g", R_gates["g"]))
var_declarations.append(get_array_definition("R_o", R_gates["o"]))

var_declarations.append(get_array_definition("B_i", B_gates["i"]))
var_declarations.append(get_array_definition("B_f", B_gates["f"]))
var_declarations.append(get_array_definition("B_g", B_gates["g"]))
var_declarations.append(get_array_definition("B_o", B_gates["o"]))

var_declarations.append(get_array_definition("FC_W", FC_W))
var_declarations.append(get_array_definition("FC_B", FC_B))

# --- Generate values
structured_output = ""
structured_output += format_structured_weights("W_i", W_gates["i"]) + "\n"
structured_output += format_structured_weights("W_f", W_gates["f"]) + "\n"
structured_output += format_structured_weights("W_g", W_gates["g"]) + "\n"
structured_output += format_structured_weights("W_o", W_gates["o"]) + "\n"

structured_output += format_structured_weights("R_i", R_gates["i"]) + "\n"
structured_output += format_structured_weights("R_f", R_gates["f"]) + "\n"
structured_output += format_structured_weights("R_g", R_gates["g"]) + "\n"
structured_output += format_structured_weights("R_o", R_gates["o"]) + "\n"

structured_output += format_structured_bias("B_i", B_gates["i"]) + "\n"
structured_output += format_structured_bias("B_f", B_gates["f"]) + "\n"
structured_output += format_structured_bias("B_g", B_gates["g"]) + "\n"
structured_output += format_structured_bias("B_o", B_gates["o"]) + "\n"

structured_output += format_structured_weights("FC_W", FC_W) + "\n"
structured_output += format_structured_bias("FC_B", FC_B) + "\n"

In [6]:
# --- build DB names from constants ---
DB_BLOCK_NAME = f"{MODEL_NAME}_HS{HIDDEN_SIZE}_IW{SEQUENCE_LENGTH}_{MODEL_VERSION}"
db_filename = Path(MODEL_FILENAME).with_suffix(".db").name
db_path = Path(MODEL_DIR) / db_filename

# --- assemble DB text ---
db_text = (
    f"""// --- Siemens Structured Format ---
DATA_BLOCK "{DB_BLOCK_NAME}"
{{ S7_Optimized_Access := 'TRUE' }}
VERSION : {MODEL_VERSION}
NON_RETAIN
   VAR
"""
    + "\n".join(var_declarations)
    + "\n   END_VAR\nBEGIN\n\n"
    + structured_output
    + "\nEND_DATA_BLOCK\n"
)

USER_CONFIRMED_OVERWRITE = False
out = widgets.Output()

if os.path.exists(db_path):
    print(f"[WARNING] File already exists at: {db_path}")
    print("Press YES to continue or NO to skip saving.")

    yes_button = widgets.Button(description="Yes", button_style="danger")
    no_button  = widgets.Button(description="No",  button_style="success")

    def _disable_buttons():
        yes_button.disabled = True
        no_button.disabled  = True

    def on_yes_clicked(b):
        global USER_CONFIRMED_OVERWRITE
        _disable_buttons()
        USER_CONFIRMED_OVERWRITE = True
        with out:
            out.clear_output()
            print("[WARNING] Existing DB file will be OVERWRITTEN!")
        db_path.write_text(db_text, encoding="utf-8")
        print(f"[INFO] DB saved: {db_path}")

    def on_no_clicked(b):
        _disable_buttons()
        with out:
            out.clear_output()
            print("[INFO] DB file exists, not saved.")

    yes_button.on_click(on_yes_clicked)
    no_button.on_click(on_no_clicked)

    display(widgets.HBox([yes_button, no_button]))
    display(out)

else:
    USER_CONFIRMED_OVERWRITE = True
    db_path.write_text(db_text, encoding="utf-8")
    print(f"[INFO] DB created at: {db_path}")




[INFO] DB created at: c:\Repos\Towards-AI-Based-Anomaly-Detection-at-the-Edge-Evaluating-Real-Time-CyberDefense-in-PLC\Models\HS8_IW10_v1\LSTM_SWaT_HS8_IW10_v1_Epoch_01.db


In [7]:
def format_st_matrix(name, matrix):
    rows, cols = matrix.shape
    text = f"{name} := [\n"
    for r in range(rows):
        row = ", ".join(f"{matrix[r, c]:.8f}" for c in range(cols))
        text += f"    {row},\n"
    text = text.rstrip(",\n") + "\n],\n"
    return text

def format_st_vector(name, vector):
    text = f"{name} := [" + ", ".join(f"{x:.8f}" for x in vector) + "],\n"
    return text

H = HIDDEN_SIZE
W_gates = split_gates(W, H)
R_gates = split_gates(R, H)
B_gates = split_gates(B, H)

output = ""
output += format_st_matrix("W_i", W_gates["i"])
output += format_st_matrix("W_f", W_gates["f"])
output += format_st_matrix("W_g", W_gates["g"])
output += format_st_matrix("W_o", W_gates["o"])

output += format_st_matrix("R_i", R_gates["i"])
output += format_st_matrix("R_f", R_gates["f"])
output += format_st_matrix("R_g", R_gates["g"])
output += format_st_matrix("R_o", R_gates["o"])

output += format_st_vector("B_i", B_gates["i"])
output += format_st_vector("B_f", B_gates["f"])
output += format_st_vector("B_g", B_gates["g"])
output += format_st_vector("B_o", B_gates["o"])

output += format_st_matrix("FC_W", FC_W)
output += format_st_vector("FC_B", FC_B)


In [8]:
ST_USER_CONFIRMED_OVERWRITE = False
st_out = widgets.Output()

ST_FILENAME = Path(MODEL_FILENAME).with_suffix(".st").name
st_path = Path(MODEL_DIR) / ST_FILENAME

if os.path.exists(st_path):
    print(f"[WARNING] File already exists at: {st_path}")
    print("Press YES to continue or NO to skip saving.")

    st_yes_button = widgets.Button(description="Yes", button_style="danger")
    st_no_button  = widgets.Button(description="No",  button_style="success")

    def _st_disable_buttons():
        st_yes_button.disabled = True
        st_no_button.disabled  = True

    def st_on_yes_clicked(b):
        global ST_USER_CONFIRMED_OVERWRITE
        _st_disable_buttons()
        ST_USER_CONFIRMED_OVERWRITE = True
        with st_out:
            st_out.clear_output()
            print("[WARNING] Existing ST file will be OVERWRITTEN!")
        st_path.write_text(output, encoding="utf-8")
        print(f"[INFO] ST saved: {st_path}")

    def st_on_no_clicked(b):
        _st_disable_buttons()
        with st_out:
            st_out.clear_output()
            print("[INFO] ST file exists, not saved.")

    st_yes_button.on_click(st_on_yes_clicked)
    st_no_button.on_click(st_on_no_clicked)

    display(widgets.HBox([st_yes_button, st_no_button]))
    display(st_out)

else:
    ST_USER_CONFIRMED_OVERWRITE = True
    st_path.write_text(output, encoding="utf-8")
    print(f"[INFO] ST created at: {st_path}")

[INFO] ST created at: c:\Repos\Towards-AI-Based-Anomaly-Detection-at-the-Edge-Evaluating-Real-Time-CyberDefense-in-PLC\Models\HS8_IW10_v1\LSTM_SWaT_HS8_IW10_v1_Epoch_01.st
