In [1]:
from Simulation.mpc import *
from TD3Agent.agent import *
from Simulation.system_functions import PolymerCSTR
from BasicFunctions.td3_functions import *
from TD3Agent.replay_buffer import ReplayDataset, DataLoader

In [2]:
import datetime
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
from torch.optim.lr_scheduler import StepLR
from TD3Agent.actor import Actor
from TD3Agent.critic import Critic
from TD3Agent.replay_buffer import ReplayBuffer

class Agent(object):
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 actor_layer_sizes: list,
                 critic_layer_sizes: list,
                 buffer_capacity: int,
                 actor_lr: float,
                 critic_lr: float,
                 target_policy_smoothing_noise_std: float,
                 noise_clip: float,
                 exploration_noise_std: float,
                 gamma: float,
                 tau: float,
                 max_action: float,
                 policy_delay: int,
                 device):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.actor_layer_sizes = actor_layer_sizes
        self.critic_layer_sizes = critic_layer_sizes
        self.buffer_capacity = buffer_capacity

        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.t_std = target_policy_smoothing_noise_std
        self.noise_clip = noise_clip
        self.exploration_noise_std = exploration_noise_std
        self.gamma = gamma
        self.tau = tau

        self.max_action = max_action
        self.policy_delay = policy_delay
        self.total_it = 0

        self.device = device

        self.actor = Actor(self.state_dim, self.action_dim, self.actor_layer_sizes).to(self.device)
        self.actor_target = Actor(self.state_dim, self.action_dim, self.actor_layer_sizes).to(self.device)

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(param.data)

        self.critic = Critic(self.state_dim, self.action_dim, self.critic_layer_sizes).to(self.device)
        self.critic_target = Critic(self.state_dim, self.action_dim, self.critic_layer_sizes).to(self.device)

        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(param.data)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)

        self.replay_buffer = ReplayBuffer(self.buffer_capacity, self.state_dim, self.action_dim)
        self.loss_fn_actor = nn.MSELoss()

        self.loss_fn = nn.MSELoss()

        # Let's implement Huber loss
        self.loss_fn_critic = nn.MSELoss()

        # Defining the lists for losses storage for later visualization
        self.critic_losses_pretrain = []
        self.actor_losses_pretrain = []
        self.critic_losses = []
        self.actor_losses = []

        self.scheduler_actor = StepLR(self.actor_optimizer, step_size=150, gamma=0.1)
        self.scheduler_critic = StepLR(self.critic_optimizer, step_size=150, gamma=0.1)

        self.buffer_save = False

        self.exploration_noise_std_min = exploration_noise_std * 0.1
        self.exploration_noise_std_initial = exploration_noise_std
        self.decay_rate = 0.99995

        self.warm_start = True

    def pre_train_entire_data(self, path, data_loader, epochs):

        # scaler = GradScaler('cuda')  # Initialize GradScaler for mixed precision training
        timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")
        save_path = os.path.join(path, "PLots", timestamp)
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        for epoch in range(epochs):
            loss_critic_epoch = 0.0
            loss_actor_epoch = 0.0

            for batch_idx, (states_batch, actions_batch, rewards_batch, next_states_batch) in enumerate(data_loader):
                # Move data to GPU
                states_batch = states_batch.to(self.device)
                actions_batch = actions_batch.to(self.device)
                rewards_batch = rewards_batch.to(self.device)
                next_states_batch = next_states_batch.to(self.device)
                batch_size = states_batch.size(0)

                # Compute target Q-value without tracking gradients
                with torch.no_grad():
                    noise = (self.t_std * torch.randn(batch_size, self.action_dim, device=self.device)).clamp(
                        -self.noise_clip, self.noise_clip)
                    next_actions = (self.actor_target(next_states_batch) + noise).clamp(-self.max_action,
                                                                                        self.max_action)
                    target_Q1, target_Q2 = self.critic_target(next_states_batch, next_actions)
                    target_Q = torch.min(target_Q1, target_Q2)
                    target_value = rewards_batch + self.gamma * target_Q

                # --- Critic Update with Mixed Precision ---
                # with autocast('cuda'):
                current_Q1, current_Q2 = self.critic(states_batch, actions_batch)
                critic_loss = self.loss_fn(current_Q1, target_value) + self.loss_fn(current_Q2, target_value)
                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                self.critic_optimizer.step()
                loss_critic_epoch += critic_loss.item() * batch_size

                # --- Actor Update with Mixed Precision ---
                # with autocast('cuda'):
                actor_output = self.actor(states_batch)
                actor_loss = self.loss_fn(actions_batch, actor_output)

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()
                loss_actor_epoch += actor_loss.item() * batch_size

            # Average losses over the epoch
            num_samples = len(data_loader.dataset)
            loss_critic_epoch /= num_samples
            loss_actor_epoch /= num_samples

            self.critic_losses_pretrain.append(loss_critic_epoch)
            self.actor_losses_pretrain.append(loss_actor_epoch)

        # """
        # Pre-trains actor-critic networks on an offline replay buffer
        # without mixed-precision.  All computations are carried out
        # in default FP32.
        # """
        # # ------------------------------------------------------------------
        # # 1.  House-keeping: create a time-stamped directory for plots
        # # ------------------------------------------------------------------
        # timestamp  = datetime.datetime.now().strftime("%y%m%d%H%M")
        # save_path  = os.path.join(path, "PLots", timestamp)
        # os.makedirs(save_path, exist_ok=True)
        #
        # # ------------------------------------------------------------------
        # # 2.  Epoch loop
        # # ------------------------------------------------------------------
        # for epoch in range(epochs):
        #     loss_critic_epoch = 0.0
        #     loss_actor_epoch  = 0.0
        #
        #     for batch_idx, (states, actions, rewards, next_states) in enumerate(data_loader):
        #         # -------- move mini-batch to the same device as the networks ---
        #         states      = states.to(self.device)
        #         actions     = actions.to(self.device)
        #         rewards     = rewards.to(self.device)
        #         next_states = next_states.to(self.device)
        #         B           = states.size(0)                       # batch size
        #
        #         # ------------------------------------------------------------------
        #         # 2-A.  Compute TD target (no gradients needed)
        #         # ------------------------------------------------------------------
        #         with torch.no_grad():
        #             noise = (self.t_std
        #                      * torch.randn(B, self.action_dim, device=self.device)
        #                     ).clamp(-self.noise_clip, self.noise_clip)
        #
        #             next_actions       = (self.actor_target(next_states) + noise).clamp(-self.max_action,
        #                                                                                  self.max_action)
        #             target_Q1, target_Q2 = self.critic_target(next_states, next_actions)
        #             target_Q   = torch.min(target_Q1, target_Q2)
        #             target_val = rewards + self.gamma * target_Q          # shape: [B, 1]
        #
        #         # ------------------------------------------------------------------
        #         # 2-B.  Critic update (FP32)
        #         # ------------------------------------------------------------------
        #         current_Q1, current_Q2 = self.critic(states, actions)
        #         critic_loss = (self.loss_fn_critic(current_Q1, target_val)
        #                        + self.loss_fn_critic(current_Q2, target_val))
        #
        #         self.critic_optimizer.zero_grad()
        #         critic_loss.backward()
        #         self.critic_optimizer.step()
        #
        #         loss_critic_epoch += critic_loss.item() * B
        #
        #         # ------------------------------------------------------------------
        #         # 2-C.  Actor update (FP32)
        #         # ------------------------------------------------------------------
        #         predicted_actions = self.actor(states)             # Ï€(s)
        #         actor_loss = self.loss_fn_actor(predicted_actions, actions)
        #
        #         self.actor_optimizer.zero_grad()
        #         actor_loss.backward()
        #         self.actor_optimizer.step()
        #
        #         loss_actor_epoch += actor_loss.item() * B
        #
        #     # ------------------------------------------------------------------
        #     # 3.  Book-keeping: store average epoch losses
        #     # ------------------------------------------------------------------
        #     N = len(data_loader.dataset)       # total samples in the loader
        #     self.critic_losses_pretrain.append(loss_critic_epoch / N)
        #     self.actor_losses_pretrain.append(loss_actor_epoch  / N)

            if epoch == 0 or epoch % 10 == 9:
                print(f"{datetime.datetime.now()} Epoch {epoch + 1},"
                      f" Actor Loss: {loss_actor_epoch},"
                      f" Critic Loss: {loss_critic_epoch}")
                if epoch == 0 or epoch % 100 == 99:
                    current_lr_cr = self.critic_optimizer.param_groups[0]['lr']
                    print(f"Epoch {epoch + 1}, Learning rate (Critic): {current_lr_cr}")
                    current_lr_ac = self.actor_optimizer.param_groups[0]['lr']
                    print(f"Epoch {epoch + 1}, Learning rate (Actor): {current_lr_ac}")

            if epoch % 100 == 0:
                if epoch > 1:
                    # Exclude the first `plot_start` losses to get a clearer plot of later epochs
                    if epoch < 110:
                        plot_start = 0
                    else:
                        plot_start = epoch - 100
                    critic_losses_plot = self.critic_losses_pretrain[plot_start:]
                    actor_losses_plot = self.actor_losses_pretrain[plot_start:]

                    fig, axs = plt.subplots(1, 2, figsize=(15, 6))

                    # Plot Critic Loss
                    axs[0].plot(range(plot_start + 1, len(critic_losses_plot) + 1 + plot_start), critic_losses_plot,
                                label="Critic Loss", color='r', linewidth=2)
                    axs[0].set_title("Critic Loss (Pre-train)", fontsize=16)
                    axs[0].set_xlabel("Epochs", fontsize=12)
                    axs[0].set_ylabel("Loss", fontsize=12)
                    axs[0].grid(True)
                    axs[0].legend()

                    # Plot Actor Loss
                    axs[1].plot(range(plot_start + 1, len(actor_losses_plot) + 1 + plot_start), actor_losses_plot,
                                label="Actor Loss", color='b', linewidth=2)
                    axs[1].set_title("Actor Loss (Pre-train)", fontsize=16)
                    axs[1].set_xlabel("Epochs", fontsize=12)
                    axs[1].set_ylabel("Loss", fontsize=12)
                    axs[1].grid(True)
                    axs[1].legend()

                    # Show the plot
                    plt.tight_layout()

                    filename = os.path.join(save_path, f"{plot_start}_{epoch}.png")
                    plt.savefig(filename)
                    plt.close(fig)

            # Update the learning rate schedulers at the end of the epoch
            # self.scheduler_actor.step()
            # self.scheduler_critic.step()

            # # Clear CUDA cache and collect garbage to free up memory (helpful with large models/datasets)
            # torch.cuda.empty_cache()
            # gc.collect()

    def pre_train(self, path: str, batch_size: int, epochs: int = 1000, log_interval: int = 1000):
        """
        Pre-train on stored replay buffer samples using TD3-style updates
        with mixed-precision (GradScaler + autocast).

        Args:
            path:            path to store plots
            batch_size:      number of samples per gradient update
            epochs:          total number of gradient-update iterations
            log_interval:    print losses every `log_interval` steps
        """

        # time to store the losses of critic and actor
        timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")

        path = os.path.join(path, "PLots_pretrain_normal", timestamp)
        if not os.path.exists(path):
            os.makedirs(path)

        for it in range(1, epochs + 1):
            # 1) Sample a random minibatch
            states, actions, rewards, next_states = self.replay_buffer.sample_pretrain(batch_size, device=str(self.device))

            # 2) Compute target Q-values (no grad)
            with torch.no_grad():
                noise = (self.t_std * np.random.randn(batch_size,
                                                      self.action_dim)).clip(
                    -self.noise_clip, self.noise_clip)
                noise = torch.from_numpy(noise).to(dtype=torch.float32, device=self.device)
                next_actions = (self.actor_target(next_states) + noise).clip(-1, 1)

                target_Q1, target_Q2 = self.critic_target(next_states, next_actions)
                target_Q = torch.min(target_Q1, target_Q2)

                y = rewards + self.gamma * target_Q

            # 3) Critic update (mixed precision)
            q1, q2 = self.critic(states, actions)

            loss_critic = self.loss_fn_critic(q1, y) + self.loss_fn_critic(q2, y)

            self.critic_optimizer.zero_grad()
            loss_critic.backward()
            self.critic_optimizer.step()
            self.critic_losses.append(loss_critic.item())

            # 4) Delayed actor update
            if it % self.policy_delay == 0:
                actor_actions = self.actor(states)
                actor_loss = self.loss_fn_actor(actor_actions, actions)

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()
                self.actor_losses.append(actor_loss.item())

                # 5) Soft update of targets
                for p, p_tgt in zip(self.actor.parameters(), self.actor_target.parameters()):
                    p_tgt.data.mul_(1 - self.tau)
                    p_tgt.data.add_(self.tau * p.data)
                for p, p_tgt in zip(self.critic.parameters(), self.critic_target.parameters()):
                    p_tgt.data.mul_(1 - self.tau)
                    p_tgt.data.add_(self.tau * p.data)

                # store the actor loss
                self.actor_losses_pretrain.append(actor_loss.item())

            # store the critic loss
            self.critic_losses_pretrain.append(loss_critic.item())

            # 6) Logging
            if it % log_interval == 0 or it == 1:
                print(f"[Pre-train] Iter {it}/{epochs}: "
                      f"critic_loss={loss_critic.item():.6e}, "
                      f"actor_loss={(actor_loss.item() if 'actor_loss' in locals() else float('nan')):.6e}")


    def train(self, n_samples=100):

        self.buffer_save = True

        (states_batch, actions_batch,
         rewards_batch, next_states_batch) = self.replay_buffer.sample(n_samples, device=self.device)
        # (states_batch, actions_batch,
        #  rewards_batch, next_states_batch) = self.replay_buffer.sample_pretrain(n_samples, device=self.device)

        loss_critic = 0
        loss_actor = 0

        with torch.no_grad():
            noise = (self.t_std * np.random.randn(n_samples,
                                                  self.action_dim)).clip(
                -self.noise_clip, self.noise_clip)
            noise = torch.from_numpy(noise).to(dtype=torch.float32, device=self.device)
            next_actions = (self.actor_target(next_states_batch) + noise).clip(-1, 1)

            target_Q1, target_Q2 = self.critic_target(next_states_batch, next_actions)
            target_Q = torch.min(target_Q1, target_Q2)

            r = rewards_batch + self.gamma * target_Q

        q1, q2 = self.critic(states_batch, actions_batch)

        loss = self.loss_fn(q1, r) + self.loss_fn(q2, r)

        self.critic_optimizer.zero_grad()
        loss.backward()
        self.critic_optimizer.step()
        loss_critic += loss.item()
        self.critic_losses.append(loss.item())

        # if self.total_it % self.policy_delay == 0:
        actions = self.actor(states_batch)
        q = self.critic.q1_forward(states_batch, actions)

        actor_loss = - torch.mean(q)

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        loss_actor += actor_loss.item()
        self.actor_losses.append(actor_loss.item())

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        # print(f'Critic loss: {self.critic_losses[-1]}')
        # print(f'Actor loss: {self.actor_losses[-1]}')
        self.total_it += 1

        self.decay_exploration()

    def decay_exploration(self):

        self.exploration_noise_std = max(
            self.exploration_noise_std_min,
            self.exploration_noise_std_initial * (self.decay_rate ** self.total_it)
        )

    def take_action(self, state, explore=False):
        state = state if isinstance(state, torch.Tensor) else torch.from_numpy(state).to(dtype=torch.float32,
                                                                                         device=self.device)

        # If warm start is active, choose a random action from the action space.
        if self.warm_start:
            action = np.random.uniform(low=-self.max_action, high=self.max_action, size=self.action_dim)
        else:
            with torch.no_grad():
                action = self.actor(state).detach().cpu().numpy()
            if explore:
                action += np.random.randn(self.action_dim) * self.exploration_noise_std
            action = action.clip(-self.max_action, self.max_action)

        return action

    def save(self, path: str, name_prefix="agent"):
        path = os.path.join(path, "models")
        if not os.path.exists(path):
            os.makedirs(path)

        timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")
        filename = os.path.join(path, f"{name_prefix}_{timestamp}.pkl")

        save_dict = {
            'actor_state_dict': self.actor.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'actor_target_state_dict': self.actor_target.state_dict(),
            'critic_target_state_dict': self.critic_target.state_dict(),
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
            'agent_attributes': {key: value for key, value in self.__dict__.items() if key not in ['actor', 'critic',
                                                                                                   'actor_target',
                                                                                                   'critic_target',
                                                                                                   'replay_buffer',
                                                                                                   'total_it',
                                                                                                   'device',
                                                                                                   'loss_fn',
                                                                                                   'replay_buffer']}
        }

        with open(filename, 'wb') as f:
            pickle.dump(save_dict, f)
        print(f"Agent saved successfully to {filename}")
        return filename

    def load(self, path: str):
        with open(path, 'rb') as f:
            loaded_dict = pickle.load(f)

        self.actor.load_state_dict(loaded_dict['actor_state_dict'])
        self.critic.load_state_dict(loaded_dict['critic_state_dict'])
        self.actor_target.load_state_dict(loaded_dict['actor_target_state_dict'])
        self.critic_target.load_state_dict(loaded_dict['critic_target_state_dict'])

        for key, value in loaded_dict['agent_attributes'].items():
            setattr(self, key, value)

        self.actor_optimizer.load_state_dict(loaded_dict['actor_optimizer_state_dict'])
        self.critic_optimizer.load_state_dict(loaded_dict['critic_optimizer_state_dict'])
        self.actor_optimizer.param_groups[0]['params'] = list(self.actor.parameters())
        self.critic_optimizer.param_groups[0]['params'] = list(self.critic.parameters())
        self.actor_optimizer.param_groups[0]['lr'] = self.actor_lr # * 1e-3
        self.critic_optimizer.param_groups[0]['lr'] = self.critic_lr

        # self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr * 1e-2)
        # self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(param.data)

        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(param.data)

        print(f"Agent loaded successfully from: {path}")


