In [None]:
from Simulation.mpc import *
from TD3Agent.agent import *
from Simulation.system_functions import PolymerCSTR
from BasicFunctions.td3_functions import *
from TD3Agent.replay_buffer import ReplayDataset, DataLoader

In [3]:
import datetime
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
from torch.optim.lr_scheduler import StepLR
from TD3Agent.actor import Actor
from TD3Agent.critic import Critic
from TD3Agent.replay_buffer import ReplayBuffer
from torch.amp import GradScaler, autocast
import gc


class Agent(object):
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 actor_layer_sizes: list,
                 critic_layer_sizes: list,
                 buffer_capacity: int,
                 actor_lr: float,
                 critic_lr: float,
                 target_policy_smoothing_noise_std: float,
                 noise_clip: float,
                 exploration_noise_std: float,
                 gamma: float,
                 tau: float,
                 max_action: float,
                 policy_delay: int,
                 device):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.actor_layer_sizes = actor_layer_sizes
        self.critic_layer_sizes = critic_layer_sizes
        self.buffer_capacity = buffer_capacity

        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.t_std = target_policy_smoothing_noise_std
        self.noise_clip = noise_clip
        self.exploration_noise_std = exploration_noise_std
        self.gamma = gamma
        self.tau = tau

        self.max_action = max_action
        self.policy_delay = policy_delay
        self.total_it = 0

        self.device = device

        self.actor = Actor(self.state_dim, self.action_dim, self.actor_layer_sizes).to(self.device)
        self.actor_target = Actor(self.state_dim, self.action_dim, self.actor_layer_sizes).to(self.device)

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(param.data)

        self.critic = Critic(self.state_dim, self.action_dim, self.critic_layer_sizes).to(self.device)
        self.critic_target = Critic(self.state_dim, self.action_dim, self.critic_layer_sizes).to(self.device)

        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(param.data)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)

        self.replay_buffer = ReplayBuffer(self.buffer_capacity, self.state_dim, self.action_dim)
        self.loss_fn = nn.MSELoss()

        # Let's implement Huber loss
        # self.loss_fn = nn.SmoothL1Loss()

        # Defining the lists for losses storage for later visualization
        self.critic_losses_pretrain = []
        self.actor_losses_pretrain = []
        self.critic_losses = []
        self.actor_losses = []

        self.scheduler_actor = StepLR(self.actor_optimizer, step_size=150, gamma=0.1)
        self.scheduler_critic = StepLR(self.critic_optimizer, step_size=150, gamma=0.1)

        self.buffer_save = False

        self.exploration_noise_std_min = exploration_noise_std * 0.1
        self.exploration_noise_std_initial = exploration_noise_std
        self.decay_rate = 0.99995

        self.warm_start = True

    def pre_train_entire_data(self, path, data_loader, epochs):

        scaler = GradScaler('cuda')  # Initialize GradScaler for mixed precision training
        timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")
        save_path = os.path.join(path, "PLots", timestamp)
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        for epoch in range(epochs):
            loss_critic_epoch = 0.0
            loss_actor_epoch = 0.0

            for batch_idx, (states_batch, actions_batch, rewards_batch, next_states_batch) in enumerate(data_loader):
                # Move data to GPU
                states_batch = states_batch.to(self.device)
                actions_batch = actions_batch.to(self.device)
                rewards_batch = rewards_batch.to(self.device)
                next_states_batch = next_states_batch.to(self.device)
                batch_size = states_batch.size(0)

                # Compute target Q-value without tracking gradients
                with torch.no_grad():
                    noise = (self.t_std * torch.randn(batch_size, self.action_dim, device=self.device)).clamp(
                        -self.noise_clip, self.noise_clip)
                    next_actions = (self.actor_target(next_states_batch) + noise).clamp(-self.max_action,
                                                                                        self.max_action)
                    target_Q1, target_Q2 = self.critic_target(next_states_batch, next_actions)
                    target_Q = torch.min(target_Q1, target_Q2)
                    target_value = rewards_batch + self.gamma * target_Q

                # --- Critic Update with Mixed Precision ---
                with autocast('cuda'):
                    current_Q1, current_Q2 = self.critic(states_batch, actions_batch)
                    critic_loss = self.loss_fn(current_Q1, target_value) + self.loss_fn(current_Q2, target_value)
                self.critic_optimizer.zero_grad()
                scaler.scale(critic_loss).backward()
                scaler.step(self.critic_optimizer)
                loss_critic_epoch += critic_loss.item() * batch_size

                # --- Actor Update with Mixed Precision ---
                with autocast('cuda'):
                    actor_output = self.actor(states_batch)
                    actor_loss = self.loss_fn(actions_batch, actor_output)

                self.actor_optimizer.zero_grad()
                scaler.scale(actor_loss).backward()
                scaler.step(self.actor_optimizer)
                loss_actor_epoch += actor_loss.item() * batch_size

                scaler.update()

            # Average losses over the epoch
            num_samples = len(data_loader.dataset)
            loss_critic_epoch /= num_samples
            loss_actor_epoch /= num_samples

            self.critic_losses_pretrain.append(loss_critic_epoch)
            self.actor_losses_pretrain.append(loss_actor_epoch)

            # if epoch == 0 or epoch % 10 == 9:
            print(f"{datetime.datetime.now()} Epoch {epoch + 1},"
                  f" Actor Loss: {loss_actor_epoch},"
                  f" Critic Loss: {loss_critic_epoch}")
            if epoch == 0 or epoch % 100 == 99:
                current_lr_cr = self.critic_optimizer.param_groups[0]['lr']
                print(f"Epoch {epoch + 1}, Learning rate (Critic): {current_lr_cr}")
                current_lr_ac = self.actor_optimizer.param_groups[0]['lr']
                print(f"Epoch {epoch + 1}, Learning rate (Actor): {current_lr_ac}")

            if epoch % 100 == 0:
                if epoch > 1:
                    # Exclude the first `plot_start` losses to get a clearer plot of later epochs
                    if epoch < 110:
                        plot_start = 0
                    else:
                        plot_start = epoch - 100
                    critic_losses_plot = self.critic_losses_pretrain[plot_start:]
                    actor_losses_plot = self.actor_losses_pretrain[plot_start:]

                    fig, axs = plt.subplots(1, 2, figsize=(15, 6))

                    # Plot Critic Loss
                    axs[0].plot(range(plot_start + 1, len(critic_losses_plot) + 1 + plot_start), critic_losses_plot,
                                label="Critic Loss", color='r', linewidth=2)
                    axs[0].set_title("Critic Loss (Pre-train)", fontsize=16)
                    axs[0].set_xlabel("Epochs", fontsize=12)
                    axs[0].set_ylabel("Loss", fontsize=12)
                    axs[0].grid(True)
                    axs[0].legend()

                    # Plot Actor Loss
                    axs[1].plot(range(plot_start + 1, len(actor_losses_plot) + 1 + plot_start), actor_losses_plot,
                                label="Actor Loss", color='b', linewidth=2)
                    axs[1].set_title("Actor Loss (Pre-train)", fontsize=16)
                    axs[1].set_xlabel("Epochs", fontsize=12)
                    axs[1].set_ylabel("Loss", fontsize=12)
                    axs[1].grid(True)
                    axs[1].legend()

                    # Show the plot
                    plt.tight_layout()

                    filename = os.path.join(save_path, f"{plot_start}_{epoch}.png")
                    plt.savefig(filename)
                    plt.close(fig)

            # Update the learning rate schedulers at the end of the epoch
            # self.scheduler_actor.step()
            # self.scheduler_critic.step()

            # Clear CUDA cache and collect garbage to free up memory (helpful with large models/datasets)
            torch.cuda.empty_cache()
            gc.collect()

    def pre_train(self, path: str, batch_size: int, epochs: int = 1000, log_interval: int = 1000):
        """
        Pre-train on stored replay buffer samples using TD3-style updates
        with mixed-precision (GradScaler + autocast).

        Args:
            path:            path to store plots
            batch_size:      number of samples per gradient update
            epochs:          total number of gradient-update iterations
            log_interval:    print losses every `log_interval` steps
        """

        # time to store the losses of critic and actor
        timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")

        path = os.path.join(path, "PLots_pretrain_normal", timestamp)
        if not os.path.exists(path):
            os.makedirs(path)


        scaler = GradScaler(str(self.device))
        for it in range(1, epochs + 1):
            # 1) Sample a random minibatch
            states, actions, rewards, next_states = \
                self.replay_buffer.sample_pretrain(batch_size, device=str(self.device))

            # 2) Compute target Q-values (no grad)
            with torch.no_grad():
                noise = (torch.randn(batch_size, self.action_dim, device=self.device)
                         * self.t_std).clamp(-self.noise_clip, self.noise_clip)
                next_actions = (self.actor_target(next_states) + noise).clamp(
                    -self.max_action, self.max_action
                )
                target_Q1, target_Q2 = self.critic_target(next_states, next_actions)
                target_Q = torch.min(target_Q1, target_Q2)
                target_value = rewards + self.gamma * target_Q

            # 3) Critic update (mixed precision)
            self.critic_optimizer.zero_grad()
            with autocast(str(self.device)):
                current_Q1, current_Q2 = self.critic(states, actions)
                critic_loss = self.loss_fn(current_Q1, target_value) \
                              + self.loss_fn(current_Q2, target_value)
            scaler.scale(critic_loss).backward()
            scaler.step(self.critic_optimizer)
            scaler.update()

            # 4) Delayed actor update
            if it % self.policy_delay == 0:
                self.actor_optimizer.zero_grad()
                with autocast(str(self.device)):
                    actor_actions = self.actor(states)
                    # Imitation actor loss:
                    actor_loss = self.loss_fn(actor_actions, actions)
                scaler.scale(actor_loss).backward()
                scaler.step(self.actor_optimizer)
                scaler.update()

                # 5) Soft update of targets
                for p, p_tgt in zip(self.actor.parameters(), self.actor_target.parameters()):
                    p_tgt.data.mul_(1 - self.tau)
                    p_tgt.data.add_(self.tau * p.data)
                for p, p_tgt in zip(self.critic.parameters(), self.critic_target.parameters()):
                    p_tgt.data.mul_(1 - self.tau)
                    p_tgt.data.add_(self.tau * p.data)

                # store the actor loss
                self.actor_losses_pretrain.append(actor_loss.item())

            # store the critic loss
            self.critic_losses_pretrain.append(critic_loss)

            # 6) Logging
            if it % log_interval == 0 or it == 1:
                print(f"[Pre-train] Iter {it}/{epochs}: "
                      f"critic_loss={critic_loss.item():.6e}, "
                      f"actor_loss={(actor_loss.item() if 'actor_loss' in locals() else float('nan')):.6e}")

            # if it % 10000 == 0:
            #         if it > 1:
            #             if it < 11000:
            #                 plot_start = 0
            #             else:
            #                 plot_start = it - 10000
            #             critic_losses_plot = self.critic_losses_pretrain[plot_start:]
            #             actor_losses_plot = self.actor_losses_pretrain[plot_start:]
            #
            #             fig, axs = plt.subplots(1, 2, figsize=(15, 6))
            #
            #             # Plot Critic Loss
            #             axs[0].plot(range(plot_start + 1, len(critic_losses_plot) + 1 + plot_start), critic_losses_plot,
            #                         label="Critic Loss", color='r', linewidth=2)
            #             axs[0].set_title("Critic Loss (Pre-train)", fontsize=16)
            #             axs[0].set_xlabel("Epochs", fontsize=12)
            #             axs[0].set_ylabel("Loss", fontsize=12)
            #             axs[0].grid(True)
            #             axs[0].legend()
            #
            #             # Plot Actor Loss
            #             axs[1].plot(range(plot_start + 1, len(actor_losses_plot) + 1 + plot_start), actor_losses_plot,
            #                         label="Actor Loss", color='b', linewidth=2)
            #             axs[1].set_title("Actor Loss (Pre-train)", fontsize=16)
            #             axs[1].set_xlabel("Epochs", fontsize=12)
            #             axs[1].set_ylabel("Loss", fontsize=12)
            #             axs[1].grid(True)
            #             axs[1].legend()
            #
            #             # Show the plot
            #             plt.tight_layout()
            #
            #             filename = os.path.join(path, f"{plot_start}_{it}.png")
            #             plt.savefig(filename)
            #             plt.close(fig)


    def train(self, n_samples=100):

        self.buffer_save = True

        (states_batch, actions_batch,
         rewards_batch, next_states_batch) = self.replay_buffer.sample(n_samples, device=self.device)
        # (states_batch, actions_batch,
        #  rewards_batch, next_states_batch) = self.replay_buffer.sample_pretrain(n_samples, device=self.device)

        loss_critic = 0
        loss_actor = 0

        with torch.no_grad():
            noise = (self.t_std * np.random.randn(n_samples,
                                                  self.action_dim)).clip(
                -self.noise_clip, self.noise_clip)
            noise = torch.from_numpy(noise).to(dtype=torch.float32, device=self.device)
            next_actions = (self.actor_target(next_states_batch) + noise).clip(-1, 1)

            target_Q1, target_Q2 = self.critic_target(next_states_batch, next_actions)
            target_Q = torch.min(target_Q1, target_Q2)

            r = rewards_batch + self.gamma * target_Q

        q1, q2 = self.critic(states_batch, actions_batch)

        loss = self.loss_fn(q1, r) + self.loss_fn(q2, r)

        self.critic_optimizer.zero_grad()
        loss.backward()
        self.critic_optimizer.step()
        loss_critic += loss.item()
        self.critic_losses.append(loss.item())

        # if self.total_it % self.policy_delay == 0:
        actions = self.actor(states_batch)
        q = self.critic.q1_forward(states_batch, actions)

        actor_loss = - torch.mean(q)

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        loss_actor += actor_loss.item()
        self.actor_losses.append(actor_loss.item())

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        # print(f'Critic loss: {self.critic_losses[-1]}')
        # print(f'Actor loss: {self.actor_losses[-1]}')
        self.total_it += 1

        self.decay_exploration()

    def decay_exploration(self):

        self.exploration_noise_std = max(
            self.exploration_noise_std_min,
            self.exploration_noise_std_initial * (self.decay_rate ** self.total_it)
        )

    def take_action(self, state, explore=False):
        state = state if isinstance(state, torch.Tensor) else torch.from_numpy(state).to(dtype=torch.float32,
                                                                                         device=self.device)

        # If warm start is active, choose a random action from the action space.
        if self.warm_start:
            action = np.random.uniform(low=-self.max_action, high=self.max_action, size=self.action_dim)
        else:
            with torch.no_grad():
                action = self.actor(state).detach().cpu().numpy()
            if explore:
                action += np.random.randn(self.action_dim) * self.exploration_noise_std
            action = action.clip(-self.max_action, self.max_action)

        return action

    def save(self, path: str, name_prefix="agent"):
        path = os.path.join(path, "models")
        if not os.path.exists(path):
            os.makedirs(path)

        timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")
        filename = os.path.join(path, f"{name_prefix}_{timestamp}.pkl")

        save_dict = {
            'actor_state_dict': self.actor.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'actor_target_state_dict': self.actor_target.state_dict(),
            'critic_target_state_dict': self.critic_target.state_dict(),
            'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
            'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
            'agent_attributes': {key: value for key, value in self.__dict__.items() if key not in ['actor', 'critic',
                                                                                                   'actor_target',
                                                                                                   'critic_target',
                                                                                                   'replay_buffer',
                                                                                                   'total_it',
                                                                                                   'device',
                                                                                                   'loss_fn',
                                                                                                   'replay_buffer']}
        }

        with open(filename, 'wb') as f:
            pickle.dump(save_dict, f)
        print(f"Agent saved successfully to {filename}")
        return filename

    def load(self, path: str):
        with open(path, 'rb') as f:
            loaded_dict = pickle.load(f)

        self.actor.load_state_dict(loaded_dict['actor_state_dict'])
        self.critic.load_state_dict(loaded_dict['critic_state_dict'])
        self.actor_target.load_state_dict(loaded_dict['actor_target_state_dict'])
        self.critic_target.load_state_dict(loaded_dict['critic_target_state_dict'])

        for key, value in loaded_dict['agent_attributes'].items():
            setattr(self, key, value)

        self.actor_optimizer.load_state_dict(loaded_dict['actor_optimizer_state_dict'])
        self.critic_optimizer.load_state_dict(loaded_dict['critic_optimizer_state_dict'])
        self.actor_optimizer.param_groups[0]['params'] = list(self.actor.parameters())
        self.critic_optimizer.param_groups[0]['params'] = list(self.critic.parameters())
        self.actor_optimizer.param_groups[0]['lr'] = self.actor_lr # * 1e-3
        self.critic_optimizer.param_groups[0]['lr'] = self.critic_lr

        # self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr * 1e-2)
        # self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(param.data)

        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(param.data)

        print(f"Agent loaded successfully from: {path}")


