#two networks need to be trained in this ALgorithm , typical actor critic framework
# first network is policy network which basically outputs the probability distribution of all different set of actions in the env
#next is the value network function which will return   value which basically suggests how good the current postion is

# Actor Network
input_layer → hidden_layers → output_layer(softmax for discrete actions)

# Critic Network  
input_layer → hidden_layers → output_layer(single value)

In [None]:
#below is sample training loop, need to understand a bit more
"""
for epoch in range(num_epochs):
    # 1. Data Collection
    collect_trajectories()
    
    # 2. Advantage Calculation
    compute_advantages_and_returns()
    
    # 3. Policy Update (multiple times)
    for _ in range(policy_updates):
        update_actor()
        if kl_divergence > threshold:
            break  # Early stopping
    
    # 4. Value Function Update
    for _ in range(value_updates):
        update_critic()
"""
#below is the policy gradient theorum and is foundation for PPO and later we add kinda clipping to it.   
"""
# Basic policy gradient (REINFORCE)
for episode in episodes:
    actions, states, rewards = collect_episode()
    returns = compute_returns(rewards)  # Future rewards
    
    for t in range(len(actions)):
        loss = -log_prob(actions[t]) * returns[t]
        loss.backward()  # Update policy to increase good actions
"""        

In [None]:
#dry run for REINFORCE Algorithm ( simpple loss fxn )
#THIS CELL IS TO TRY OUT IMPLEMENTING REINFORCE ALGORITHM FIRST TO GET USED TO POLICY GRADIENT METHODS, NEXT ONWARDS WILL TRY RUNNING ppo WITH GAE Advantage function
import gymnasium as gym # Use gymnasium instead of gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

#Hyperparameters
learning_rate = 0.0002
gamma         = 0.98

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.data = []

        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=0)
        return x

    def put_data(self, item):
        self.data.append(item)

    def train_net(self):
        R = 0
        self.optimizer.zero_grad()
        for r, prob in self.data[::-1]:
            R = r + gamma * R
            loss = -torch.log(prob) * R
            loss.backward()
        self.optimizer.step()
        self.data = []

def main():
    env = gym.make('CartPole-v1')
    pi = Policy()
    score = 0.0
    print_interval = 20


    for n_epi in range(2500): # Reduced episodes for demonstration
        s, info = env.reset() # gymnasium's reset returns observation and info
        done = False
        truncated = False # gymnasium's step returns terminated and truncated

        while not done and not truncated: # CartPole-v1 forced to terminates at 500 step.
            prob = pi(torch.from_numpy(s).float())
            m = Categorical(prob)
            a = m.sample()
            s_prime, r, done, truncated, info = env.step(a.item())
            pi.put_data((r,prob[a]))
            s = s_prime
            score += r

        pi.train_net()

        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {}".format(n_epi, score/print_interval))
            score = 0.0
    env.close()

    # Save the model's state dictionary (saved in drive)
    torch.save(pi.state_dict(), 'reinforce_cartpole_policy.pth')
    print("Model saved to reinforce_cartpole_policy.pth")


if __name__ == '__main__':
    main()

#Observations in colab -> 
#ran it for 2500 episodes, average reward at the end was 350. Best performing agent score 450+ out of 500 and this might need more episodes
# but point is proven that a simple gradient function can help learn the agent small and discrete environments like these

In [None]:
# Import necessary libraries for video recording
#Run thsi in colab to check the downloaded file that has agent playing the game
from gymnasium.wrappers import RecordVideo
import base64
import io
from IPython.display import HTML

# Instantiate the policy network
pi = Policy()

# Load the saved model state dictionary
try:
    pi.load_state_dict(torch.load('reinforce_cartpole_policy.pth'))
    print("Model loaded successfully.")
except FileNotFoundError:
    print("Error: Model file 'reinforce_cartpole_policy.pth' not found. Please run the training cell first to save the model.")
except Exception as e:
    print(f"An error occurred while loading the model: {e}")

# Create the environment with the RecordVideo wrapper
# The video will be saved in a folder named 'videos'
env_video = RecordVideo(gym.make('CartPole-v1', render_mode='rgb_array'), video_folder='videos')

# Run one episode using the trained policy
for i in range(10): #100 episodes to demonstrate
  s, info = env_video.reset()
  done = False
  truncated = False
  while not done and not truncated:
      # Ensure the model is in evaluation mode
      pi.eval()
      with torch.no_grad():
          prob = pi(torch.from_numpy(s).float())
      m = Categorical(prob)
      a = m.sample()
      s_prime, r, done, truncated, info = env_video.step(a.item())
      s = s_prime

env_video.close()

# Function to display the video in the notebook
def display_video(video_path):
    mp4 = open(video_path, 'rb').read()
    data_url = "data:video/mp4;base64," + base64.b64encode(mp4).decode()
    return HTML("""
    <video width="640" height="480" controls>
        <source src="%s" type="video/mp4">
    </video>
    """ % data_url)

# Display the last recorded video
import glob
import os

list_of_files = glob.glob('videos/*.mp4')
latest_file = max(list_of_files, key=os.path.getctime) if list_of_files else None

if latest_file:
    display_video(latest_file)