## Initialize the system

In [3]:
# First initiate the system
# Parameters
Ad = 2.142e17           # h^-1
Ed = 14897              # K
Ap = 3.816e10           # L/(molh)
Ep = 3557               # K
At = 4.50e12            # L/(molh)
Et = 843                # K
fi = 0.6                # Coefficient
m_delta_H_r = -6.99e4   # j/mol
hA = 1.05e6             # j/(Kh)
rhocp = 1506            # j/(Kh)
rhoccpc = 4043          # j/(Kh)
Mm = 104.14             # g/mol
system_params = np.array([Ad, Ed, Ap, Ep, At, Et, fi, m_delta_H_r, hA, rhocp, rhoccpc, Mm])

In [4]:
# Design Parameters
CIf = 0.5888    # mol/L
CMf = 8.6981    # mol/L
Qi = 108.       # L/h
Qs = 459.       # L/h
Tf = 330.       # K
Tcf = 295.      # K
V = 3000.       # L
Vc = 3312.4     # L
        
system_design_params = np.array([CIf, CMf, Qi, Qs, Tf, Tcf, V, Vc])

In [5]:
# Steady State Inputs
Qm_ss = 378.    # L/h
Qc_ss = 471.6   # L/h

system_steady_state_inputs = np.array([Qc_ss, Qm_ss])

