# Aproximador Explícito de Lei de Controle

Abordagem baseada em aprendizado para aproximar a solução de um problema de Model Predictive Control (MPC), representado por um problema de otimização quadrático (QP Problem).

O método desenvolvido a seguir desenvolve uma rede neural artificial (Multilayer Perceptron - MLP) que tenta encontrar o primal do problema QP através de aprendizado supervisionado, baseado no lagrangiano do problema e um conjunto de dados. As amostras de dados que compõem o dataset utilizado foram construídas pelos autores através do desenvolvimento de uma simulação do cenário de controle e um solver tradicional.

Ao longo do trabalho, serão apresentados dois contextos de aprendizados:

- #### Lagrangiano:
Método inicial, proposto pelo artigo, que definia a função de perda a partir do lagrangiano do problema. No fim, essa foi preterida em razão da próxima, que foi a versão final utilizada.

- #### Primal Restrita
Método utilizado para substituir a função de perda lagrangiana. **Foi a versão final utilizada.**

## Autores:
* Artur
* Carlos
* Vicenzo D'Arezzo Zilio - 13671790

# 1. Configurações Iniciais

In [None]:
%pip install -r requirements.txt

In [None]:
import random
import os
import math
import zipfile
import copy

from itertools import product
from typing import Callable, Dict, Any
from tqdm import tqdm

import matplotlib.pyplot as plt

import pandas as pd
import numpy as np

from scipy.sparse import issparse
from scipy.sparse.linalg import eigsh
from numpy.linalg import eigvals, norm

import torch, platform
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.autograd import gradcheck
from torch.utils.data import random_split

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

import onnx
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType

In [None]:
DATA_PATH = './content/states_with_bounds.npz'

## 1.1 Ambiente de Execução

Contexto de hardware onde as funcionalidades do PyTorch serão executadas. Troque com base no cenário desejado.
- MPS: framework GPU nativa dispositivos Apple
- Cuda: framework GPU disponível no colab

In [None]:
print("Torch:", torch.__version__, "| MPS:", torch.mps.is_available(), "| Py:", platform.python_version())

In [None]:
DEVICE = torch.device("mps" if torch.mps.is_available() else "cpu")

In [None]:
DATA_PATH = './content/states_with_bounds.npz'

# 2. Extração e Pré-processaemento de Dados

In [None]:
with zipfile.ZipFile(DATA_PATH, 'r') as zf:
    print("Conteúdo:", zf.namelist())

## 2.1 Carregamento de dados
Com base na definição das bases de dados gerada em simulação, cujo header pode ser visualizado pela célula anterior, é definida uma classe especializada para armazenar e validar as informações retidas. 

### 2.1.1 Problema Lagrangiano
São aplicadas diversas transformações de pré-processamento, principalmente em razão da instabilidade e dimensionalidade do problema.
Nessa versão foi realizada a visualização dos dados, portanto, aqui também se encontram os métodos de visualização.

- Tikhonov na matriz P - para garantir semi-positividade na matriz
- Normalização espectral nas matrizes - para melhorar condicionalidade
- Normalização por z-scalling no estado inicial, primal e duais - para melhorar desempenho do modelo e garantir equilíbrio numérico
- Escalonamento do problema a partir do primal - para reduzir condicionalidade da matriz P durante o cálculo

