In [1]:
import torch
import torch.optim as optim
from src.Environment import Environment
from src.actor_critic import ActorCritic
import torch.distributions as dist
from torch.distributions import Categorical

In [2]:
def train(n_episodes, max_steps):
    env = Environment()
    for episode in range(n_episodes):
        state = env.reset(45,135)
        episode_reward = 0
        
        for step in range(max_steps):
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action_probs, state_value = ac_net(state_tensor)
            
            # Sample action from the probability distribution
            dist = Categorical(action_probs)
            action = dist.sample()
            # print(action)
            
            # Take action in the environment
            next_state, reward, done, _ = env.step(action.item())
            episode_reward += reward
            
            # Compute TD error
            next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
            _, next_state_value = ac_net(next_state_tensor)
            td_error = reward + (0.99 * next_state_value * (1 - int(done))) - state_value
            
            # Compute losses
            actor_loss = dist.log_prob(action) * td_error.detach()
            critic_loss = td_error * state_value
            print(actor_loss,"..", critic_loss)
            loss = actor_loss*1e11 + critic_loss/1000
            
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if done:
                break
            
            state = next_state
        
        print(f"Episode {episode+1}, Reward: {episode_reward}")


In [3]:
# Train the agent
ac_net = ActorCritic(7,4)
optimizer = optim.Adam(ac_net.parameters(), lr = 1e-2)
train(n_episodes=5, max_steps=10)

tensor([[2.5815e-06]], grad_fn=<MulBackward0>) .. tensor([[-6501.2476]], grad_fn=<MulBackward0>)
tensor([[2.9068e-06]], grad_fn=<MulBackward0>) .. tensor([[-13973.2412]], grad_fn=<MulBackward0>)
tensor([[3.2343e-06]], grad_fn=<MulBackward0>) .. tensor([[-23001.7383]], grad_fn=<MulBackward0>)
tensor([[3.5649e-06]], grad_fn=<MulBackward0>) .. tensor([[-33647.9805]], grad_fn=<MulBackward0>)
tensor([[3.9007e-06]], grad_fn=<MulBackward0>) .. tensor([[-46033.2500]], grad_fn=<MulBackward0>)
tensor([[4.2433e-06]], grad_fn=<MulBackward0>) .. tensor([[-60305.3203]], grad_fn=<MulBackward0>)
tensor([[4.5948e-06]], grad_fn=<MulBackward0>) .. tensor([[-76664.9844]], grad_fn=<MulBackward0>)
tensor([[4.9585e-06]], grad_fn=<MulBackward0>) .. tensor([[-95425.0234]], grad_fn=<MulBackward0>)
tensor([[5.3371e-06]], grad_fn=<MulBackward0>) .. tensor([[-116934.1484]], grad_fn=<MulBackward0>)
tensor([[5.7330e-06]], grad_fn=<MulBackward0>) .. tensor([[-141576.5938]], grad_fn=<MulBackward0>)
Episode 1, Reward: 