# Эксперимент по обучению с использованием сохраненных данных из буфера воспроизведения

In [None]:
from pathlib import Path
import torch
from omegaconf import OmegaConf
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer

def make_buffer(config):
    buffer = TensorDictReplayBuffer(
        batch_size=1,
        storage=LazyTensorStorage(max_size=config.data.buffer_size),
    )
    return buffer

EXPERIMENT_NAME = "td3_train_real_async"
EXPERIMENT_DATE = "2025-09-30"
EXPERIMENT_TIME = "13-08-38"

PATH_TO_EXP_DIR = Path(f"../experiments/{EXPERIMENT_NAME}/{EXPERIMENT_DATE}_{EXPERIMENT_TIME}")

ENV_LOG_DIR = PATH_TO_EXP_DIR / "env_logs"
TRAIN_LOG_DIR = PATH_TO_EXP_DIR / "train_logs"

config_path = PATH_TO_EXP_DIR / ".hydra" / "config.yaml"

config = OmegaConf.load(config_path)

buffer = make_buffer(config)

buffer_path = PATH_TO_EXP_DIR / "saved_data/replay_buffer.pkl"
buffer.load(buffer_path) 

print(f"Replay buffer загружен. Размер: {len(buffer)} сэмплов")

In [None]:
batch_size = 64
batch = buffer.sample(batch_size)
print(f"Сэмплирован batch из {len(batch)} сэмплов")

In [None]:
sample = batch[0]

print("Поля сэмпла:", sample.keys())

print("\nObservation:", sample["observation"])
print("Next Observation:", sample["next"]["observation"])
print("Action:", sample["action"])
print("Reward:", sample["next"]["reward"])
print("Done:", sample["done"])

In [None]:
filtered_buffer = make_buffer(config)
buffer._batch_size = 1
seen = set()

for sample in buffer:
    action = sample["action"]
    
    if (action[0] > 0.9).any() or (action[0] < -0.9).any():
        continue

    obs_tuple = tuple(sample["observation"].flatten().tolist())
    action_tuple = tuple(action.flatten().tolist())
    key = (obs_tuple, action_tuple)
    
    if key in seen:
        continue
    seen.add(key)

    observation = sample["observation"]

    observation[0][0] = observation[0][0] / 125.0
    observation[0][1] = observation[0][1] / 10.0  
    sample["observation"] = observation

    next_observation = sample["next"]["observation"].clone()
    next_observation[0][0] = next_observation[0][0] / 125.0
    next_observation[0][1] = next_observation[0][1] / 10.0
    sample["next"]["observation"] = next_observation
    
    filtered_buffer.add(sample)

    if (len(filtered_buffer) > 2500):
        break

print(f"Filtered buffer создан. Размер: {len(filtered_buffer)} сэмплов")

In [None]:
from collections import deque
from nn_laser_stabilizer.envs.utils import make_specs
from nn_laser_stabilizer.logging.utils import set_seeds
from nn_laser_stabilizer.agents.td3 import (
    make_td3_agent,
    make_loss_module,
    make_optimizers,
    make_target_updater,
    train_step,
    warmup_from_specs
)

set_seeds(config.seed)

specs = make_specs(config.env.bounds)
action_spec = specs["action"]
observation_spec = specs["observation"]

config.agent.learning_rate_actor = 1e-4

actor, qvalue = make_td3_agent(config, observation_spec, action_spec)
warmup_from_specs(observation_spec, action_spec, actor, qvalue)

loss_module = make_loss_module(config, actor, qvalue, action_spec)
optimizer_actor, optimizer_critic = make_optimizers(config, loss_module)
target_net_updater = make_target_updater(config, loss_module)

train_config = config.train

total_train_steps = 0
recent_qvalue_losses = deque(maxlen=train_config.update_to_data)
recent_actor_losses = deque(maxlen=train_config.update_to_data // train_config.update_actor_freq)

print("Training process initiated")

try:
    for _ in range(100):
        try:
            for i in range(train_config.update_to_data):
                batch = filtered_buffer.sample(train_config.batch_size)
                update_actor = i % train_config.update_actor_freq == 0
                loss_qvalue_val, loss_actor_val = train_step(
                    batch, loss_module, optimizer_actor, optimizer_critic,
                    target_net_updater, update_actor
                )

                recent_qvalue_losses.append(loss_qvalue_val)
                if loss_actor_val is not None:
                    recent_actor_losses.append(loss_actor_val)

            avg_qvalue_loss = sum(recent_qvalue_losses) / len(recent_qvalue_losses)
            avg_actor_loss = sum(recent_actor_losses) / len(recent_actor_losses)
            print(f"step={total_train_steps} Loss/Critic={avg_qvalue_loss} Loss/Actor={avg_actor_loss}")

            total_train_steps += 1

        except KeyboardInterrupt:
            print("Training interrupted by user.")

        except Exception as ex:
            print(f"Error while training: {ex}")

finally:
    print("Training finished")
    print(f"Final buffer size: {len(filtered_buffer)} samples")

In [None]:
for _ in range(5):
    sample = filtered_buffer.sample(1)[0]  
    sample_action = sample["action"]

    actor.eval()
    with torch.no_grad():
        action = actor(sample)

    print("Observation:", sample["observation"])
    print("Action (from buffer):", sample_action)
    print("Action predicted by actor:", action["action"])

    qvalue.eval()
    with torch.no_grad():
        q_pred = qvalue(sample)  

    print("Predicted Q-value:", q_pred["state_action_value"])
    print()

In [None]:
error_mean_list = []
error_std_list = []
kp_list = []
ki_list = []
kd_list = []
reward_list = []
done_list = []

for i in range(len(buffer)):
    sample = buffer[i]
    
    error_mean_list.append(sample["observation"][0].cpu())
    error_std_list.append(sample["observation"][1].cpu())
    
    kp_list.append(sample["action"][0].cpu())
    ki_list.append(sample["action"][1].cpu())
    kd_list.append(sample["action"][2].cpu())
    
    reward_list.append(sample["next"]["reward"].cpu())
    done_list.append(sample["done"].cpu())

error_mean = torch.stack(error_mean_list)
error_std = torch.stack(error_std_list)
kp = torch.stack(kp_list)
ki = torch.stack(ki_list)
kd = torch.stack(kd_list)
reward = torch.stack(reward_list)
done = torch.stack(done_list)

def stats(tensor, name):
    print(f"=== {name} ===")
    print("Min:", tensor.min().item())
    print("Max:", tensor.max().item())
    print("Mean:", tensor.mean().item())
    print("Std:", tensor.std().item())
    print()

stats(error_mean, "Error Mean (obs[0])")
stats(error_std, "Error Std (obs[1])")
stats(kp, "KP")
stats(ki, "KI")
stats(kd, "KD")
stats(reward, "Reward")
stats(done.float(), "Done flags")