In [None]:
class Init_MPC_Problem:

    def __init__(self, np_raw_path):

      raw = np.load(DATA_PATH)

      # Dimensions
      self.n_states = int(raw["Nx"])
      self.n_actuations = int(raw["Nu"])
      self.n_horizon = int(raw["Nh"])
      self.dim_z = (self.n_states + self.n_actuations) * self.n_horizon

      # Constraints
      self.n_constraints =  int(raw["n_con"])
      self.n_eq_constraints = int(raw["n_eq"])
      self.n_ineq_constraints = int(raw["n_in"])
      
      # Iterating through data to extract each instance:
      # -> Decision variables, Duals and Bounds

      self.input_data = raw["data"].astype(np.float32)
      it_pos = 0

      self.state_list = self.input_data[:, it_pos: it_pos + self.n_states]
      it_pos += self.n_states
      
      self.z_list = self.input_data[:, it_pos : it_pos + self.dim_z]
      it_pos += self.dim_z

      self.raw_dual = self.input_data[:, it_pos : it_pos + self.n_constraints]
      it_pos += self.n_constraints

      self.lower_bounds = self.input_data[:, it_pos : it_pos + self.n_constraints]
      it_pos += self.n_constraints

      self.upper_bounds = self.input_data[:, it_pos : it_pos + self.n_constraints]

      # Quadratic Programming Matrices
      self.P = raw["P"].astype(np.float64)
      self.q = raw["q"].astype(np.float64)
      self.C = raw["C"].astype(np.float64)
      self.U = raw["U"].astype(np.float64)

      # Constraints Matrix
      self.D = np.concatenate((self.C,  self.U), axis=0)

      # Duals Extraction
      self.lamb = np.maximum(-self.raw_dual, 0)
      self.nu = np.maximum(self.raw_dual, 0)

      # Dimension Check
      assert self.P.shape == (self.dim_z, self.dim_z), f"P must be [{self.dim_z},{self.dim_z}], got {self.P.shape}"
      assert self.q.shape == (self.dim_z,), f"q must be [{self.dim_z}], got {self.q.shape}"
      assert self.C.shape == (self.n_eq_constraints, self.dim_z), f"C must be [{self.n_eq_constraints},{self.dim_z}], got {self.Q.shape}"
      assert self.U.shape == (self.n_ineq_constraints, self.dim_z), f"U must be [{self.n_ineq_constraints},{self.dim_z}], got {self.U.shape}"

      N_samples = self.input_data.shape[0]
      assert self.state_list.shape[0] == N_samples, f"state_list must have {N_samples} samples, got {self.state_list.shape[0]}"
      assert self.z_list.shape[0] == N_samples, f"z_list must have {N_samples} samples, got {self.z_list.shape[0]}"
      assert self.raw_dual.shape[0] == N_samples, f"raw_dual must have {N_samples} samples, got {self.raw_dual.shape[0]}"
      assert self.upper_bounds.shape[0] == N_samples, f"upper_bounds must have {N_samples} samples, got {self.upper_bounds.shape[0]}"
      assert self.lower_bounds.shape[0] == N_samples, f"lower_bounds must have {N_samples} samples, got {self.lower_bounds.shape[0]}"
      
      assert self.state_list.shape[1] == self.n_states, f"state_list must have {self.n_states} features, got {self.state_list.shape[1]}"
      assert self.z_list.shape[1] == self.dim_z, f"z_list must have {self.dim_z} features, got {self.z_list.shape[1]}"
      assert self.raw_dual.shape[1] == self.n_constraints, f"raw_lag must have {self.n_constraints} features, got {self.raw_lag.shape[1]}"
      assert self.upper_bounds.shape[1] == self.n_constraints, f"upper_bounds must have {self.n_constraints} features, got {self.upper_bounds.shape[1]}"
      assert self.lower_bounds.shape[1] == self.n_constraints, f"lower_bounds must have {self.n_constraints} features, got {self.lower_bounds.shape[1]}"

      total_expected_cols = self.n_states + self.dim_z  + self.n_constraints * 3
      assert self.input_data.shape[1] == total_expected_cols, f"Total input columns must be {total_expected_cols}, got {self.input_data.shape[1]}. Check slicing logic."
      print(f"[CHECK] - Dimensions and Sampling Ok")

      # P nature checking
      eigvals = np.linalg.eigvalsh(self.P)
      assert np.all(eigvals >= -1e-8), f"P not positive semidefinite, min eig = {eigvals.min()}"
      print(f"[CHECK] - P Matrices are Positive Semidefinite")

    def check_P_matrix_condition(self):
      """
      The condition number gives an indication of the tightness
      of the bounds on the curvature of f, and therefore
      of the difficulty to optimize it: the bigger the value of
      κ is, the slowest is the convergence of the algorithm
      """
      condition_number = np.linalg.cond(self.P)
      print(f"[CHECK] - Condition Number of P: {condition_number}")
      if condition_number > 1e6:
        print(f"[WARNING] - Condition Number of P matrix is too high, this may lead to numerical instability")

    def apply_pre_processing(self):

      # Condition transformations
      self.apply_tikhonov_regularization()
      self.normalizating_qp_matrices()

      # Distribuition transformation
      self.apply_state_standarization()
      self.apply_decision_variable_standarization()
      self.apply_duals_scalling()
      self.decision_variable_scalling()
      pass

    def apply_state_standarization(self):
      """
      x' = (x - x_mean) / x_std
      """
      self.state_mean = self.state_list.mean(axis=0)
      self.state_std = self.state_list.std(axis=0)
      self.state_list = (self.state_list - self.state_mean) / self.state_std
      print(f"[INFO] - State Standarization applied")

    def apply_decision_variable_standarization(self):
      """
      z' = (z - z_mean) / z_std
      """

      self.z_mean = self.z_list.mean(axis=0)
      self.z_std = self.z_list.std(axis=0)
      self.z_list = (self.z_list - self.z_mean) / self.z_std

      # This process must be applied to the constraints ineq:
      # a <= D z' <= b -> a <= D(mean) + (D @ diag(std)) z <= b
      # a' = a - D(mean) (same for b)
      # D' = D @ diag(std)

      self.lower_bounds = self.lower_bounds - self.D @ self.z_mean # Old D
      self.upper_bounds = self.upper_bounds - self.D @ self.z_mean # Old D
      self.D = self.D @ np.diag(self.z_std)
      
      # This process must be applie to the objective function:
      # P' = diag(std) @ P @ diag(std)
      # q = diag(std) @ P @ mean + diag(std) @ s

      self.q = np.diag(self.z_std) @ (self.P @ self.z_mean + self.q) # Dld P
      self.P = np.diag(self.z_std) @ self.P @ np.diag(self.z_std)

      print(f"[INFO] - Decision Variable Standarization applied")

    def apply_duals_scalling(self):
      """
      nu_std = (nu - nu_mean) / nu_std
      lambda_std = (lambda - lambda_mean) / lambda_std
      """
      self.nu = self.nu / self.nu.std(axis=0) + 1e-6
      self.lamb = self.lamb / self.lamb.std(axis=0) + 1e-6
      print(f"[INFO] - Duals Scalling by Std Dev applied")

    def apply_tikhonov_regularization(self, gamma=1e-5):
      """
      P_reg = P + gamma * I.
      Ensure positive definite and positive semidefinite nature of P and improve
      numerical stability by decreasing P condition number.
      """
      # Adicionar o termo de regularização à matriz P
      self.P = self.P + gamma * np.eye(self.P.shape[0], dtype=np.float64)
      print(f"[INFO] - Tikhonov Regularization applied = {gamma}")

    def normalizating_qp_matrices(self):
      """
        -> Normalizing P by its spectral norm
      """
      spec_norm_P = np.linalg.norm(self.P, ord=2)

      self.P = self.P / spec_norm_P
      self.q = self.q / spec_norm_P
      
      print(f"[INFO] - Normalizing P and q by P's spectral norm = {spec_norm_P}")

    def decision_variable_scalling(self):
      """
      Maps the optimization problem from z -> z', where
      z' = D^-1 z and D is calculated in order to reduce
      the condition number of P.

      D (or D^-1) must be applied to all the operations
      in relation to z
      """

      # Calculating scale:
      P_diag = np.diag(self.P)
      scale = np.sqrt(P_diag)
      scale[scale < 1e-9] = 1.0

      # Applying scale:
      self.z_scalling = np.diag(1.0 / scale)
      self.z_inv_scalling = 1/np.diag(self.z_scalling)

      self.P = self.z_scalling.T @ self.P @ self.z_scalling
      self.q = self.z_scalling.T @ self.q
      self.D = self.D @ self.z_scalling

      self.z_list = self.z_inv_scalling * self.z_list

      print(f"[INFO] - Decision Variable Scaling applied")

    def visualize_bounds(self):
      x = self.lower_bounds
      z = self.upper_bounds

      plt.figure(figsize=(10, 4))
      plt.boxplot(x, showfliers=False)
      plt.title("Lower Bounds (z_inf)")
      plt.xlabel("Dimension")
      plt.ylabel("Value")
      plt.grid(True)
      plt.show()

      plt.figure(figsize=(10, 4))
      plt.boxplot(z, showfliers=False)
      plt.title("Upper Bounds (z_sup)")
      plt.xlabel("Dimension")
      plt.ylabel("Value")
      plt.grid(True)
      plt.show()

    def visualize_duals(self):

        plt.figure(figsize=(10, 4))
        plt.boxplot(self.lamb, showfliers=False)
        plt.title("Lambda (lower bound dual)")
        plt.xlabel("Dimension")
        plt.ylabel("Value")
        plt.grid(True)
        plt.show()

        lamb_mean = self.lamb.mean(axis=0)

        plt.figure(figsize=(10, 4))
        plt.bar(range(len(lamb_mean)), lamb_mean, capsize=5)
        plt.title("Lambda – Mean value per dimension")
        plt.xlabel("Dimension")
        plt.ylabel("Mean")
        plt.grid(True, axis='y')
        plt.show()

        plt.figure(figsize=(10, 4))
        plt.boxplot(self.nu, showfliers=False)
        plt.title("Nu (upper bound dual)")
        plt.xlabel("Dimension")
        plt.ylabel("Value")
        plt.grid(True)
        plt.show()

        nu_mean = self.nu.mean(axis=0)

        plt.figure(figsize=(10, 4))
        plt.bar(range(len(nu_mean)), nu_mean, capsize=5)
        plt.title("Nu – Mean value per dimension")
        plt.xlabel("Dimension")
        plt.ylabel("Mean")
        plt.grid(True, axis='y')
        plt.show()


    def visualize_decision_vars(self):
      x = self.state_list
      z = self.z_list

      def summarize_array(arr, name):
        df = pd.DataFrame(arr, columns=[f"{name}_{i}" for i in range(arr.shape[1])])
        desc = df.describe(percentiles=[0.01, 0.05, 0.5, 0.95, 0.99]).T
        print(f"\nResumo estatístico de {name}:")
        print(desc[["mean", "std", "min", "1%", "5%", "50%", "95%", "99%", "max"]])
        return desc

      summarize_array(x, "x")
      summarize_array(z, "z")

      plt.figure(figsize=(10, 4))
      plt.boxplot(x, showfliers=False)
      plt.title("State Distribuition (x)")
      plt.xlabel("Dimension")
      plt.ylabel("Value")
      plt.grid(True)
      plt.show()

      plt.figure(figsize=(10, 4))
      plt.boxplot(z, showfliers=False)
      plt.title("Decision Variable Distribuition (z)")
      plt.xlabel("Dimension")
      plt.ylabel("Value")
      plt.grid(True)
      plt.show()

    def visualize_C_and_U(self):

      def _visualize_matrix(M, title_prefix="Matrix"):
        diag = np.diag(M)

        fig, axs = plt.subplots(1, 3, figsize=(16, 4))

        # 1. Valores da diagonal
        axs[0].plot(diag, marker='o')
        axs[0].set_title(f"{title_prefix} – diagonal")
        axs[0].set_xlabel("Index")
        axs[0].set_ylabel("Value")

        # 2. Heatmap da matriz
        im = axs[1].imshow(M, aspect='auto', cmap='viridis')
        axs[1].set_title(f"{title_prefix} – heatmap")
        plt.colorbar(im, ax=axs[1])

        # 3. Heatmap absoluto (concentração de valores)
        im2 = axs[2].imshow(np.abs(M), aspect='auto', cmap='inferno')
        axs[2].set_title(f"{title_prefix} – magnitude |M|")
        plt.colorbar(im2, ax=axs[2])

        plt.tight_layout()
        plt.show()

      print("=== C MATRIX ===")
      _visualize_matrix(self.C, title_prefix="C")

      print("=== U MATRIX ===")
      _visualize_matrix(self.U, title_prefix="U")

      print("=== D MATRIX ===")
      _visualize_matrix(self.D, title_prefix="D")

    def visualize_P_and_q(self):

      P = self.P
      q = self.q
      P = P if not issparse(P) else P.tocsr()
      q = q.reshape(-1)

      # ---- 1. SPY Plot ----
      plt.figure(figsize=(6, 6))
      if issparse(P):
        plt.spy(P, markersize=1)
      else:
        plt.imshow(P != 0, aspect="auto", cmap="gray_r")
      plt.title("Sparsity Pattern of P")
      plt.show()

      # ---- 2. Heatmap for dense or sampled ----
      if not issparse(P) and P.shape[0] <= 3000:
        plt.figure(figsize=(6, 5))
        plt.imshow(P, aspect='auto')
        plt.title("Heatmap of P values")
        plt.colorbar()
        plt.show()

      # ---- 3. Histogram of values ----
      P_vals = P.data if issparse(P) else P.flatten()
      plt.figure(figsize=(6,4))
      plt.hist(P_vals, bins=100)
      plt.title("Distribution of P entries")
      plt.show()

      plt.figure(figsize=(6,4))
      plt.hist(q, bins=50)
      plt.title("Distribution of q entries")
      plt.show()

      # ---- Boxplot & Violinplot for q ----
      plt.figure(figsize=(6,4))
      plt.boxplot(q)
      plt.title("Boxplot of q values")
      plt.show()

      plt.figure(figsize=(6,4))
      plt.violinplot(q, showmeans=True)
      plt.title("Violin Plot of q values")
      plt.show()

      # ---- 4. Symmetry check ----
      if not issparse(P):
        sym_error = norm(P - P.T)
        print(f"Symmetry error ||P - P^T||: {sym_error}")

      # ---- 5. Eigenvalues (smallest few) ----
      print("Estimating smallest eigenvalues of P...")
      try:
        if issparse(P):
          evals = eigsh(P, k=6, which='SA')[0]
        else:
          evals = np.sort(eigvals(P).real)[:6]
        print("Smallest eigenvalues:", evals)

        plt.figure(figsize=(6,4))
        plt.hist(evals, bins=10)
        plt.title("Histogram of smallest eigenvalues of P")
        plt.show()

        # ---- Boxplot & Violinplot for eigenvalues ----
        plt.figure(figsize=(6,4))
        plt.boxplot(evals)
        plt.title("Boxplot of smallest eigenvalues of P")
        plt.show()


        plt.figure(figsize=(6,4))
        plt.violinplot(evals, showmeans=True)
        plt.title("Violin Plot of smallest eigenvalues of P")
        plt.show()
        plt.title("Histogram of smallest eigenvalues of P")
        plt.show()

      except Exception as e:
        print("Eigenvalue analysis failed:", e)

      # ---- 6. Row norm profile ----
      if not issparse(P):
        row_norms = np.linalg.norm(P, axis=1)
      else:
        row_norms = np.sqrt(P.multiply(P).sum(axis=1)).A.flatten()


      plt.figure(figsize=(8,4))
      plt.plot(row_norms)
      plt.title("Row Norm Profile of P")
      plt.show()

      # ---- 7. Inspect q structure ----
      plt.figure(figsize=(8,4))
      plt.plot(q)
      plt.title("q vector profile")
      plt.show()