In [6]:
# Sampling time of the system
delta_t = 0.5 # 30 mins

In [7]:
# Initiate the CSTR for steady state values
cstr = PolymerCSTR(system_params, system_design_params, system_steady_state_inputs, delta_t)
steady_states={"ss_inputs":cstr.ss_inputs,
               "y_ss":cstr.y_ss}

## Loading the system matrices, min max scaling, and min max of the states

In [8]:
dir_path = os.path.join(os.getcwd(), "Data")

In [9]:
# Defining the range of setpoints for data generation
setpoint_y = np.array([[2.8, 320.],
                       [5., 326.]])
u_min = np.array([71.6, 78])
u_max = np.array([870, 670])

system_data = load_and_prepare_system_data(steady_states=steady_states, setpoint_y=setpoint_y, u_min=u_min, u_max=u_max)

In [10]:
A_aug = system_data["A_aug"]
B_aug = system_data["B_aug"]
C_aug = system_data["C_aug"]

In [11]:
data_min = system_data["data_min"]
data_max = system_data["data_max"]

In [12]:
# min_max_states = system_data["min_max_states"]
# min_max_states

In [13]:
# min_max_states = system_data["min_max_states"]
# min_max_states = system_data["min_max_states"]
min_max_states = {'max_s': np.array([256.79686253, 256.01560603,  48.99447186, 144.79949103,
          2.82199733,   3.14014989,   2.78866348,   3.71691422,
          6.2029936 ]),
                  'min_s': np.array([ -272.28060121, -1112.33972595,   -76.63993491,  -608.60327886,
           -3.94399122,    -3.93115257,    -2.9532091 ,    -4.06547624,
          -28.25906582])}