## Initialize the system

In [1]:
# First initiate the system
# Parameters
Ad = 2.142e17           # h^-1
Ed = 14897              # K
Ap = 3.816e10           # L/(molh)
Ep = 3557               # K
At = 4.50e12            # L/(molh)
Et = 843                # K
fi = 0.6                # Coefficient
m_delta_H_r = -6.99e4   # j/mol
hA = 1.05e6             # j/(Kh)
rhocp = 1506            # j/(Kh)
rhoccpc = 4043          # j/(Kh)
Mm = 104.14             # g/mol
system_params = np.array([Ad, Ed, Ap, Ep, At, Et, fi, m_delta_H_r, hA, rhocp, rhoccpc, Mm])

NameError: name 'np' is not defined

In [5]:
# Design Parameters
CIf = 0.5888    # mol/L
CMf = 8.6981    # mol/L
Qi = 108.       # L/h
Qs = 459.       # L/h
Tf = 330.       # K
Tcf = 295.      # K
V = 3000.       # L
Vc = 3312.4     # L
        
system_design_params = np.array([CIf, CMf, Qi, Qs, Tf, Tcf, V, Vc])

In [6]:
# Steady State Inputs
Qm_ss = 378.    # L/h
Qc_ss = 471.6   # L/h

system_steady_state_inputs = np.array([Qc_ss, Qm_ss])