In [None]:
LAG_MPC = Init_MPC_Problem(DATA_PATH)

#### 2.1.1.1 Visualização Inicial

In [None]:
LAG_MPC.visualize_decision_vars()

In [None]:
LAG_MPC.visualize_duals()

In [None]:
LAG_MPC.visualize_bounds()

In [None]:
LAG_MPC.visualize_P_and_q()

In [None]:
LAG_MPC.visualize_C_and_U()

#### 2.1.1.2 Aplicação do Pré-processamento

In [None]:
LAG_MPC.check_P_matrix_condition()

In [None]:
LAG_MPC.apply_pre_processing()

In [None]:
LAG_MPC.check_P_matrix_condition()

In [None]:
LAG_MPC.visualize_decision_vars()

In [None]:
LAG_MPC.visualize_duals()

### 2.1.2 Problema Final
Não são aplicadas as transformações complexas em razão da não utilização das matrizes instáveis do problema na função de perda.
São realizadas:

- Normalização z-scalling nos estados iniciais e no primal

In [None]:
class MPC_Problem:

  def __init__(self, np_raw_path):

    raw = np.load(DATA_PATH)

    # Dimensions
    self.n_states = int(raw["Nx"])
    self.n_actuations = int(raw["Nu"])
    self.n_horizon = int(raw["Nh"])
    self.dim_z = (self.n_states + self.n_actuations) * self.n_horizon

    # Constraints
    self.n_constraints =  int(raw["n_con"])
    self.n_eq_constraints = int(raw["n_eq"])
    self.n_ineq_constraints = int(raw["n_in"])
    
    # Iterating through data to extract each instance:
    # -> Decision variables, Duals and Bounds

    self.input_data = raw["data"].astype(np.float32)
    it_pos = 0

    self.state_list = self.input_data[:, it_pos: it_pos + self.n_states]
    it_pos += self.n_states
    
    self.z_list = self.input_data[:, it_pos : it_pos + self.dim_z]
    it_pos += self.dim_z

    self.raw_dual = self.input_data[:, it_pos : it_pos + self.n_constraints]
    it_pos += self.n_constraints

    self.lower_bounds = self.input_data[:, it_pos : it_pos + self.n_constraints]
    it_pos += self.n_constraints

    self.upper_bounds = self.input_data[:, it_pos : it_pos + self.n_constraints]

    # Quadratic Programming Matrices
    self.P = raw["P"].astype(np.float64)
    self.q = raw["q"].astype(np.float64)
    self.C = raw["C"].astype(np.float64)
    self.U = raw["U"].astype(np.float64)

    # Constraints Matrix
    self.D = np.concatenate((self.C,  self.U), axis=0)

    # Duals Extraction
    self.lamb = np.maximum(-self.raw_dual, 0)
    self.nu = np.maximum(self.raw_dual, 0)

    # Dimension Check
    assert self.P.shape == (self.dim_z, self.dim_z), f"P must be [{self.dim_z},{self.dim_z}], got {self.P.shape}"
    assert self.q.shape == (self.dim_z,), f"q must be [{self.dim_z}], got {self.q.shape}"
    assert self.C.shape == (self.n_eq_constraints, self.dim_z), f"C must be [{self.n_eq_constraints},{self.dim_z}], got {self.Q.shape}"
    assert self.U.shape == (self.n_ineq_constraints, self.dim_z), f"U must be [{self.n_ineq_constraints},{self.dim_z}], got {self.U.shape}"

    N_samples = self.input_data.shape[0]
    assert self.state_list.shape[0] == N_samples, f"state_list must have {N_samples} samples, got {self.state_list.shape[0]}"
    assert self.z_list.shape[0] == N_samples, f"z_list must have {N_samples} samples, got {self.z_list.shape[0]}"
    assert self.raw_dual.shape[0] == N_samples, f"raw_dual must have {N_samples} samples, got {self.raw_dual.shape[0]}"
    assert self.upper_bounds.shape[0] == N_samples, f"upper_bounds must have {N_samples} samples, got {self.upper_bounds.shape[0]}"
    assert self.lower_bounds.shape[0] == N_samples, f"lower_bounds must have {N_samples} samples, got {self.lower_bounds.shape[0]}"
    
    assert self.state_list.shape[1] == self.n_states, f"state_list must have {self.n_states} features, got {self.state_list.shape[1]}"
    assert self.z_list.shape[1] == self.dim_z, f"z_list must have {self.dim_z} features, got {self.z_list.shape[1]}"
    assert self.raw_dual.shape[1] == self.n_constraints, f"raw_lag must have {self.n_constraints} features, got {self.raw_lag.shape[1]}"
    assert self.upper_bounds.shape[1] == self.n_constraints, f"upper_bounds must have {self.n_constraints} features, got {self.upper_bounds.shape[1]}"
    assert self.lower_bounds.shape[1] == self.n_constraints, f"lower_bounds must have {self.n_constraints} features, got {self.lower_bounds.shape[1]}"

    total_expected_cols = self.n_states + self.dim_z  + self.n_constraints * 3
    assert self.input_data.shape[1] == total_expected_cols, f"Total input columns must be {total_expected_cols}, got {self.input_data.shape[1]}. Check slicing logic."
    print(f"[CHECK] - Dimensions and Sampling Ok")

    # P nature checking
    eigvals = np.linalg.eigvalsh(self.P)
    assert np.all(eigvals >= -1e-8), f"P not positive semidefinite, min eig = {eigvals.min()}"
    print(f"[CHECK] - P Matrices are Positive Semidefinite")
    
  def apply_primal_std(self):
    """
    t(z) -> z' | z' = z - mean / std
    """
    self.z_mean = self.z_list.mean(axis=0)
    self.z_std = self.z_list.std(axis=0)
    self.z_list = (self.z_list - self.z_mean) / self.z_std

    # This process must be applied to the constraints ineq:
    # a <= D z' <= b -> a <= D(mean) + (D @ diag(std)) z <= b
    # a' = a - D(mean) (same for b)
    # D' = D @ diag(std)

    self.lower_bounds = self.lower_bounds - self.D @ self.z_mean # Old D
    self.upper_bounds = self.upper_bounds - self.D @ self.z_mean # Old D
    self.D = self.D @ np.diag(self.z_std)
    
    # Ignoring P and q propagations of t
  
  def apply_state_std(self):
    """
    t(z) -> z' | z' = z - mean / std
    """

    self.state_mean = self.state_list.mean(axis=0)
    self.state_std = self.state_list.std(axis=0)
    self.state_list = (self.state_list - self.state_mean) / self.state_std
       