In [14]:
y_sp_scaled_deviation = system_data["y_sp_scaled_deviation"]

In [15]:
b_min = system_data["b_min"]
b_max = system_data["b_max"]

In [16]:
min_max_dict = system_data["min_max_dict"]
min_max_dict["x_max"] = np.array([256.79686253, 256.01560603,  48.99447186, 144.79949103,
          2.82199733,   3.14014989,   2.78866348,   3.71691422,
          6.2029936 ])
min_max_dict["x_min"] = np.array([ -272.28060121, -1112.33972595,   -76.63993491,  -608.60327886,
           -3.94399122,    -3.93115257,    -2.9532091 ,    -4.06547624,
          -28.25906582])

In [17]:
min_max_dict

{'x_max': array([256.79686253, 256.01560603,  48.99447186, 144.79949103,
          2.82199733,   3.14014989,   2.78866348,   3.71691422,
          6.2029936 ]),
 'x_min': array([ -272.28060121, -1112.33972595,   -76.63993491,  -608.60327886,
           -3.94399122,    -3.93115257,    -2.9532091 ,    -4.06547624,
          -28.25906582]),
 'y_sp_min': array([-4.91766443, -4.61204935]),
 'y_sp_max': array([5.00776949, 3.06512771]),
 'u_max': array([9.96, 7.3 ]),
 'u_min': array([-10. ,  -7.5])}

