In [1]:
import numpy as np
import torch
import os
from tinyphysics import TinyPhysicsModel, TinyPhysicsSimulator
from controllers import sacController
from sac import SAC, ActorNetwork, CriticNetwork, ReplayBuffer, Environment, PolicyEvaluator

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Define paths and parameters
model_path = "./models/tinyphysics.onnx"
data_folder = "./data"
controller = sacController.Controller
debug = False

# Initialize actor and critic networks
actor = ActorNetwork()
critic = CriticNetwork()

# Initialize replay buffer
state_dim = 4  # State dimension
action_dim = 2  # Action dimension
replay_buffer = ReplayBuffer(state_dim, action_dim)

# Initialize SAC agent
agent = SAC(actor, critic, replay_buffer)

# Train the SAC agent
num_episodes = 100
max_steps_per_episode = 1000
batch_size = 64

# Iterate over each CSV file in the data folder
for file_name in os.listdir(data_folder):
    if file_name.endswith(".csv"):
        data_path = os.path.join(data_folder, file_name)
        # Initialize environment for current CSV file
        env = Environment(model_path, data_path, controller, debug=debug)
        print(f"Training on data file: {file_name}")
        agent.train(num_episodes, max_steps_per_episode, batch_size, env)

# Evaluate the trained policy
env = Environment(model_path, data_path, debug=debug)  # Use the last data file for evaluation
evaluator = PolicyEvaluator(env, agent)
avg_reward, std_reward = evaluator.evaluate()


Training on data file: 00000.csv


IndexError: list index out of range