In [7]:
# Sampling time of the system
delta_t = 0.5 # 30 mins

In [8]:
# Initiate the CSTR for steady state values
cstr = PolymerCSTR(system_params, system_design_params, system_steady_state_inputs, delta_t)
steady_states={"ss_inputs":cstr.ss_inputs,
               "y_ss":cstr.y_ss}

## Loading the system matrices, min max scaling, and min max of the states

In [9]:
dir_path = os.path.join(os.getcwd(), "Data")

In [10]:
# Defining the range of setpoints for data generation
setpoint_y = np.array([[3.2, 321],
                       [4.5, 325]])
u_min = np.array([71.6, 78])
u_max = np.array([870, 670])

system_data = load_and_prepare_system_data(steady_states=steady_states, setpoint_y=setpoint_y, u_min=u_min, u_max=u_max)

In [11]:
A_aug = system_data["A_aug"]
B_aug = system_data["B_aug"]
C_aug = system_data["C_aug"]

In [12]:
data_min = system_data["data_min"]
data_max = system_data["data_max"]

In [13]:
# min_max_states = system_data["min_max_states"]
# min_max_states

In [14]:
# min_max_states = system_data["min_max_states"]
# min_max_states = system_data["min_max_states"]
min_max_states = {'max_s': np.array([267.22485424, 309.39633111,  59.30922216, 176.90730552,
         2.93078497,   3.25879914,   2.90161768,   5.05992019,
         8.006078  ]),
                  'min_s': np.array([-2.45841151e+02, -4.42441998e+03, -7.59124687e+01, -2.66399529e+03,
       -2.63938070e+00, -2.98235548e+00, -2.66964174e+00, -5.38486567e+00,
       -1.00988888e+02])}

In [15]:
y_sp_scaled_deviation = system_data["y_sp_scaled_deviation"]

In [16]:
b_min = system_data["b_min"]
b_max = system_data["b_max"]

In [17]:
min_max_dict = system_data["min_max_dict"]
min_max_dict["x_max"] = np.array([267.22485424, 309.39633111,  59.30922216, 176.90730552,
         2.93078497,   3.25879914,   2.90161768,   5.05992019,
         8.006078  ])
min_max_dict["x_min"] = np.array([-2.45841151e+02, -4.42441998e+03, -7.59124687e+01, -2.66399529e+03,
       -2.63938070e+00, -2.98235548e+00, -2.66964174e+00, -5.38486567e+00,
       -1.00988888e+02])

In [18]:
min_max_dict

{'x_max': array([267.22485424, 309.39633111,  59.30922216, 176.90730552,
          2.93078497,   3.25879914,   2.90161768,   5.05992019,
          8.006078  ]),
 'x_min': array([-2.45841151e+02, -4.42441998e+03, -7.59124687e+01, -2.66399529e+03,
        -2.63938070e+00, -2.98235548e+00, -2.66964174e+00, -5.38486567e+00,
        -1.00988888e+02]),
 'y_sp_min': array([-3.11304008, -3.33251984]),
 'y_sp_max': array([2.75198906, 1.7855982 ]),
 'u_max': array([9.96, 7.3 ]),
 'u_min': array([-10. ,  -7.5])}

## Setting The hyperparameters for the TD3 Agent