## Setting The hyperparameters for the TD3 Agent

In [18]:
set_points_number = int(C_aug.shape[0])
inputs_number = int(B_aug.shape[1])
STATE_DIM = int(A_aug.shape[0]) + set_points_number + inputs_number
ACTION_DIM = int(B_aug.shape[1])
# ACTOR_LAYER_SIZES = [1024, 1024, 1024, 1024, 1024]
# CRITIC_LAYER_SIZES = [1024, 1024, 1024, 1024, 1024]
ACTOR_LAYER_SIZES = [512, 512, 512, 512, 512]
CRITIC_LAYER_SIZES = [512, 512, 512, 512, 512]
BUFFER_CAPACITY = 30_000_000
ACTOR_LR = 1e-4
CRITIC_LR = 1e-4
SMOOTHING_STD = 0.000001
NOISE_CLIP = 0.5
EXPLORATION_NOISE_STD = 0.1
GAMMA = 0.995
TAU = 0.005
MAX_ACTION = 1
POLICY_DELAY = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)
BATCH_SIZE = 256
NUM_BATCHES = 1

Using device: cuda


In [19]:
# Agent
agent = Agent(
    STATE_DIM,
    ACTION_DIM,
    ACTOR_LAYER_SIZES,
    CRITIC_LAYER_SIZES,
    BUFFER_CAPACITY,
    ACTOR_LR,
    CRITIC_LR,
    SMOOTHING_STD,
    NOISE_CLIP,
    EXPLORATION_NOISE_STD,
    GAMMA,
    TAU,
    MAX_ACTION,
    POLICY_DELAY,
    DEVICE
)

