# Эксперимент по обучению с использованием сохраненных данных из буфера воспроизведения

In [None]:
from pathlib import Path
import torch
from omegaconf import OmegaConf
from nn_laser_stabilizer.data.utils import make_buffer

exp_dir = Path("../experiments/test/2025-09-29_15-58-48")
config_path = exp_dir / ".hydra" / "config.yaml"

config = OmegaConf.load(config_path)

buffer = make_buffer(config)

buffer_path = exp_dir / "saved_data/replay_buffer.pkl"
buffer.load(buffer_path) 

print(f"Replay buffer загружен. Размер: {len(buffer)} сэмплов")


In [None]:
batch_size = 64
batch = buffer.sample(batch_size)
print(f"Сэмплирован batch из {len(batch)} сэмплов")

In [None]:
sample = batch[0]

print("Поля сэмпла:", sample.keys())

print("\nObservation:", sample["observation"])
print("Action:", sample["action"])
print("Reward:", sample["next"]["reward"])
print("Next Observation:", sample["next"]["observation"])
print("Done:", sample["done"])

In [None]:
obs_list = []
action_list = []
reward_list = []

for i in range(len(buffer)):
    sample = buffer[i]

    # sample["done"] = torch.tensor(True)
    # sample["terminated"] = torch.tensor(True)

    obs_list.append(sample["observation"].cpu())
    action_list.append(sample["action"].cpu())
    reward_list.append(sample["next"]["reward"].cpu())

    buffer[i] = sample  

obs_tensor = torch.stack(obs_list)
action_tensor = torch.stack(action_list)
reward_tensor = torch.stack(reward_list)

obs_mean = obs_tensor.mean(0)
obs_std = obs_tensor.std(0) + 1e-8

# action_mean = action_tensor.mean(0)
# action_std = action_tensor.std(0) + 1e-8

reward_max = reward_tensor.abs().max()
reward_min = reward_tensor.min()
reward_max_val = reward_max if reward_max != 0 else 1.0

for i in range(len(buffer)):
    sample = buffer[i]

    sample["observation"] = (sample["observation"] - obs_mean) / obs_std
    sample["next"]["observation"] = (sample["next"]["observation"] - obs_mean) / obs_std

    sample["next"]["reward"] = sample["next"]["reward"] / reward_max_val

    buffer[i] = sample  

In [None]:
from collections import deque
from nn_laser_stabilizer.envs.utils import make_specs
from nn_laser_stabilizer.logging.utils import set_seeds
from nn_laser_stabilizer.agents.td3 import (
    make_td3_agent,
    make_loss_module,
    make_optimizers,
    make_target_updater,
    train_step,
    warmup_from_specs
)

set_seeds(config.seed)

# CHECK
config.agent.gamma = 0.0

specs = make_specs(config.env.bounds)
action_spec = specs["action"]
observation_spec = specs["observation"]

actor, qvalue = make_td3_agent(config, observation_spec, action_spec)
warmup_from_specs(observation_spec, action_spec, actor, qvalue)

loss_module = make_loss_module(config, actor, qvalue, action_spec)
optimizer_actor, optimizer_critic = make_optimizers(config, loss_module)
target_net_updater = make_target_updater(config, loss_module)

train_config = config.train

total_train_steps = 0
recent_qvalue_losses = deque(maxlen=train_config.update_to_data)
recent_actor_losses = deque(maxlen=train_config.update_to_data // train_config.update_actor_freq)

print("Training process initiated")

for i in range(len(buffer)):
    sample = buffer[i]
    sample["next"]["done"] = torch.tensor(True)
    sample["next"]["terminated"] = torch.tensor(True) 
    buffer[i] = sample 

for _ in range(1000):
    try:
        for i in range(train_config.update_to_data):
            batch = buffer.sample(train_config.batch_size)
            update_actor = i % train_config.update_actor_freq == 0
            loss_qvalue_val, loss_actor_val = train_step(
                batch, loss_module, optimizer_actor, optimizer_critic,
                target_net_updater, update_actor
            )

            recent_qvalue_losses.append(loss_qvalue_val)
            if loss_actor_val is not None:
                recent_actor_losses.append(loss_actor_val)

        avg_qvalue_loss = sum(recent_qvalue_losses) / len(recent_qvalue_losses)
        avg_actor_loss = sum(recent_actor_losses) / len(recent_actor_losses)
        print(f"step={total_train_steps} Loss/Critic={avg_qvalue_loss} Loss/Actor={avg_actor_loss}")

        total_train_steps += 1

    except KeyboardInterrupt:
        print("Training interrupted by user.")

    except Exception as ex:
        print(f"Error while training: {ex}")

    finally:
        print("Training finished")
        print(f"Final buffer size: {len(buffer)} samples")

In [None]:
for _ in range(5):
    sample = buffer.sample(1)[0]  

    actor.eval()
    with torch.no_grad():
        action = actor(sample)

    print("Observation:", sample["observation"])
    print("Action (from buffer):", sample["action"])
    print("Action predicted by actor:", action["action"])

    qvalue.eval()
    with torch.no_grad():
        q_pred = qvalue(sample)  

    print("Predicted Q-value:", q_pred["state_action_value"])
    print()

In [None]:
error_mean_list = []
error_std_list = []
kp_list = []
ki_list = []
kd_list = []
reward_list = []
done_list = []

for i in range(len(buffer)):
    sample = buffer[i]
    
    error_mean_list.append(sample["observation"][0].cpu())
    error_std_list.append(sample["observation"][1].cpu())
    
    kp_list.append(sample["action"][0].cpu())
    ki_list.append(sample["action"][1].cpu())
    kd_list.append(sample["action"][2].cpu())
    
    reward_list.append(sample["next"]["reward"].cpu())
    done_list.append(sample["done"].cpu())

error_mean = torch.stack(error_mean_list)
error_std = torch.stack(error_std_list)
kp = torch.stack(kp_list)
ki = torch.stack(ki_list)
kd = torch.stack(kd_list)
reward = torch.stack(reward_list)
done = torch.stack(done_list)

def stats(tensor, name):
    print(f"=== {name} ===")
    print("Min:", tensor.min().item())
    print("Max:", tensor.max().item())
    print("Mean:", tensor.mean().item())
    print("Std:", tensor.std().item())
    print()

stats(error_mean, "Error Mean (obs[0])")
stats(error_std, "Error Std (obs[1])")
stats(kp, "KP")
stats(ki, "KI")
stats(kd, "KD")
stats(reward, "Reward")
stats(done.float(), "Done flags")