In [19]:
# import time
# class Agent(object):
#     def __init__(self,
#                  state_dim: int,
#                  action_dim: int,
#                  actor_layer_sizes: list,
#                  critic_layer_sizes: list,
#                  buffer_capacity: int,
#                  actor_lr: float,
#                  critic_lr: float,
#                  target_policy_smoothing_noise_std: float,
#                  noise_clip: float,
#                  exploration_noise_std: float,
#                  gamma: float,
#                  tau: float,
#                  max_action: float,
#                  policy_delay: int,
#                  device):
#         self.state_dim = state_dim
#         self.action_dim = action_dim
#         self.actor_layer_sizes = actor_layer_sizes
#         self.critic_layer_sizes = critic_layer_sizes
#         self.buffer_capacity = buffer_capacity
#
#         self.actor_lr = actor_lr
#         self.critic_lr = critic_lr
#         self.t_std = target_policy_smoothing_noise_std
#         self.noise_clip = noise_clip
#         self.exploration_noise_std = exploration_noise_std
#         self.gamma = gamma
#         self.tau = tau
#
#         self.max_action = max_action
#         self.policy_delay = policy_delay
#         self.total_it = 0
#
#         self.device = device
#
#         self.actor = Actor(self.state_dim, self.action_dim, self.actor_layer_sizes).to(self.device)
#         self.actor_target = Actor(self.state_dim, self.action_dim, self.actor_layer_sizes).to(self.device)
#
#         for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
#             target_param.data.copy_(param.data)
#
#         self.critic = Critic(self.state_dim, self.action_dim, self.critic_layer_sizes).to(self.device)
#         self.critic_target = Critic(self.state_dim, self.action_dim, self.critic_layer_sizes).to(self.device)
#
#         for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
#             target_param.data.copy_(param.data)
#
#         self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr)
#         self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)
#
#         self.replay_buffer = ReplayBuffer(self.buffer_capacity, self.state_dim, self.action_dim)
#         # self.loss_fn = nn.MSELoss()
#
#         # Let's implement Huber loss
#         self.loss_fn = nn.SmoothL1Loss()
#
#         # Defining the lists for losses storage for later visualization
#         self.critic_losses_pretrain = []
#         self.actor_losses_pretrain = []
#         self.critic_losses = []
#         self.actor_losses = []
#
#         self.scheduler_actor = StepLR(self.actor_optimizer, step_size=150, gamma=0.1)
#         self.scheduler_critic = StepLR(self.critic_optimizer, step_size=150, gamma=0.1)
#
#         self.buffer_save = False
#
#         self.exploration_noise_std_min = exploration_noise_std * 0.1
#         self.exploration_noise_std_initial = exploration_noise_std
#         self.decay_rate = 0.99995
#
#         self.warm_start = True
#
#     def pre_train_entire_data(self, path, data_loader, epochs):
#
#         scaler = GradScaler('cuda')  # Initialize GradScaler for mixed precision training
#         timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")
#         save_path = os.path.join(path, "PLots", timestamp)
#         if not os.path.exists(save_path):
#             os.makedirs(save_path)
#
#         for epoch in range(epochs):
#             loss_critic_epoch = 0.0
#             loss_actor_epoch = 0.0
#
#             for batch_idx, (states_batch, actions_batch, rewards_batch, next_states_batch) in enumerate(data_loader):
#                 # Move data to GPU
#                 states_batch = states_batch.to(self.device)
#                 actions_batch = actions_batch.to(self.device)
#                 rewards_batch = rewards_batch.to(self.device)
#                 next_states_batch = next_states_batch.to(self.device)
#                 batch_size = states_batch.size(0)
#
#                 # Compute target Q-value without tracking gradients
#                 with torch.no_grad():
#                     noise = (self.t_std * torch.randn(batch_size, self.action_dim, device=self.device)).clamp(
#                         -self.noise_clip, self.noise_clip)
#                     next_actions = (self.actor_target(next_states_batch) + noise).clamp(-self.max_action,
#                                                                                         self.max_action)
#                     target_Q1, target_Q2 = self.critic_target(next_states_batch, next_actions)
#                     target_Q = torch.min(target_Q1, target_Q2)
#                     target_value = rewards_batch + self.gamma * target_Q
#
#                 # --- Critic Update with Mixed Precision ---
#                 with autocast('cuda'):
#                     current_Q1, current_Q2 = self.critic(states_batch, actions_batch)
#                     critic_loss = self.loss_fn(current_Q1, target_value) + self.loss_fn(current_Q2, target_value)
#                 self.critic_optimizer.zero_grad()
#                 scaler.scale(critic_loss).backward()
#                 scaler.step(self.critic_optimizer)
#                 loss_critic_epoch += critic_loss.item() * batch_size
#
#                 # --- Actor Update with Mixed Precision ---
#                 with autocast('cuda'):
#                     # For demonstration, we use MSE between actions and actor outputs.
#                     # (You might instead use the critic’s Q-value as in standard TD3.)
#                     actor_output = self.actor(states_batch)
#                     actor_loss = self.loss_fn(actions_batch, actor_output)
#                 self.actor_optimizer.zero_grad()
#                 scaler.scale(actor_loss).backward()
#                 scaler.step(self.actor_optimizer)
#                 loss_actor_epoch += actor_loss.item() * batch_size
#
#                 scaler.update()
#
#             # Average losses over the epoch
#             num_samples = len(data_loader.dataset)
#             loss_critic_epoch /= num_samples
#             loss_actor_epoch /= num_samples
#
#             self.critic_losses_pretrain.append(loss_critic_epoch)
#             self.actor_losses_pretrain.append(loss_actor_epoch)
#
#             if epoch == 0 or epoch % 10 == 9:
#                 print(f"{datetime.datetime.now()} Epoch {epoch + 1},"
#                       f" Actor Loss: {loss_actor_epoch},"
#                       f" Critic Loss: {loss_critic_epoch}")
#             if epoch == 0 or epoch % 100 == 99:
#                 current_lr_cr = self.critic_optimizer.param_groups[0]['lr']
#                 print(f"Epoch {epoch + 1}, Learning rate (Critic): {current_lr_cr}")
#                 current_lr_ac = self.actor_optimizer.param_groups[0]['lr']
#                 print(f"Epoch {epoch + 1}, Learning rate (Actor): {current_lr_ac}")
#
#             if epoch % 100 == 0:
#                 if epoch > 1:
#                     # Exclude the first `plot_start` losses to get a clearer plot of later epochs
#                     if epoch < 110:
#                         plot_start = 0
#                     else:
#                         plot_start = epoch - 100
#                     critic_losses_plot = self.critic_losses_pretrain[plot_start:]
#                     actor_losses_plot = self.actor_losses_pretrain[plot_start:]
#
#                     fig, axs = plt.subplots(1, 2, figsize=(15, 6))
#
#                     # Plot Critic Loss
#                     axs[0].plot(range(plot_start + 1, len(critic_losses_plot) + 1 + plot_start), critic_losses_plot,
#                                 label="Critic Loss", color='r', linewidth=2)
#                     axs[0].set_title("Critic Loss (Pre-train)", fontsize=16)
#                     axs[0].set_xlabel("Epochs", fontsize=12)
#                     axs[0].set_ylabel("Loss", fontsize=12)
#                     axs[0].grid(True)
#                     axs[0].legend()
#
#                     # Plot Actor Loss
#                     axs[1].plot(range(plot_start + 1, len(actor_losses_plot) + 1 + plot_start), actor_losses_plot,
#                                 label="Actor Loss", color='b', linewidth=2)
#                     axs[1].set_title("Actor Loss (Pre-train)", fontsize=16)
#                     axs[1].set_xlabel("Epochs", fontsize=12)
#                     axs[1].set_ylabel("Loss", fontsize=12)
#                     axs[1].grid(True)
#                     axs[1].legend()
#
#                     # Show the plot
#                     plt.tight_layout()
#
#                     filename = os.path.join(save_path, f"{plot_start}_{epoch}.png")
#                     plt.savefig(filename)
#                     plt.close(fig)
#
#             # Update the learning rate schedulers at the end of the epoch
#             self.scheduler_actor.step()
#             self.scheduler_critic.step()
#
#
#     def pre_train(self, path, epochs, n_samples=100, log_interval=1000):
#
#         timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")
#         path = os.path.join(path, "Pre_training_plots", timestamp)
#         os.makedirs(path, exist_ok=True)
#
#         # 1) set up AMP scaler
#         scaler = GradScaler()
#
#         # logging time
#         t0 = time.perf_counter()
#
#         for epoch in range(epochs):
#
#             # 1_1) sample a mini‐batch
#             states_batch, actions_batch, rewards_batch, next_states_batch = \
#                 self.replay_buffer.sample_pretrain(n_samples, device=self.device)
#
#             # 2) cast everything to float32 on the correct device
#             states_batch      = states_batch.to(self.device, dtype=torch.float32)
#             actions_batch     = actions_batch.to(self.device, dtype=torch.float32)
#             rewards_batch     = rewards_batch.to(self.device, dtype=torch.float32)
#             next_states_batch = next_states_batch.to(self.device, dtype=torch.float32)
#
#             # 2) compute targets in mixed precision
#             with torch.no_grad(), autocast(device_type=str(self.device)):
#                 # noise for target policy
#                 noise = (
#                     self.t_std
#                     * torch.randn(n_samples, self.action_dim, device=self.device, dtype=torch.float32)
#                 ).clamp(-self.noise_clip, self.noise_clip)
#
#                 next_actions = (self.actor_target(next_states_batch) + noise).clamp(-1, 1)
#                 target_Q1, target_Q2 = self.critic_target(next_states_batch, next_actions)
#                 target_Q = torch.min(target_Q1, target_Q2)
#                 r = rewards_batch + self.gamma * target_Q
#
#             # 3) critic update
#             self.critic_optimizer.zero_grad()
#             with autocast(device_type=str(self.device)):
#                 q1, q2 = self.critic(states_batch, actions_batch)
#                 critic_loss = self.loss_fn(q1, r) + self.loss_fn(q2, r)
#             scaler.scale(critic_loss).backward()
#             scaler.step(self.critic_optimizer)
#             scaler.update()
#             loss_critic = critic_loss.item()
#
#             # 4) actor update (delayed if you use policy_delay logic)
#             self.actor_optimizer.zero_grad()
#             with autocast(device_type=str(self.device)):
#                 # imitation Learning
#                 actor_out = self.actor(states_batch)
#                 actor_loss = self.loss_fn(actions_batch, actor_out)
#             scaler.scale(actor_loss).backward()
#             scaler.step(self.actor_optimizer)
#             scaler.update()
#             loss_actor = actor_loss.item()
#
#
#             # 5) soft‐update targets
#             for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
#                 target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
#             for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
#                 target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
#
#
#             # 6) logging
#             self.critic_losses_pretrain.append(loss_critic)
#             self.actor_losses_pretrain.append(loss_actor)
#
#             if epoch == 0 or epoch % log_interval  == (log_interval - 1):
#                 print(f"[Pre-train] Step {epoch+1}/{epochs} |"
#                       f"Computation Time: {time.perf_counter() - t0:.3f}s |"
#                       f"Actor Loss: {loss_actor:.6e} |"
#                       f"Critic Loss: {loss_critic:.6e}")
#                 # logging time
#                 t0 = time.perf_counter()
#
#             # 7) occasional plotting
#             if epoch % 2500 == 0:
#                 if epoch > 1:
#                     if epoch < 1100:
#                         plot_start = 0
#                     else:
#                         plot_start = epoch - 2500
#                     critic_losses_plot = self.critic_losses_pretrain[plot_start:]
#                     actor_losses_plot = self.actor_losses_pretrain[plot_start:]
#
#                     fig, axs = plt.subplots(1, 2, figsize=(15, 6))
#
#                     # Plot Critic Loss
#                     axs[0].plot(range(plot_start + 1, len(critic_losses_plot) + 1 + plot_start), critic_losses_plot,
#                                 label="Critic Loss", color='r', linewidth=2)
#                     axs[0].set_title("Critic Loss (Pre-train)", fontsize=16)
#                     axs[0].set_xlabel("Epochs", fontsize=12)
#                     axs[0].set_ylabel("Loss", fontsize=12)
#                     axs[0].grid(True)
#                     axs[0].legend()
#
#                     # Plot Actor Loss
#                     axs[1].plot(range(plot_start + 1, len(actor_losses_plot) + 1 + plot_start), actor_losses_plot,
#                                 label="Actor Loss", color='b', linewidth=2)
#                     axs[1].set_title("Actor Loss (Pre-train)", fontsize=16)
#                     axs[1].set_xlabel("Epochs", fontsize=12)
#                     axs[1].set_ylabel("Loss", fontsize=12)
#                     axs[1].grid(True)
#                     axs[1].legend()
#
#                     # Show the plot
#                     plt.tight_layout()
#
#                     filename = os.path.join(path, f"{plot_start}_{epoch}.png")
#                     plt.savefig(filename)
#                     plt.close(fig)
#
#
#     def train_online(self, n_samples=100):
#
#         self.buffer_save = True
#
#         # (states_batch, actions_batch,
#         #  rewards_batch, next_states_batch) = self.replay_buffer.sample(n_samples, device=self.device)
#         (states_batch, actions_batch,
#          rewards_batch, next_states_batch) = self.replay_buffer.sample_pretrain(n_samples, device=self.device)
#
#         loss_critic = 0
#         loss_actor = 0
#
#         with torch.no_grad():
#             noise = (self.t_std * np.random.randn(n_samples,
#                                                   self.action_dim)).clip(
#                 -self.noise_clip, self.noise_clip)
#             noise = torch.from_numpy(noise).to(dtype=torch.float32, device=self.device)
#             next_actions = (self.actor_target(next_states_batch) + noise).clip(-1, 1)
#
#             target_Q1, target_Q2 = self.critic_target(next_states_batch, next_actions)
#             target_Q = torch.min(target_Q1, target_Q2)
#
#             r = rewards_batch + self.gamma * target_Q
#
#         q1, q2 = self.critic(states_batch, actions_batch)
#
#         loss = self.loss_fn(q1, r) + self.loss_fn(q2, r)
#
#         self.critic_optimizer.zero_grad()
#         loss.backward()
#         self.critic_optimizer.step()
#         loss_critic += loss.item()
#         self.critic_losses.append(loss.item())
#
#         # if self.total_it % self.policy_delay == 0:
#         actions = self.actor(states_batch)
#         q = self.critic.q1_forward(states_batch, actions)
#
#         actor_loss = - torch.mean(q)
#
#         self.actor_optimizer.zero_grad()
#         actor_loss.backward()
#         self.actor_optimizer.step()
#         loss_actor += actor_loss.item()
#         self.actor_losses.append(actor_loss.item())
#
#         for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
#             target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
#
#         for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
#             target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
#
#         # print(f'Critic loss: {self.critic_losses[-1]}')
#         # print(f'Actor loss: {self.actor_losses[-1]}')
#         self.total_it += 1
#
#         self.decay_exploration()
#
#     def decay_exploration(self):
#
#         self.exploration_noise_std = max(
#             self.exploration_noise_std_min,
#             self.exploration_noise_std_initial * (self.decay_rate ** self.total_it)
#         )
#
#     def take_action(self, state, explore=False):
#         state = state if isinstance(state, torch.Tensor) else torch.from_numpy(state).to(dtype=torch.float32,
#                                                                                          device=self.device)
#
#         # If warm start is active, choose a random action from the action space.
#         if self.warm_start:
#             action = np.random.uniform(low=-self.max_action, high=self.max_action, size=self.action_dim)
#         else:
#             with torch.no_grad():
#                 action = self.actor(state).detach().cpu().numpy()
#             if explore:
#                 action += np.random.randn(self.action_dim) * self.exploration_noise_std
#             action = action.clip(-self.max_action, self.max_action)
#
#         return action
#
#     def save(self, path: str, name_prefix="agent"):
#         path = os.path.join(path, "models")
#         if not os.path.exists(path):
#             os.makedirs(path)
#
#         timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")
#         filename = os.path.join(path, f"{name_prefix}_{timestamp}.pkl")
#
#         save_dict = {
#             'actor_state_dict': self.actor.state_dict(),
#             'critic_state_dict': self.critic.state_dict(),
#             'actor_target_state_dict': self.actor_target.state_dict(),
#             'critic_target_state_dict': self.critic_target.state_dict(),
#             'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
#             'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
#             'agent_attributes': {key: value for key, value in self.__dict__.items() if key not in ['actor', 'critic',
#                                                                                                    'actor_target',
#                                                                                                    'critic_target',
#                                                                                                    'replay_buffer',
#                                                                                                    'total_it',
#                                                                                                    'device',
#                                                                                                    'loss_fn',
#                                                                                                    'replay_buffer']}
#         }
#
#         with open(filename, 'wb') as f:
#             pickle.dump(save_dict, f)
#         print(f"Agent saved successfully to {filename}")
#         return filename
#
#     def load(self, path: str):
#         with open(path, 'rb') as f:
#             loaded_dict = pickle.load(f)
#
#         self.actor.load_state_dict(loaded_dict['actor_state_dict'])
#         self.critic.load_state_dict(loaded_dict['critic_state_dict'])
#         self.actor_target.load_state_dict(loaded_dict['actor_target_state_dict'])
#         self.critic_target.load_state_dict(loaded_dict['critic_target_state_dict'])
#
#         for key, value in loaded_dict['agent_attributes'].items():
#             setattr(self, key, value)
#
#         self.actor_optimizer.load_state_dict(loaded_dict['actor_optimizer_state_dict'])
#         self.critic_optimizer.load_state_dict(loaded_dict['critic_optimizer_state_dict'])
#         self.actor_optimizer.param_groups[0]['params'] = list(self.actor.parameters())
#         self.critic_optimizer.param_groups[0]['params'] = list(self.critic.parameters())
#         self.actor_optimizer.param_groups[0]['lr'] = self.actor_lr # * 1e-3
#         self.critic_optimizer.param_groups[0]['lr'] = self.critic_lr
#
#         # self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.actor_lr * 1e-2)
#         # self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr)
#
#         for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
#             target_param.data.copy_(param.data)
#
#         for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
#             target_param.data.copy_(param.data)
#
#         print(f"Agent loaded successfully from: {path}")