In [20]:
# MPC parameters
predict_h = 9
cont_h = 3
b1 = (b_min[0], b_max[0])
b2 = (b_min[1], b_max[1])
bnds = (b1, b2)*cont_h
cons = []
IC_opt = np.zeros(inputs_number*cont_h)
Q1 = 5
Q2 = 1
R1 = 1
R2 = 1
Q_rew = np.array([[12, 0], [0, 8]])
R_rew = np.array([[1, 0], [0, 1]])

In [21]:
MPC_obj = MpcSolver(A_aug, B_aug, C_aug,
                    Q1, Q2, R1, R2,
                    predict_h, cont_h)

In [22]:
steady_states_samples_number = 100000
mpc_pretrain_samples_numbers = BUFFER_CAPACITY - steady_states_samples_number

In [23]:
filling_the_buffer(
        min_max_dict,
        A_aug, B_aug, C_aug,
        MPC_obj,
        mpc_pretrain_samples_numbers,
        Q_rew, R_rew,
        agent,
        IC_opt, bnds, cons, steady_states["y_ss"], data_min, data_max, chunk_size= 100_000)

Processing chunk 1/299
Processing chunk 2/299
Processing chunk 3/299
Processing chunk 4/299
Processing chunk 5/299
Processing chunk 6/299
Processing chunk 7/299
Processing chunk 8/299
Processing chunk 9/299
Processing chunk 10/299
Processing chunk 11/299
Processing chunk 12/299
Processing chunk 13/299
Processing chunk 14/299
Processing chunk 15/299
Processing chunk 16/299
Processing chunk 17/299
Processing chunk 18/299
Processing chunk 19/299
Processing chunk 20/299
Processing chunk 21/299
Processing chunk 22/299
Processing chunk 23/299
Processing chunk 24/299
Processing chunk 25/299
Processing chunk 26/299
Processing chunk 27/299
Processing chunk 28/299
Processing chunk 29/299
Processing chunk 30/299
Processing chunk 31/299
Processing chunk 32/299
Processing chunk 33/299
Processing chunk 34/299
Processing chunk 35/299
Processing chunk 36/299
Processing chunk 37/299
Processing chunk 38/299
Processing chunk 39/299
Processing chunk 40/299
Processing chunk 41/299
Processing chunk 42/299
P