In [None]:
MPC = MPC_Problem(DATA_PATH)

#### 2.1.2.1 Aplicação do Pré-processamento

In [None]:
MPC.apply_primal_std()
MPC.apply_state_std()

## 2.2 Definição do Conjunto de Dados
Tradução dos dados pré-processados a uma base de dados PyTorch, com o objetivo de propiciar o treinamento da rede posteriormente.

In [None]:
class MPC_DATASET(Dataset):

  def __init__(self, MPC):

    self.states = MPC.state_list
    self.decision_variables = MPC.z_list
    self.nu_lagrangians = MPC.nu
    self.lambda_lagrangians = MPC.lamb
    self.lower_bounds = MPC.lower_bounds
    self.upper_bounds = MPC.upper_bounds

    self.D = torch.from_numpy(MPC.D).float()

    # Pre-processing parameters

    # print(f"[CHECK] States list dimension: {self.states.shape}")
    # print(f"[CHECK] Decision Variables lis dimension: {self.decision_variables.shape}")
    # print(f"[CHECK] Nu Lagrangians dimension: {self.nu_lagrangians.shape}")
    # print(f"[CHECK] Lambda Lagrangians dimension: {self.lambda_lagrangians.shape}")

    # print(f"[CHECK] States list element list[i]: {self.states[0].shape}")
    # print(f"[CHECK] Decision Variables list element list[i]: {self.decision_variables[0].shape}")
    # print(f"[CHECK] Nu Lagrangians element list[i]: {self.nu_lagrangians[0].shape}")
    # print(f"[CHECK] Lambda Lagrangians element list[i]: {self.lambda_lagrangians[0].shape}")

  def __len__(self):
    return self.states.shape[0]

  def __getitem__(self, i):

    """
    Returning:
    -> x std
    -> z std
    -> z_inf
    -> z_sup
    """
    x = self.states[i]
    z = self.decision_variables[i]
    lb = self.lower_bounds[i]
    ub = self.upper_bounds[i]

    return {
        "x": torch.from_numpy(x).float(),
        "z": torch.from_numpy(z).float(),
        "lb": torch.from_numpy(lb).float(),
        "ub": torch.from_numpy(ub).float()
    }

# 3. Função de Perda

## 3.1 Função de Perda Lagrangiana

Definida como a média do quadrado da diferença do valor do Lagrangiano com o primal estimado e o primal ótimo. Utilizando o Lagrangiano aumentado, é possível incluir tanto a influência do valor da função objetivo - através do primal - quanto da contemplação das restrições - a partir de cada dual e dos termos inclusos na versão aumentada do lagrangiano.

$$
\ell(\theta) :=
\sum_{i=1}^{}
\Big(
    ℒ\big(\tilde{\pi}(x_i \mid \theta), \nu_i^{\ast}, \lambda_i^{\ast} \mid x_i\big)
    - ℒ\big(z_i^{\ast}, \nu_i^{\ast}, \lambda_i^{\ast} \mid x_i\big)
\Big)^2
$$

### Cálculo do Lagrangiano
Feito a partir da interpretação das matrizes C e U como matrizes das restrições de igualdade e desigualdade em zero.

$$
ℒ(z, \nu, \lambda) = \frac{1}{2}z^t\mathbf{P}z + \mathbf{q}^Tz + \lambda(\mathbf{D}z - z_{inf}) + \nu^t(z_{sup} - \mathbf{D}z) + \frac{\rho}{2}(\max{0, z_{inf} - \mathbf{D}z}) + \frac{\rho}{2}(\max{0, \mathbf{D}z} - z_{sup})
$$

### Problema Obtidos
Inúmero problemas foram obtidos com essa função. Foram realizados os testes:
- Correlação Valor de Loss e Gradiente para valores aleatórios (OK)
- Convergência para caso de Overfit - capacidade de aprender (OK)
- Distribuição de Probabilidade de Valor de Loss para valores aleatórios (Não OK)
    - Valor esperaldo alto mas estável
    - Valor explodindo frequentemente em termos quadráticos, duais e do aumentado
    - OBS: gráficos no relatório