In [20]:
# class AgentPretrainer:
#     def __init__(self, agent):
#         """
#         agent: TD3 Agent, with .actor, .critic, .actor_target, .critic_target,
#                .actor_optimizer, .critic_optimizer, .gamma, .tau, .t_std, .noise_clip
#         buffer: replay_buffer with .sample_pretrain(batch_size, device)
#         """
#         self.agent = agent
#         self.buffer = agent.replay_buffer
#         self.device = agent.device
#         self.loss_fn = agent.loss_fn
#         self.critic_optimizer = agent.critic_optimizer
#         self.actor_optimizer = agent.actor_optimizer
#
#         self.bc_losses = []
#         self.phase2_losses = []
#         self.actor_losses = []
#         self.critic_losses = []
#
#     def phase1_behavioral_cloning(self, epochs, batch_size, log_interval=100):
#         """Phase 1: freeze critic, train actor to mimic MPC actions via Huber (SmoothL1) loss."""
#         # 1) Freeze critic
#         for p in self.agent.critic.parameters():
#             p.requires_grad = False
#         for p in self.agent.actor.parameters():
#             p.requires_grad = True
#         self.agent.actor.train()
#         self.agent.critic.eval()
#
#
#         scaler = GradScaler()
#
#         for ep in range(1, epochs+1):
#             # sample (s, u_mpc, _, _)
#             states, actions, _, _ = self.buffer.sample_pretrain(batch_size, device=self.device)
#             states  = states .to(self.device, dtype=torch.float32)
#             targets = actions.to(self.device, dtype=torch.float32)
#
#             self.actor_optimizer.zero_grad()
#             with autocast(device_type=str(self.device)):
#                 preds = self.agent.actor(states)
#                 loss  = self.loss_fn(preds, targets)
#             scaler.scale(loss).backward()
#             scaler.step(self.actor_optimizer)
#             scaler.update()
#
#             if ep % log_interval == 0 or ep == 1:
#                 print(f"[BC Phase] Epoch {ep}/{epochs}  Loss: {loss.item():.3e}")
#
#             self.bc_losses.append(loss.item())
#
#     def phase2_critic_warmup(self, epochs, batch_size, log_interval=100):
#         """Phase 2: freeze actor, train critic on TD targets from MPC data."""
#         # 1) Freeze actor
#         for p in self.agent.critic.parameters():
#             p.requires_grad = True
#         for p in self.agent.actor.parameters():
#             p.requires_grad = False
#         self.agent.actor.eval()
#         self.agent.critic.train()
#
#         scaler = GradScaler()
#
#         for ep in range(1, epochs+1):
#             # sample (s, u_mpc, r, s')
#             states, actions, rewards, next_states = \
#                 self.buffer.sample_pretrain(batch_size, device=self.device)
#             states      = states     .to(self.device, dtype=torch.float32)
#             actions     = actions    .to(self.device, dtype=torch.float32)
#             rewards     = rewards    .to(self.device, dtype=torch.float32)
#             next_states = next_states.to(self.device, dtype=torch.float32)
#
#             # build TD target in full precision
#             with torch.no_grad():
#                 noise = (
#                     self.agent.t_std
#                     * torch.randn(batch_size, self.agent.action_dim,
#                                  device=self.device, dtype=torch.float32)
#                 ).clamp(-self.agent.noise_clip, self.agent.noise_clip)
#                 next_a = (self.agent.actor_target(next_states) + noise).clamp(-1, 1)
#                 tQ1, tQ2  = self.agent.critic_target(next_states, next_a)
#                 target_Q  = torch.min(tQ1, tQ2)
#                 y_target  = rewards + self.agent.gamma * target_Q
#
#             # critic update
#             self.critic_optimizer.zero_grad()
#             with autocast(device_type=str(self.device)):
#                 Q1, Q2 = self.agent.critic(states, actions)
#                 loss   = self.loss_fn(Q1, y_target) + self.loss_fn(Q2, y_target)
#             scaler.scale(loss).backward()
#             scaler.step(self.critic_optimizer)
#             scaler.update()
#
#             if ep % log_interval == 0 or ep == 1:
#                 print(f"[Critic Phase] Epoch {ep}/{epochs}  Loss: {loss.item():.3e}")
#
#             self.phase2_losses.append(loss.item())
#
#     def phase3_offline_td3(self, epochs, batch_size, policy_delay=4, log_interval=100, rl_factor=0, bc_factor=1):
#         """
#         Phase 3: offline TD3 updates on filled buffer (still no env).
#         This alternates critic & actor updates exactly as in TD3, sampling from buffer.
#         """
#         # Unfreeze both
#         for p in self.agent.actor.parameters():
#             p.requires_grad = True
#         for p in self.agent.critic.parameters():
#             p.requires_grad = True
#         self.agent.actor.train()
#         self.agent.critic.train()
#
#         scaler = GradScaler()
#
#         for ep in range(1, epochs+1):
#             # sample batch
#             states, actions, rewards, next_states = \
#                 self.buffer.sample_pretrain(batch_size, device=self.device)
#             states      = states     .to(self.device, dtype=torch.float32)
#             actions     = actions    .to(self.device, dtype=torch.float32)
#             rewards     = rewards    .to(self.device, dtype=torch.float32)
#             next_states = next_states.to(self.device, dtype=torch.float32)
#
#             # 1) critic step
#             with torch.no_grad():
#                 noise = (
#                     self.agent.t_std
#                     * torch.randn(batch_size, self.agent.action_dim,
#                                  device=self.device, dtype=torch.float32)
#                 ).clamp(-self.agent.noise_clip, self.agent.noise_clip)
#                 next_a = (self.agent.actor_target(next_states) + noise).clamp(-1, 1)
#                 tQ1, tQ2 = self.agent.critic_target(next_states, next_a)
#                 y_target = rewards + self.agent.gamma * torch.min(tQ1, tQ2)
#
#             self.critic_optimizer.zero_grad()
#             with autocast(device_type=str(self.device)):
#                 Q1, Q2 = self.agent.critic(states, actions)
#                 loss_c = self.loss_fn(Q1, y_target) + self.loss_fn(Q2, y_target)
#             scaler.scale(loss_c).backward()
#             scaler.step(self.critic_optimizer)
#             scaler.update()
#             self.critic_losses.append(loss_c.item())
#
#             # 2) actor step (delayed)
#             if ep % policy_delay == 0:
#                 self.actor_optimizer.zero_grad()
#                 with autocast(device_type=str(self.device)):
#                     a_pred = self.agent.actor(states)
#                     # 1) RL term: maximize Q1 -> minimize neg-Q1
#                     q_val, _  = self.agent.critic(states, a_pred)
#                     rl_loss = -q_val.mean()
#                     # 2) BC term: keep close to u_MPC
#                     bc_loss = self.loss_fn(a_pred, actions)
#                     # 3) combined
#                     loss_a = rl_factor * rl_loss + bc_factor * bc_loss
#
#                 scaler.scale(loss_a).backward()
#                 scaler.step(self.actor_optimizer)
#                 scaler.update()
#                 self.actor_losses.append(loss_a.item())
#
#             # 3) soft target‐update
#             for p, tp in zip(self.agent.actor.parameters(), self.agent.actor_target.parameters()):
#                 tp.data.copy_(self.agent.tau * p.data + (1-self.agent.tau)*tp.data)
#             for p, tp in zip(self.agent.critic.parameters(), self.agent.critic_target.parameters()):
#                 tp.data.copy_(self.agent.tau * p.data + (1-self.agent.tau)*tp.data)
#
#             if ep % log_interval == 0 or ep == 1:
#                 print(f"[TD3 Phase] Step {ep}/{epochs}  Critic Loss: {loss_c.item():.3e}"
#                       + (f", Actor Loss: {loss_a.item():.3e}" if ep % policy_delay==0 else ""))