In [24]:
add_steady_state_samples(
        min_max_dict,
        A_aug, B_aug, C_aug,
        MPC_obj,
        steady_states_samples_number,
        Q_rew, R_rew,
        agent,
        IC_opt, bnds, cons, steady_states["y_ss"], data_min, data_max, chunk_size= 100_000)

Processing chunk 1/1
Replay buffer has been filled up with the steady_state values.


## Saving and Loading the Replay Buffer to make sure we have the saved replay buffer

In [25]:
filename_buffer = agent.replay_buffer.save(dir_path)

Replay buffer saved to C:\Users\HAMEDI\OneDrive - McMaster University\PythonProjects\Polymer_example\Data\models\replay_buffer_2507151626.h5


In [26]:
replay_buffer = ReplayBuffer(BUFFER_CAPACITY, STATE_DIM, ACTION_DIM)

In [27]:
replay_buffer.load(filename_buffer)

Replay buffer loaded from C:\Users\HAMEDI\OneDrive - McMaster University\PythonProjects\Polymer_example\Data\models\replay_buffer_2507151626.h5


In [28]:
agent.replay_buffer = replay_buffer

## Pre training the Agent

In [29]:
# import gc
#
# # Clear memory before DataLoader
# torch.cuda.empty_cache()
# gc.collect()

In [30]:
# Create dataset and dataloader
dataset = ReplayDataset(torch.from_numpy(replay_buffer.states).to(dtype=torch.float32),
                        torch.from_numpy(replay_buffer.actions).to(dtype=torch.float32),
                        torch.from_numpy(replay_buffer.rewards).to(dtype=torch.float32),
                        torch.from_numpy(replay_buffer.next_states).to(dtype=torch.float32))
data_loader = DataLoader(dataset, batch_size=2048 * 4, shuffle=True, num_workers=10, pin_memory=True)