In [None]:
class AugmentedLagrangianLoss(nn.Module):
    def __init__(self, P, q, D, rho):
        super().__init__()

        if q.dim() == 1:
            q = q.unsqueeze(-1)  # (n,1)

        self.register_buffer("P", P.clone().detach().float())   # (n,n)
        self.register_buffer("q", q.clone().detach().float())   # (n,1)
        self.register_buffer("D", D.clone().detach().float())   # (m,n)
        self.rho = float(rho)

        self.quadratic_list = []
        self.linear_list = []
        self.lambda_list = []
        self.nu_list = []
        self.aug_inf_list = []
        self.aug_sup_list = []

        self.quadratic_list_opt = []
        self.linear_list_opt = []
        self.lambda_list_opt = []
        self.nu_list_opt = []
        self.aug_inf_list_opt = []
        self.aug_sup_list_opt = []

    def visualize(self):
        """
        Plota lado a lado:
        - valores não ótimos (listas normais)
        - valores ótimos (listas *_opt)
        """

        # Nome das listas e títulos
        items = [
            ("quadratic_list", "quadratic_list_opt", "Quadratic Term"),
            ("linear_list", "linear_list_opt", "Linear Term"),
            ("lambda_list", "lambda_list_opt", "Lambda Term"),
            ("nu_list", "nu_list_opt", "Nu Term"),
            ("aug_inf_list", "aug_inf_list_opt", "Augmented (Inf)"),
            ("aug_sup_list", "aug_sup_list_opt", "Augmented (Sup)"),
        ]

        for attr_non_opt, attr_opt, title in items:

            # --- Recupera as listas ---
            list_non_opt = getattr(self, attr_non_opt)
            list_opt = getattr(self, attr_opt)

            # Converter tudo para numpy arrays 1D
            def to_numpy_list(L):
                processed = []
                for x in L:
                    if hasattr(x, "detach"):   # Tensor
                        x = x.detach().cpu().numpy()
                    if isinstance(x, np.ndarray):
                        processed.append(x.flatten())
                    else:
                        processed.append(np.array([float(x)]))
                if len(processed) == 0:
                    return np.array([])
                return np.concatenate(processed)

            arr_non_opt = to_numpy_list(list_non_opt)
            arr_opt = to_numpy_list(list_opt)

            # --- Criar figura com dois subplots ---
            fig, axes = plt.subplots(1, 2, figsize=(12, 4))
            fig.suptitle(title, fontsize=14, fontweight="bold")

            # Subplot da versão não ótima
            axes[0].plot(arr_non_opt)
            axes[0].set_title("Pred (não ótimo)")
            axes[0].set_xlabel("Índice")
            axes[0].set_ylabel("Valor")
            axes[0].grid(True, alpha=0.3)

            # Subplot da versão ótima
            axes[1].plot(arr_opt)
            axes[1].set_title("Opt (ótimo)")
            axes[1].set_xlabel("Índice")
            axes[1].grid(True, alpha=0.3)

            plt.tight_layout()
            plt.show()

    def _ensure_batch_bound(self, b, B):
        """Ensure bound is (B,m)."""
        if b is None:
            return None
        if b.dim() == 1:
            return b.unsqueeze(0).expand(B, -1)
        elif b.dim() == 2:
            return b
        else:
            raise ValueError("bound must be (m,) or (B,m).")

    def augmented_lagrangian(self, z, lam_inf, lam_sup, lower_bound, upper_bound, pred=False):
        """
        z:        (B,n)
        lam_inf:  (B,m)
        lam_sup:  (B,m)
        lower_bound: (m,) or (B,m)
        upper_bound: (m,) or (B,m)

        output: (B,) augmented Lagrangian value
        """
        dtype = z.dtype
        P = self.P.to(dtype)
        q = self.q.to(dtype).squeeze(-1)   # (n,)
        D = self.D.to(dtype)

        B = z.shape[0]
        lb = self._ensure_batch_bound(lower_bound, B)  # (B,m)
        ub = self._ensure_batch_bound(upper_bound, B)  # (B,m)

        # ---------------------------------------------------------
        # QUADRATIC: 0.5 zᵀ P z
        # ---------------------------------------------------------
        zP = torch.matmul(z, P)                   # (B,n)
        quad = 0.5 * torch.sum(zP * z, dim=1)     # (B,)

        # ---------------------------------------------------------
        # LINEAR: qᵀ z
        # ---------------------------------------------------------
        lin = torch.sum(q * z, dim=1)             # (B,)

        # ---------------------------------------------------------
        # D z
        # ---------------------------------------------------------
        Dz = torch.matmul(z, D.t())               # (B,m)

        # ---------------------------------------------------------
        # Lagrangian linear terms
        # λ_infᵀ (l - Dz)
        # λ_supᵀ (Dz - u)
        # ---------------------------------------------------------
        lag_inf = torch.sum(lam_inf * (Dz - lb), dim=1)
        lag_sup = torch.sum(lam_sup * (ub - Dz), dim=1)

        # ---------------------------------------------------------
        # Augmented terms: ρ/2 ||max(0, l - Dz)||²
        #                  ρ/2 ||max(0, Dz - u)||²
        # ---------------------------------------------------------
        viol_inf = torch.relu(lb - Dz)            # (B,m)
        viol_sup = torch.relu(Dz - ub)            # (B,m)

        aug_inf = 0.5 * self.rho * torch.sum(viol_inf ** 2, dim=1)
        aug_sup = 0.5 * self.rho * torch.sum(viol_sup ** 2, dim=1)

        if pred:
            self.quadratic_list.append(quad.mean().item())
            self.linear_list.append(lin.mean().item())
            self.lambda_list.append(lag_inf.mean().item())
            self.nu_list.append(lag_sup.mean().item())
            self.aug_inf_list.append(aug_inf.mean().item())
            self.aug_sup_list.append(aug_sup.mean().item())
        else:
            self.quadratic_list_opt.append(quad.mean().item())
            self.linear_list_opt.append(lin.mean().item())
            self.lambda_list_opt.append(lag_inf.mean().item())
            self.nu_list_opt.append(lag_sup.mean().item())
            self.aug_inf_list_opt.append(aug_inf.mean().item())
            self.aug_sup_list_opt.append(aug_sup.mean().item())


        return quad + lin + lag_inf + lag_sup + aug_inf + aug_sup
    
    def augmented_lagrangian_parsed(self, z, lam_inf, lam_sup, lower_bound, upper_bound):
        """
        z:        (B,n)
        lam_inf:  (B,m)
        lam_sup:  (B,m)
        lower_bound: (m,) or (B,m)
        upper_bound: (m,) or (B,m)

        output: (B,) augmented Lagrangian value
        """
        dtype = z.dtype
        P = self.P.to(dtype)
        q = self.q.to(dtype).squeeze(-1)   # (n,)
        D = self.D.to(dtype)

        B = z.shape[0]
        lb = self._ensure_batch_bound(lower_bound, B)  # (B,m)
        ub = self._ensure_batch_bound(upper_bound, B)  # (B,m)

        # ---------------------------------------------------------
        # QUADRATIC: 0.5 zᵀ P z
        # ---------------------------------------------------------
        zP = torch.matmul(z, P)                   # (B,n)
        quad = 0.5 * torch.sum(zP * z, dim=1)     # (B,)

        # ---------------------------------------------------------
        # LINEAR: qᵀ z
        # ---------------------------------------------------------
        lin = torch.sum(q * z, dim=1)             # (B,)

        # ---------------------------------------------------------
        # D z
        # ---------------------------------------------------------
        Dz = torch.matmul(z, D.t())               # (B,m)

        # ---------------------------------------------------------
        # Lagrangian linear terms
        # λ_infᵀ (l - Dz)
        # λ_supᵀ (Dz - u)
        # ---------------------------------------------------------
        lag_inf = torch.sum(lam_inf * (lb - Dz), dim=1)
        lag_sup = torch.sum(lam_sup * (Dz - ub), dim=1)

        # ---------------------------------------------------------
        # Augmented terms: ρ/2 ||max(0, l - Dz)||²
        #                  ρ/2 ||max(0, Dz - u)||²
        # ---------------------------------------------------------
        viol_inf = torch.relu(lb - Dz)            # (B,m)
        viol_sup = torch.relu(Dz - ub)            # (B,m)

        aug_inf = 0.5 * self.rho * torch.sum(viol_inf ** 2, dim=1)
        aug_sup = 0.5 * self.rho * torch.sum(viol_sup ** 2, dim=1)

        return quad, lin, lag_inf/1e4, lag_sup/1e4, aug_inf/1e7, aug_sup/1e7

    def forward(self, z_pred, z_opt, lam_inf, lam_sup, lower_bound, upper_bound):
        """
        Loss = mean( (L_aug(z_pred) - L_aug(z_opt))² )
        """
        L_pred = self.augmented_lagrangian(z_pred, lam_inf, lam_sup,
                                           lower_bound, upper_bound, True)
        L_opt = self.augmented_lagrangian(z_opt, lam_inf, lam_sup,
                                          lower_bound, upper_bound)
        

        return ((L_pred - L_opt) ** 2).mean()
    
    def foward_parsed(self, z_pred, z_opt, lam_inf, lam_sup, lower_bound, upper_bound):
        """
        Loss = mean( (L_aug(z_pred) - L_aug(z_opt))² )
        """
        q_1, l_1, lag_min_1, lag_max_1, aug_min_1, aug_max_1 = \
            self.augmented_lagrangian_parsed(
                z_pred, lam_inf, lam_sup, lower_bound, upper_bound
            )

        q_2, l_2, lag_min_2, lag_max_2, aug_min_2, aug_max_2 = \
            self.augmented_lagrangian_parsed(
                z_opt, lam_inf, lam_sup, lower_bound, upper_bound
            )
        
        print(f"LAG PRED: {q_1 + l_1 + lag_min_1 + lag_max_1 + aug_min_1 + aug_max_1}")
        print(f"LAG OPT: {q_2 + l_2 + lag_min_2 + lag_max_2 + aug_min_2 + aug_max_2}")

        def mse(x, y):
            return ((x - y) ** 2).mean()

        return {
            "mse_q": mse(q_1, q_2).item(),
            "mse_l": mse(l_1, l_2).item(),
            "mse_lag_min": mse(lag_min_1, lag_min_2).item(),
            "mse_lag_max": mse(lag_max_1, lag_max_2).item(),
            "mse_aug_min": mse(aug_min_1, aug_min_2).item(),
            "mse_aug_max": mse(aug_max_1, aug_max_2).item(),
        }

## 3.2 Loss Final - Primal e Restrições

Tentando emular o comportamento da função anterior, que é sensível à otimalidade e à factibilidade da solução, sem cálculos complexos que podem introduzir instabilidade e complexidade, foi construída essa segunda função:
$$
\ell(\theta) :=
\sum_{i=1}^{}
\Big(
    (\tilde{\pi}(x_i \mid \theta) - z^*) + \max{(0, z_{inf} - \mathbf{D}z)} + \max{(0, \mathbf{D}z - z_{sup})}
\Big)^2
$$

- A diferença entre o primal estimado e o ótimo contribui para o atendimento do requisito de otimalidade
- A incorporação dos termos do lagrangiano aumentado ajudam a penalizar resultados que violam as restrições do problema