In [21]:
set_points_number = int(C_aug.shape[0])
inputs_number = int(B_aug.shape[1])
STATE_DIM = int(A_aug.shape[0]) + set_points_number + inputs_number
ACTION_DIM = int(B_aug.shape[1])
ACTOR_LAYER_SIZES = [2048, 2048, 2048]
CRITIC_LAYER_SIZES = [2048, 2048, 2048]
BUFFER_CAPACITY = 10_000_000
ACTOR_LR = 1e-4
CRITIC_LR = 1e-4
SMOOTHING_STD = 0.000001
NOISE_CLIP = 0.5
EXPLORATION_NOISE_STD = 0.1
GAMMA = 0.95
TAU = 0.005
MAX_ACTION = 1
POLICY_DELAY = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)
BATCH_SIZE = 256
NUM_BATCHES = 1

Using device: cuda


In [22]:
# Agent
agent = Agent(
    STATE_DIM,
    ACTION_DIM,
    ACTOR_LAYER_SIZES,
    CRITIC_LAYER_SIZES,
    BUFFER_CAPACITY,
    ACTOR_LR,
    CRITIC_LR,
    SMOOTHING_STD,
    NOISE_CLIP,
    EXPLORATION_NOISE_STD,
    GAMMA,
    TAU,
    MAX_ACTION,
    POLICY_DELAY,
    DEVICE
)