In [31]:
# Pretraining
EPOCHS_FOR_PRETRAIN = 3000

agent.pre_train_entire_data(dir_path, data_loader, EPOCHS_FOR_PRETRAIN)

2025-07-15 16:27:47.085449 Epoch 1, Actor Loss: 0.0026822865463982646, Critic Loss: 6079348.787957609
Epoch 1, Learning rate (Critic): 0.0001
Epoch 1, Learning rate (Actor): 0.0001
2025-07-15 16:34:47.300607 Epoch 10, Actor Loss: 7.853822999541686e-06, Critic Loss: 516.0803024298177
2025-07-15 16:42:58.983043 Epoch 20, Actor Loss: 2.8033056189640774e-06, Critic Loss: 211.43094332109376
2025-07-15 16:52:06.261738 Epoch 30, Actor Loss: 1.7116014426205463e-06, Critic Loss: 143.60284401822918
2025-07-15 17:01:25.091879 Epoch 40, Actor Loss: 1.2262625033714964e-06, Critic Loss: 115.35663267147623
2025-07-15 17:10:46.611564 Epoch 50, Actor Loss: 9.522098025248852e-07, Critic Loss: 97.16969638343099
2025-07-15 17:20:05.843347 Epoch 60, Actor Loss: 7.867441409694341e-07, Critic Loss: 81.32849118634441
2025-07-15 17:29:24.701614 Epoch 70, Actor Loss: 6.651147892878119e-07, Critic Loss: 70.1532533137736
2025-07-15 17:38:44.349470 Epoch 80, Actor Loss: 5.719167145211637e-07, Critic Loss: 67.55745

In [41]:
# # Pretraining
# EPOCHS_FOR_PRETRAIN = int(BUFFER_CAPACITY / 5)

EPOCHS_FOR_PRETRAIN = 1000000

agent.pre_train(dir_path, EPOCHS_FOR_PRETRAIN, BATCH_SIZE, log_interval=10000)

KeyboardInterrupt: 

## Saving and loading the agent to make sure the agent has been stored

In [32]:
filename_agent = agent.save(dir_path)

Agent saved successfully to C:\Users\HAMEDI\OneDrive - McMaster University\PythonProjects\Polymer_example\Data\models\agent_2507171027.pkl


In [33]:
# Agent
agent = Agent(
    STATE_DIM,
    ACTION_DIM,
    ACTOR_LAYER_SIZES,
    CRITIC_LAYER_SIZES,
    BUFFER_CAPACITY,
    ACTOR_LR,
    CRITIC_LR,
    SMOOTHING_STD,
    NOISE_CLIP,
    EXPLORATION_NOISE_STD,
    GAMMA,
    TAU,
    MAX_ACTION,
    POLICY_DELAY,
    DEVICE
)

In [34]:
agent.load(filename_agent)
agent.replay_buffer = replay_buffer

Agent loaded successfully from: C:\Users\HAMEDI\OneDrive - McMaster University\PythonProjects\Polymer_example\Data\models\agent_2507171027.pkl


## Checking the accuracy of the agent and compare it to the MPC actions

In [35]:
print_accuracy(agent.replay_buffer, agent, n_samples=2, device=DEVICE)

Agent r2 score for the predicted inputs compare to MPC inputs: 0.999998
Agent r2 score for the predicted input 1 compare to MPC input 1: 0.999997
Agent r2 score for the predicted input 1 compare to MPC input 2: 0.999999


In [36]:
min_max_dict

{'x_max': array([256.79686253, 256.01560603,  48.99447186, 144.79949103,
          2.82199733,   3.14014989,   2.78866348,   3.71691422,
          6.2029936 ]),
 'x_min': array([ -272.28060121, -1112.33972595,   -76.63993491,  -608.60327886,
           -3.94399122,    -3.93115257,    -2.9532091 ,    -4.06547624,
          -28.25906582]),
 'y_sp_min': array([-4.91766443, -4.61204935]),
 'y_sp_max': array([5.00776949, 3.06512771]),
 'u_max': array([9.96, 7.3 ]),
 'u_min': array([-10. ,  -7.5])}