In [None]:
class QPLoss(nn.Module):
    def __init__(self, D, lam=1e-4):
        """
        D : matriz de restrições (m, n) — fixa para todo o dataset
        lam : peso da penalização das violações
        """
        super().__init__()
        self.lam = lam
        
        self.register_buffer("D", D)

        self.primal_list = []
        self.penalty_list = []

    def visualize(self, show_total=True):
        """
        Plota a evolução dos termos da loss ao longo do treinamento.
        """

        if len(self.primal_list) == 0:
            print("Nenhum valor armazenado. Chame forward(..., store=True) durante o treino.")
            return

        plt.figure(figsize=(10, 5))

        plt.plot(self.primal_list, label="Primal (MSE)", linewidth=2)
        plt.plot(self.penalty_list, label=f"Penalty (λ * restrições)", linewidth=2)

        if show_total:
            total = [p + q for p, q in zip(self.primal_list, self.penalty_list)]
            plt.plot(total, label="Total Loss", linestyle="--", linewidth=2)

        plt.xlabel("Iterações de Treino")
        plt.ylabel("Valores")
        plt.title("Evolução dos termos da Loss QP")
        plt.legend()
        plt.grid(True)
        plt.show()

    def forward(self, z_pred, z_star, a, b, store=False):

        # 1. MSE
        mse = torch.mean((z_pred - z_star)**2)

        # 2. D no mesmo device que z_pred
        D = self.D.to(z_pred.device)
        Dz = z_pred @ D.T

        viol_sup = torch.relu(Dz - b)
        viol_inf = torch.relu(a - Dz)

        penalty = torch.mean(viol_sup**2 + viol_inf**2)

        loss = mse + self.lam * penalty

        if store:
            self.primal_list.append(mse.item())
            self.penalty_list.append((self.lam * penalty).item())

        return loss

### 3.2.1 Validação da Função de Perda

#### 3.2.1.1 Distribuição de Probabilidade para Valores Aleatórios

In [None]:

dataset = MPC_DATASET(MPC)
criterion = QPLoss(dataset.D, lam=1e-4).to(DEVICE)

def sample_random_z(batch, n, scale=1.0):
    return scale * torch.randn(batch, n)

loss_values = []
N = 1000

with torch.no_grad():
    for i in range(0, N):

        print(f"\n########### SAMPLE {i} ############")

        sample = dataset[i]
        z_opt = sample["z"].to(DEVICE)
        lb    = sample["lb"].to(DEVICE)
        ub    = sample["ub"].to(DEVICE)

        z_opt = z_opt.unsqueeze(0)
        lb    = lb.unsqueeze(0)
        ub    = ub.unsqueeze(0)

        z_rand = sample_random_z(batch=64, n=z_opt.shape[1]).to(DEVICE)

        print(f"[z_pred exemplo]: {z_rand[0, :5].cpu().numpy()} ...")
        print(f"[z_opt exemplo]:  {z_opt[0, :5].cpu().numpy()} ...")

        loss = criterion(z_rand, z_opt, lb, ub, store=True)
        loss_values.append(loss.item())

plt.hist(loss_values, bins=40)
plt.title(f"Distribuição da loss para z aleatórios (N instâncias)")
plt.show()

criterion.visualize()

#### 3.2.1.2 Ponto Ótimo como Mínimo Local

In [None]:
def test_loss_zero_at_optimum(criterion, dataset, device):
    zero_losses = []

    with torch.no_grad():
        for i in range(len(dataset)):
            sample = dataset[i]
            z_opt = sample["z"].to(device).unsqueeze(0)
            lb    = sample["lb"].to(device).unsqueeze(0)
            ub    = sample["ub"].to(device).unsqueeze(0)

            # loss no ótimo
            loss = criterion(z_opt, z_opt, lb, ub).item()
            zero_losses.append(loss)

    print("\n=== Teste: Loss no ótimo ===")
    print(f"Média  : {sum(zero_losses)/len(zero_losses):.6e}")
    print(f"Máximo : {max(zero_losses):.6e}")
    print(f"Mínimo : {min(zero_losses):.6e}")

    plt.hist(zero_losses, bins=40)
    plt.title("Loss(z_opt, z_opt) — Deve ser zero")
    plt.show()

dataset = MPC_DATASET(MPC)
criterion = QPLoss(dataset.D).to(DEVICE)
test_loss_zero_at_optimum(criterion, dataset, DEVICE)

#### 3.2.1.3 Validação da Curvatura em torno do ótimo para garantir ausência de mínimos locais

In [None]:
def test_local_curvature(criterion, dataset, device, epsilons=None):
    if epsilons is None:
        epsilons = [1e-3, 1e-2, 1e-1, 1.0]

    losses_per_eps = {eps: [] for eps in epsilons}

    with torch.no_grad():
        for i in range(10000):
            sample = dataset[i]
            z_opt = sample["z"].to(device).unsqueeze(0)
            lb    = sample["lb"].to(device).unsqueeze(0)
            ub    = sample["ub"].to(device).unsqueeze(0)

            n = z_opt.shape[1]

            direction = torch.randn(1, n).to(device)
            direction = direction / (torch.norm(direction) + 1e-9)

            for eps in epsilons:
                z_test = z_opt + eps * direction
                loss = criterion(z_test, z_opt, lb, ub).item()
                losses_per_eps[eps].append(loss)

    # Plot
    plt.figure(figsize=(7,5))
    mean_losses = []

    for eps in epsilons:
        mean_loss = sum(losses_per_eps[eps]) / len(losses_per_eps[eps])
        mean_losses.append(mean_loss)
        plt.plot([eps], [mean_loss], "o", label=f"eps={eps}, mean={mean_loss:.2e}")

    plt.plot(epsilons, mean_losses, "-k", linewidth=1)
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("Perturbação ε (log scale)")
    plt.ylabel("Loss média (log scale)")
    plt.title("Monotonicidade local da Loss ao se afastar do primal")
    plt.legend()
    plt.grid(True, which="both", ls="--")
    plt.show()


dataset = MPC_DATASET(MPC)
criterion = QPLoss(dataset.D).to(DEVICE)
test_local_curvature(criterion, dataset, DEVICE)

#### 3.2.1.4 Estabilidade da Norma do Gradiente

In [None]:
def test_gradient_explosion(criterion, dataset, device, sigmas=None):
    if sigmas is None:
        sigmas = [1e-2, 1e-1, 1, 10, 100]

    grad_norms_by_sigma = {s: [] for s in sigmas}

    for i in range(1000):
        sample = dataset[i]
        z_opt = sample["z"].to(device).unsqueeze(0)
        lb    = sample["lb"].to(device).unsqueeze(0)
        ub    = sample["ub"].to(device).unsqueeze(0)

        n = z_opt.shape[1]

        for sigma in sigmas:
            # cria z_pred aleatório com variância controlada
            z_pred = (sigma * torch.randn(1, n)).to(device)
            z_pred.requires_grad_(True)

            loss = criterion(z_pred, z_opt, lb, ub)
            loss.backward()

            grad = z_pred.grad.detach()
            grad_norm = torch.norm(grad).item()

            grad_norms_by_sigma[sigma].append(grad_norm)

    # ---- Plot ----
    plt.figure(figsize=(7,5))

    means = []
    for sigma in sigmas:
        mean_norm = sum(grad_norms_by_sigma[sigma]) / len(grad_norms_by_sigma[sigma])
        means.append(mean_norm)
        plt.plot(sigma, mean_norm, "o", label=f"sigma={sigma}, mean={mean_norm:.2e}")

    plt.plot(sigmas, means, "-k")
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("Escala do z_pred aleatório (sigma)")
    plt.ylabel("Norma média do gradiente")
    plt.title("Exploding Gradient Test da QPLoss")
    plt.grid(True, which="both", ls="--")
    plt.legend()
    plt.show()

dataset = MPC_DATASET(MPC)
criterion = QPLoss(dataset.D).to(DEVICE)
test_gradient_explosion(criterion, dataset, DEVICE)

# 4. Definição da Rede e Exploração

Construção da rede básica conforme as especificações do artigo, com a dimensionalidade das camadas ocultas parametrizada.

- Função de Ativação: ReLU - indicada no paper para a aproximação de problemas QP Convexos