In [23]:
# MPC parameters
predict_h = 9
cont_h = 3
b1 = (b_min[0], b_max[0])
b2 = (b_min[1], b_max[1])
bnds = (b1, b2)*cont_h
cons = []
IC_opt = np.zeros(inputs_number*cont_h)
Q1 = 5
Q2 = 1
R1 = 1
R2 = 1
Q_rew = np.array([[12, 0], [0, 8]])
R_rew = np.array([[1, 0], [0, 1]])

In [24]:
MPC_obj = MpcSolver(A_aug, B_aug, C_aug,
                    Q1, Q2, R1, R2,
                    predict_h, cont_h)

In [23]:
steady_states_samples_number = 100000
mpc_pretrain_samples_numbers = BUFFER_CAPACITY - steady_states_samples_number

In [24]:
filling_the_buffer(
        min_max_dict,
        A_aug, B_aug, C_aug,
        MPC_obj,
        mpc_pretrain_samples_numbers,
        Q_rew, R_rew,
        agent,
        IC_opt, bnds, cons, chunk_size= 100_000)

Processing chunk 1/99
Processing chunk 2/99
Processing chunk 3/99
Processing chunk 4/99
Processing chunk 5/99
Processing chunk 6/99
Processing chunk 7/99
Processing chunk 8/99
Processing chunk 9/99
Processing chunk 10/99
Processing chunk 11/99
Processing chunk 12/99
Processing chunk 13/99
Processing chunk 14/99
Processing chunk 15/99
Processing chunk 16/99
Processing chunk 17/99
Processing chunk 18/99
Processing chunk 19/99
Processing chunk 20/99
Processing chunk 21/99
Processing chunk 22/99
Processing chunk 23/99
Processing chunk 24/99
Processing chunk 25/99
Processing chunk 26/99
Processing chunk 27/99
Processing chunk 28/99
Processing chunk 29/99
Processing chunk 30/99
Processing chunk 31/99
Processing chunk 32/99
Processing chunk 33/99
Processing chunk 34/99
Processing chunk 35/99
Processing chunk 36/99
Processing chunk 37/99
Processing chunk 38/99
Processing chunk 39/99
Processing chunk 40/99
Processing chunk 41/99
Processing chunk 42/99
Processing chunk 43/99
Processing chunk 44/

