## (Successful): Check latest entry (SESSION 7) in the **.log** file for results along with **take_2_1_1.png** for the reward vs episode graph

In [None]:
import Logger
from Interrupt import *

log = Logger.Logger()


In [None]:
import torch
from torch import nn
import gym
import matplotlib.pyplot as plt

log.start_session()
log("Beginning Test Run - 2")

class PolicyNet(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )
    def forward(self, state):
        return self.net(state)

env = gym.make("CartPole-v1", max_episode_steps=1000, render_mode = 'human')
policy = PolicyNet(env.observation_space.shape[0], env.action_space.n).cuda()
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)

gamma = 0.99
episodes = 5000
episode_rewards = []

flag = CheckFlag()

for episode in range(episodes):

    flag = CheckFlag()

    if flag:
        log("Interrupt recieved... Terminating.")
        break

    state, _ = env.reset()
    log_probs = []
    rewards = []
    total_reward = 0

    done = False
    while not done:
        state = torch.tensor(state, dtype=torch.float32).cuda()
        probs = policy(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        next_state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated

        log_probs.append(log_prob)
        rewards.append(reward)
        total_reward += reward
        state = next_state

    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.tensor(returns, dtype=torch.float32).cuda()
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)

    loss = 0
    for log_prob, Gt in zip(log_probs, returns):
        loss += -log_prob * Gt
    loss = loss.sum()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    episode_rewards.append(total_reward)
    N=1
    if (episode + 1) % N == 0:
        avg = sum(episode_rewards[-N:]) / N
        # log(f"Episode {episode+1}, Avg Loss (last {n}): {avg:.2f}")
        log(f"Episode {episode+1}, Avg Reward (last {N}): {avg:.2f}")

env.close()

plt.plot(episode_rewards)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("REINFORCE on CartPole")