In [None]:
class PlannerNet(nn.Module):

    def __init__(self, input_dim, output_dim, hidden=(64, 64)):

        super().__init__()
        layers = []
        prev = input_dim

        for h in hidden:
            linear = nn.Linear(prev, h)
            nn.init.xavier_uniform_(linear.weight,
                                    gain=nn.init.calculate_gain("relu"))
            nn.init.zeros_(linear.bias)
            layers += [linear, nn.ReLU()]
            prev = h

        out = nn.Linear(prev, output_dim)
        nn.init.xavier_uniform_(out.weight,
                                gain=1.0)
        nn.init.zeros_(out.bias)

        layers.append(out)
        self.net = nn.Sequential(*layers)

    def forward(self, x_std):
        return self.net(x_std)

## 4.1 Treinamento Exploratório

Para validação do modelo, é realizado um treinamento base para explorar o comportamento da função de perda e do gradiente, com o objetivo de verificar a estabilidade para o treinamento final

In [None]:
def grad_norm(model):
    total = 0
    for p in model.parameters():
        if p.grad is not None:
            total += p.grad.detach().norm(2).item()**2
    return total**0.5


def grad_direction(model, prev_grad):
    g = []
    for p in model.parameters():
        if p.grad is not None:
            g.append(p.grad.detach().flatten())
    g = torch.cat(g)

    if prev_grad is None:
        return None, g

    cos = torch.nn.functional.cosine_similarity(g, prev_grad, dim=0).item()
    return cos, g

@torch.no_grad()
def evaluate_model(model, loss_fn, loader, device="cpu"):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for batch in loader:
            x = batch["x"].to(device)
            z = batch["z"].to(device)
            lb = batch["lb"].to(device)
            ub = batch["ub"].to(device)

            z_pred = model(x)
            loss = loss_fn(z_pred, z, lb, ub, store=False)

            total_loss += loss.item()

    return total_loss / len(loader)

def plot_training_history(history):

    fig, axs = plt.subplots(3, 1, figsize=(10, 14))

    axs[0].plot(history["train_loss"], label="Train Loss", linewidth=2)
    axs[0].plot(history["val_loss"], label="Val Loss", linewidth=2)
    axs[0].set_title("Loss no Treino e Teste")
    axs[0].set_xlabel("Época")
    axs[0].set_ylabel("Loss")
    axs[0].grid(True)
    axs[0].legend()

    axs[1].plot(history["grad_norm"], label="Norma do Gradiente")
    axs[1].set_title("Norma do Gradiente ao Longo do Tempo")
    axs[1].set_xlabel("Iteração")
    axs[1].set_ylabel("||g||")
    axs[1].grid(True)

    axs[2].plot(history["grad_cosine"], label="Cosine Similarity", color="green")
    axs[2].set_title("Direção do Gradiente (similaridade coseno)")
    axs[2].set_xlabel("Iteração")
    axs[2].set_ylabel("cos(g_t, g_{t-1})")
    axs[2].grid(True)

    plt.tight_layout()
    plt.show()

def train_model(
    model,
    loss_fn,
    train_loader,
    val_loader,
    optimizer,
    epochs=50,
    device="cpu",
    patience=75,
    min_delta=1e-6
):
    model.to(device)
    history = {
        "train_loss": [],
        "val_loss": [],
        "grad_norm": [],
        "grad_cosine": []
    }

    prev_grad = None
    best_val_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0

        for batch in train_loader:
            x = batch["x"].to(device)
            z = batch["z"].to(device)
            lb = batch["lb"].to(device)
            ub = batch["ub"].to(device)

            optimizer.zero_grad()
            z_pred = model(x)
            loss = loss_fn(z_pred, z, lb, ub, store=True)
            loss.backward()

            # --- Compute gradient norm ---
            grad_vec = torch.cat([
                p.grad.reshape(-1) for p in model.parameters() if p.grad is not None
            ])
            grad_norm = grad_vec.norm().item()
            history["grad_norm"].append(grad_norm)

            # --- Compute gradient direction (cosine similarity) ---
            if prev_grad is not None:
                cosine = torch.nn.functional.cosine_similarity(
                    grad_vec, prev_grad, dim=0
                ).item()
            else:
                cosine = 1.0
            history["grad_cosine"].append(cosine)
            prev_grad = grad_vec.detach()

            optimizer.step()
            epoch_loss += loss.item()
        
        epoch_loss /= len(train_loader)
        history["train_loss"].append(epoch_loss)

        val_loss = evaluate_model(model, loss_fn, val_loader, device)
        history["val_loss"].append(val_loss)

        print(
            f"[Epoch {epoch+1:03d}] "
            f"Train Loss = {epoch_loss:.6f} | "
            f"Val Loss = {val_loss:.6f} | "
            f"Grad Norm = {grad_norm:.3e} | "
            f"Cosine Dir = {cosine:.3f}"
        )

        # --- Early Stopping ---
        if val_loss + min_delta < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"\nEarly stopping ativado após {epoch+1} epochs sem melhora.\n")
                model.load_state_dict(best_model_state)
                break

    print("\nTreino finalizado! Gerando gráficos...\n")
    plot_training_history(history)

    print("\nGerando gráficos específicos da QPLoss...")
    loss_fn.visualize(show_total=True)

    return history

In [None]:
DATASET = MPC_DATASET(MPC)

train_size = int(0.8 * len(DATASET))
val_size   = len(DATASET) - train_size

