In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from environment2 import Env

# Actor network
class Actor(nn.Module):
    def __init__(self, input_size, output_size):
        super(Actor, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 24),
            nn.ReLU(),
            nn.Linear(24, 24),
            nn.ReLU(),
            nn.Linear(24, output_size),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.fc(x)

# REINFORCE agent
class REINFORCEAgent:
    def __init__(self, input_size, output_size):
        self.actor = Actor(input_size, output_size)
        self.optimizer = optim.Adam(self.actor.parameters(), lr=0.01)
        self.gamma = 0.99
        self.log_probs = []
        self.rewards = []

    def select_action(self, state):
        state = torch.FloatTensor(state)
        print(f"state float tensor {state}")
        probs = self.actor(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        self.log_probs.append(dist.log_prob(action))
        return action.item()

    def update_policy(self):
        R = 0
        policy_loss = []
        returns = []
        for r in self.rewards[::-1]:
            R = r + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-5)
        for log_prob, R in zip(self.log_probs, returns):
            policy_loss.append(-log_prob * R)
        self.optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward()
        self.optimizer.step()
        self.log_probs = []
        self.rewards = []

if __name__ == "__main__":
    env = Env()
    agent = REINFORCEAgent(input_size=len(env.reset()), output_size=env.action_size)
    for episode in range(1000):
        state = env.reset()
        print(state)
        total_reward = 0

        while True:
            action = agent.select_action(state)
            next_state, reward, done = env.step(action)
            total_reward += reward
            agent.rewards.append(reward)

            if done:
                agent.update_policy()
                print("Episode:", episode, "Total Reward:", total_reward)
                break

            state = next_state