In [25]:
add_steady_state_samples(
        min_max_dict,
        A_aug, B_aug, C_aug,
        MPC_obj,
        steady_states_samples_number,
        Q_rew, R_rew,
        agent,
        IC_opt, bnds, cons, chunk_size= 100000)

Processing chunk 1/1
Replay buffer has been filled up with the steady_state values.


## Saving and Loading the Replay Buffer to make sure we have the saved replay buffer

In [26]:
filename_buffer = agent.replay_buffer.save(dir_path)

Replay buffer saved to C:\Users\HAMEDI\OneDrive - McMaster University\PythonProjects\Polymer_example\Data\models\replay_buffer_2505100048.h5


In [25]:
replay_buffer = ReplayBuffer(BUFFER_CAPACITY, STATE_DIM, ACTION_DIM)

In [26]:
replay_buffer.load(r"C:\Users\HAMEDI\OneDrive - McMaster University\PythonProjects\Polymer_example\Data\models\replay_buffer_2505100048.h5")

Replay buffer loaded from C:\Users\HAMEDI\OneDrive - McMaster University\PythonProjects\Polymer_example\Data\models\replay_buffer_2505100048.h5


In [27]:
agent.replay_buffer = replay_buffer

In [28]:
# mu = agent.replay_buffer.rewards.mean()
# sigma = agent.replay_buffer.rewards.std() + 1e-6
# r_norm = (agent.replay_buffer.rewards - mu)/sigma
# r_clipped = np.clip(r_norm,-5.0,+5.0)
# agent.replay_buffer.rewards = r_clipped

## Pre training the Agent

In [29]:
import gc

# Clear memory before DataLoader
torch.cuda.empty_cache()
gc.collect()

92

In [30]:
# Create dataset and dataloader
dataset = ReplayDataset(torch.from_numpy(replay_buffer.states).to(dtype=torch.float32),
                        torch.from_numpy(replay_buffer.actions).to(dtype=torch.float32),
                        torch.from_numpy(replay_buffer.rewards).to(dtype=torch.float32),
                        torch.from_numpy(replay_buffer.next_states).to(dtype=torch.float32))
data_loader = DataLoader(dataset, batch_size=2056, shuffle=True, num_workers=10, pin_memory=True)

In [31]:
# Pretraining
EPOCHS_FOR_PRETRAIN = 1300

agent.pre_train_entire_data(dir_path, data_loader, EPOCHS_FOR_PRETRAIN)

2025-06-17 19:09:13.817925 Epoch 1, Actor Loss: 0.0008212525874044106, Critic Loss: 5427385.012223375
Epoch 1, Learning rate (Critic): 0.0001
Epoch 1, Learning rate (Actor): 0.0001
2025-06-17 19:10:03.324002 Epoch 2, Actor Loss: 5.354618446895038e-05, Critic Loss: 4107.190626141406


KeyboardInterrupt: 

In [None]:
# -------------------------------------------------------
# offline only training
# -------------------------------------------------------
agent_pre = AgentPretrainer(agent=agent)

# Phase 1: Behavioral cloning
agent_pre.phase1_behavioral_cloning(epochs=300000, batch_size=2048)

# Phase 2: Critic warm-up
agent_pre.phase2_critic_warmup(epochs=100000, batch_size=2048)

# Phase 3: Offline TD3 fine-tuning
agent_pre.phase3_offline_td3(epochs=100000, batch_size=2048)

[BC Phase] Epoch 1/300000  Loss: 1.620e-01
[BC Phase] Epoch 100/300000  Loss: 1.689e-03
[BC Phase] Epoch 200/300000  Loss: 1.367e-03
[BC Phase] Epoch 300/300000  Loss: 4.601e-04
[BC Phase] Epoch 400/300000  Loss: 2.340e-04
[BC Phase] Epoch 500/300000  Loss: 2.088e-04
[BC Phase] Epoch 600/300000  Loss: 1.411e-04
[BC Phase] Epoch 700/300000  Loss: 1.313e-04
[BC Phase] Epoch 800/300000  Loss: 1.107e-04
[BC Phase] Epoch 900/300000  Loss: 1.001e-04
[BC Phase] Epoch 1000/300000  Loss: 9.610e-05
[BC Phase] Epoch 1100/300000  Loss: 1.010e-04
[BC Phase] Epoch 1200/300000  Loss: 8.791e-05
[BC Phase] Epoch 1300/300000  Loss: 7.880e-05
[BC Phase] Epoch 1400/300000  Loss: 1.071e-04
[BC Phase] Epoch 1500/300000  Loss: 6.732e-05
[BC Phase] Epoch 1600/300000  Loss: 6.970e-05
[BC Phase] Epoch 1700/300000  Loss: 7.588e-05
[BC Phase] Epoch 1800/300000  Loss: 7.007e-05
[BC Phase] Epoch 1900/300000  Loss: 7.516e-05
[BC Phase] Epoch 2000/300000  Loss: 7.875e-05
[BC Phase] Epoch 2100/300000  Loss: 8.138e-05


In [33]:
# # # Pretraining
# # EPOCHS_FOR_PRETRAIN = int(BUFFER_CAPACITY / 5)
#
# EPOCHS_FOR_PRETRAIN = 100000
#
# BATCH_SIZE = 256
#
# agent.pre_train(dir_path, EPOCHS_FOR_PRETRAIN, BATCH_SIZE, log_interval=1000)
#
# # agent.pre_train(path=dir_path, batch_size=BATCH_SIZE, epochs=EPOCHS_FOR_PRETRAIN, log_interval=10000)

## Saving and loading the agent to make sure the agent has been stored

In [41]:
filename_agent = agent.save(dir_path)

Agent saved successfully to C:\Users\HAMEDI\OneDrive - McMaster University\PythonProjects\Polymer_example\Data\models\agent_2505121104.pkl


In [42]:
# Agent
agent = Agent(
    STATE_DIM,
    ACTION_DIM,
    ACTOR_LAYER_SIZES,
    CRITIC_LAYER_SIZES,
    BUFFER_CAPACITY,
    ACTOR_LR,
    CRITIC_LR,
    SMOOTHING_STD,
    NOISE_CLIP,
    EXPLORATION_NOISE_STD,
    GAMMA,
    TAU,
    MAX_ACTION,
    POLICY_DELAY,
    DEVICE
)

In [43]:
agent.load(filename_agent)
agent.replay_buffer = replay_buffer

Agent loaded successfully from: C:\Users\HAMEDI\OneDrive - McMaster University\PythonProjects\Polymer_example\Data\models\agent_2505121104.pkl


## Checking the accuracy of the agent and compare it to the MPC actions

In [44]:
print_accuracy(agent.replay_buffer, agent, n_samples=5000, device=DEVICE)

Agent r2 score for the predicted inputs compare to MPC inputs: 1.000000
Agent r2 score for the predicted input 1 compare to MPC input 1: 1.000000
Agent r2 score for the predicted input 1 compare to MPC input 2: 1.000000