train_dataset, val_dataset = random_split(DATASET, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader   = DataLoader(val_dataset,  batch_size=256, shuffle=False)

model = PlannerNet(MPC.n_states, MPC.dim_z)
loss_fn = QPLoss(DATASET.D)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

history = train_model(
    model=model,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    epochs=100,
    device=DEVICE
)

# 5. Modelo Definitivo e Treinamento Final

Para verificar a melhor possibilidade de modelos, foi implementada uma metodologia de exploração a partir de um Grid Search, definindo um espaço de possibilidade paramétrico de NN, a partir do qual cada cenário será treinado com K-Folding e avaliado conforme o menor valor esperado de função de perda no conjunto de teste.

- **OBS**: Como são muitas possibilidades e o processo de treinamento é moroso, a célula de search demora muito, cerca de 16 horas no computador utilizado.

## 5.1 Funções de Treino

In [None]:
def train_one_epoch(model, loss_fn, loader, optimizer, device, clip_value=None, store_loss=False):
    model.train()
    running_loss = 0.0
    n_batches = 0

    for batch in loader:
        x = batch["x"].to(device)
        z = batch["z"].to(device)
        lb = batch["lb"].to(device)
        ub = batch["ub"].to(device)

        optimizer.zero_grad()
        z_pred = model(x)
        loss = loss_fn(z_pred, z, lb, ub, store=store_loss)
        loss.backward()

        if clip_value is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

        optimizer.step()

        running_loss += loss.item()
        n_batches += 1

    return running_loss / max(1, n_batches)


@torch.no_grad()
def evaluate(model, loss_fn, loader, device):
    model.eval()
    running_loss = 0.0
    n_batches = 0

    for batch in loader:
        x = batch["x"].to(device)
        z = batch["z"].to(device)
        lb = batch["lb"].to(device)
        ub = batch["ub"].to(device)

        z_pred = model(x)
        loss = loss_fn(z_pred, z, lb, ub, store=False)
        running_loss += loss.item()
        n_batches += 1

    return running_loss / max(1, n_batches)


from tqdm import tqdm

def k_fold_train(
    dataset,
    model_fn: Callable[[], torch.nn.Module],
    loss_fn_fn: Callable[[], torch.nn.Module],
    k: int = 5,
    num_epochs: int = 100,
    batch_size: int = 256,
    lr: float = 3e-4,
    weight_decay: float = 1e-5,
    device: str = None,
    patience: int = 100,
    min_delta: float = 1e-6,
    clip_value: float = 1.0,
    verbose: bool = True,
    use_tqdm_folds: bool = True,
    use_tqdm_epochs: bool = False
) -> Dict[str, Any]:

    if device is None:
        device = "mps" if torch.mps.is_available() else "cpu"

    n = len(dataset)
    indices = np.arange(n)
    rng = np.random.RandomState(SEED)
    rng.shuffle(indices)

    # Compute fold sizes
    fold_sizes = np.full(k, n // k, dtype=int)
    fold_sizes[: n % k] += 1
    current = 0

    all_fold_histories = []
    best_model_states = []

    # -----------------------
    # Setup tqdm for FOLDS
    # -----------------------
    fold_range = range(k)
    if use_tqdm_folds:
        fold_range = tqdm(range(k), desc="K-Fold", leave=False)

    for fold in fold_range:

        start, stop = current, current + fold_sizes[fold]
        val_idx = indices[start:stop]
        train_idx = np.setdiff1d(indices, val_idx)
        current = stop

        train_subset = Subset(dataset, train_idx.tolist())
        val_subset = Subset(dataset, val_idx.tolist())

        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

        model = model_fn().to(device)
        loss_fn = loss_fn_fn()
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

        best_val = float("inf")
        best_state = None
        patience_counter = 0

        history = {"train_loss": [], "val_loss": []}

        if verbose and not use_tqdm_folds:
            print(f"\n=== Fold {fold+1}/{k} | train {len(train_idx)} samples | val {len(val_idx)} samples ===")

        # -----------------------
        # tqdm for EPOCHS
        # -----------------------
        epoch_range = range(1, num_epochs + 1)
        if use_tqdm_epochs:
            epoch_range = tqdm(epoch_range, desc=f"Fold {fold+1}", leave=False)

        for epoch in epoch_range:
            train_loss = train_one_epoch(
                model, loss_fn, train_loader, optimizer,
                device, clip_value=clip_value, store_loss=True
            )

            val_loss = evaluate(model, loss_fn, val_loader, device)

            history["train_loss"].append(train_loss)
            history["val_loss"].append(val_loss)

            if verbose and not use_tqdm_epochs:
                print(
                    f"Fold {fold+1} | Epoch {epoch:03d} "
                    f"| Train {train_loss:.6f} | Val {val_loss:.6f} | Best {best_val:.6f}"
                )

            # early stopping
            if val_loss + min_delta < best_val:
                best_val = val_loss
                best_state = copy.deepcopy(model.state_dict())
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                if verbose:
                    print(f"Early stopping on fold {fold+1} at epoch {epoch}.")
                break

        all_fold_histories.append(history)
        best_model_states.append({
            "fold": fold,
            "best_val_loss": best_val,
            "state_dict": best_state
        })

    result = {
        "fold_histories": all_fold_histories,
        "best_models": best_model_states,
        "params": {
            "k": k,
            "num_epochs": num_epochs,
            "batch_size": batch_size,
            "lr": lr,
            "weight_decay": weight_decay,
            "patience": patience,
            "clip_value": clip_value
        }
    }

    return result

## 5.2 Funções de Search

In [None]:
def generate_param_combinations(param_grid):
    keys = list(param_grid.keys())
    values = list(param_grid.values())
    for combo in product(*values):
        params = dict(zip(keys, combo))
        yield params
        
def grid_search(param_grid, k=5, num_epochs=300):

    all_results = []
    best_result = None
    best_avg_val_loss = float("inf")

    def mean_best_val_loss(result):
        fold_histories = result["fold_histories"]
        best_vals = [min(f["val_loss"]) for f in fold_histories]
        return sum(best_vals) / len(best_vals)

    param_combinations = list(generate_param_combinations(param_grid))

    print(f"\n=== Iniciando Grid Search com {len(param_combinations)} combinações ===\n")

    for params in tqdm(param_combinations, desc="Grid Search"):

        dataset = MPC_DATASET(MPC)

        result = k_fold_train(
            dataset=dataset,
            model_fn=lambda : PlannerNet(
                MPC.n_states,
                MPC.dim_z,
                (params["hidden_dim"], params["hidden_dim"])
            ),
            loss_fn_fn=lambda : QPLoss(dataset.D),
            k=k,
            num_epochs=num_epochs,
            batch_size=params["batch_size"],
            lr=params["lr"],
            weight_decay=params["weight_decay"],
            patience=75,
            min_delta=1e-6,
            clip_value=1.0,
            verbose=True,
            use_tqdm_folds=False
        )

        avg_val = mean_best_val_loss(result)

        tqdm.write(f"Params={params} | Média da Val Loss={avg_val:.6f}")

        entry = {"params": params, "result": result}
        all_results.append(entry)

        if avg_val < best_avg_val_loss:
            best_avg_val_loss = avg_val
            best_result = entry

    return {
        "all_results": all_results,
        "best_result": best_result,
        "best_val_loss": best_avg_val_loss
    }

In [None]:
param_grid = {
    "hidden_dim": [32, 64, 128],
    "lr": [1e-3, 5e-4, 1e-4],
    "batch_size": [32, 64, 128],
    "weight_decay": [0.0, 1e-5],
}

grid_search_result = grid_search(param_grid)

## 5.3 Treinamento Final

In [None]:
print(grid_search_result["best_result"]["params"])

In [None]:
DATASET = MPC_DATASET(MPC)

# params = grid_search_result["best_result"]["params"]
params = {'hidden_dim': 128, 'lr': 0.001, 'batch_size': 64, 'weight_decay': 1e-05}

train_size = int(0.8 * len(DATASET))
val_size   = len(DATASET) - train_size

train_dataset, val_dataset = random_split(DATASET, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True)
val_loader   = DataLoader(val_dataset,  batch_size=params["batch_size"], shuffle=False)

model = PlannerNet(MPC.n_states, MPC.dim_z, (params["hidden_dim"], params["hidden_dim"]))
loss_fn = QPLoss(DATASET.D)

optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"], weight_decay=params["weight_decay"])

history = train_model(
    model=model,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    epochs=300,
    device=DEVICE
)

## 5.4 Resultados Modelo

In [None]:
model.eval()
val_loss = 0
all_preds = []
all_targets = []

with torch.no_grad():
    for batch in val_loader:
        x = batch["x"].to(DEVICE)
        z = batch["z"].to(DEVICE)   # alvo
        # Se houver limites:
        lb = batch["lb"].to(DEVICE)
        ub = batch["ub"].to(DEVICE)

        y_pred = model(x)
        loss = loss_fn(y_pred, z, lb, ub)  # adapte para sua loss
        val_loss += loss.item() * x.size(0)

        all_preds.append(y_pred.cpu().numpy())
        all_targets.append(z.cpu().numpy())

val_loss /= len(val_dataset)
print(f"Mean Validation Loss: {val_loss:.4f}")

all_preds = np.concatenate(all_preds, axis=0)
all_targets = np.concatenate(all_targets, axis=0)

mse = mean_squared_error(all_targets, all_preds)
mae = mean_absolute_error(all_targets, all_preds)
r2  = r2_score(all_targets, all_preds)

errors = all_preds - all_targets

plt.figure(figsize=(8,4))
plt.plot(errors)
plt.xlabel("Sample index")
plt.ylabel("Prediction error")
plt.title("Prediction Errors on Validation Set")
plt.grid(True)
plt.show()

plt.figure(figsize=(6,6))
plt.scatter(all_targets, all_preds, alpha=0.5)
plt.plot([all_targets.min(), all_targets.max()],
         [all_targets.min(), all_targets.max()], 'r--')
plt.xlabel("Target")
plt.ylabel("Prediction")
plt.title("Predictions vs Targets")
plt.grid(True)
plt.show()

all_preds = np.concatenate(all_preds, axis=0)
all_targets = np.concatenate(all_targets, axis=0)
errors = all_preds - all_targets
plt.figure(figsize=(8,4))
plt.hist(errors, bins=50, color='skyblue', edgecolor='black')
plt.xlabel("Prediction error")
plt.ylabel("Frequency")
plt.title("Histogram of Prediction Errors")
plt.grid(True)
plt.show()

print(f"MSE: {mse:.4f}, MAE: {mae:.4f}, R2: {r2:.4f}")

# 6. Exportando Modelo

## 6.1 Modelo Cru

In [None]:
dataset = MPC_DATASET(MPC)
sample = dataset[0]
init_state_sample = sample["x"].unsqueeze(0)

torch.onnx.export(
    model,
    init_state_sample,
    "planner_raw.onnx",
    input_names=["x"],
    output_names=["z_pred"],
    opset_version=17
)

## 6.2 Modelo Quantizado

In [None]:
input_model = "planner_raw.onnx"
output_model = "planner_quant.onnx"

quantize_dynamic(
    model_input=input_model,
    model_output=output_model,
    weight_type=QuantType.QInt8,
    optimize_model=True
)