# PART 1: BREAKOUT WITH REINFORCE

### IMPORTS

In [None]:
"""
Dependencies (we used Kaggle for this part)
!pip install gymnasium==1.0.0
!pip install ale-py
!pip install wandb
!pip install torchsummary""" 

import gymnasium as gym
import ale_py
from gymnasium.wrappers import MaxAndSkipObservation, ResizeObservation, GrayscaleObservation, FrameStackObservation, ReshapeObservation

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from torchsummary import summary

import collections

import wandb
import datetime

import os

### PREPROCESSING

In [None]:
# version
print("Using Gymnasium version {}".format(gym.__version__))

ENV_NAME = "ALE/Breakout-v5"
test_env = gym.make(ENV_NAME, render_mode='rgb_array')

print(test_env.unwrapped.get_action_meanings())
print(test_env.observation_space.shape)

In [None]:
# Source: M3-2_Example_1a (DQN on Pong, train)
class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


def make_env(env_name):
    env = gym.make(env_name, render_mode='rgb_array')
    print("Standard Env.        : {}".format(env.observation_space.shape))
    env = MaxAndSkipObservation(env, skip=4)
    print("MaxAndSkipObservation: {}".format(env.observation_space.shape))
    #env = FireResetEnv(env)
    env = ResizeObservation(env, (84, 84))
    print("ResizeObservation    : {}".format(env.observation_space.shape))
    env = GrayscaleObservation(env, keep_dim=True)
    print("GrayscaleObservation : {}".format(env.observation_space.shape))
    env = ImageToPyTorch(env)
    print("ImageToPyTorch       : {}".format(env.observation_space.shape))
    env = ReshapeObservation(env, (84, 84))
    print("ReshapeObservation   : {}".format(env.observation_space.shape))
    env = FrameStackObservation(env, stack_size=4)
    print("FrameStackObservation: {}".format(env.observation_space.shape))
    env = ScaledFloatFrame(env)
    print("ScaledFloatFrame     : {}".format(env.observation_space.shape))

    return env

env=make_env(ENV_NAME)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

### POLICY NETWORK

In [None]:
# Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(PolicyNetwork, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten()
        )
        
        with torch.no_grad():
            sample_input = torch.zeros((1, *input_shape))
            feature_size = self.feature_extractor(sample_input).shape[1]

        # Policy head (action logits)
        self.policy = nn.Sequential(
            nn.Linear(feature_size, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_actions)
        )

        # Value head (baseline)
        self.value = nn.Sequential(
            nn.Linear(feature_size, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 1)
        )
    
    def forward(self, x):
        features = self.feature_extractor(x)
        action_logits = self.policy(features)
        state_value = self.value(features)
        return action_logits, state_value

### AGENT

In [None]:
class REINFORCEAgent:
    def __init__(self, env, device, learning_rate=1e-3, gamma=0.99, value_loss_coeff=0.5):
        self.env = env
        self.device = device
        self.gamma = gamma
        self.value_loss_coeff = value_loss_coeff

        self.policy_net = PolicyNetwork(env.observation_space.shape, env.action_space.n).to(device)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
        
        self.saved_log_probs = []
        self.rewards = []
        self.saved_values = []

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        logits, state_value = self.policy_net(state)
        action_probs = torch.softmax(logits, dim=1)

        m = Categorical(action_probs)
        action = m.sample()

        self.saved_log_probs.append(m.log_prob(action))
        self.saved_values.append(state_value)

        return action.item()

    def finish_episode(self):
        # Calculate losses
        R = 0
        returns = []
        for r in self.rewards[::-1]:
            R = r + self.gamma * R
            returns.insert(0, R)

        returns = torch.tensor(returns).to(self.device)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        policy_losses = []
        value_losses = []

        for log_prob, value, R in zip(self.saved_log_probs, self.saved_values, returns):
            # Calculate advantage
            advantage = R - value.detach()
            policy_losses.append(-log_prob * advantage)
            value_losses.append(nn.MSELoss()(value.squeeze(), torch.tensor([[R]]).to(self.device)))  # Ensure target is of shape (1, 1)

        self.optimizer.zero_grad()
        total_policy_loss = torch.stack(policy_losses).sum()
        total_value_loss = torch.stack(value_losses).sum()
        total_loss = total_policy_loss + self.value_loss_coeff * total_value_loss
        total_loss.backward()
        self.optimizer.step()
        
        del self.rewards[:]
        del self.saved_log_probs[:]
        del self.saved_values[:]

        return total_policy_loss.item(), total_value_loss.item()

### HYPERPARAMETERS

In [None]:
MAX_EPISODES = 100000
NUMBER_OF_REWARDS_TO_AVERAGE = 10
GAMMA = 0.995
LEARNING_RATE = 1e-3
VALUE_LOSS_COEFF = 0.5

### TRAINING LOOP

In [None]:
def train():
    wandb.login(key="YOUR_API_KEY")
    
    wandb.init(project="breakout-reinforce", config={
        "gamma": GAMMA,
        "learning_rate": LEARNING_RATE,
        "value_loss_coeff": VALUE_LOSS_COEFF,
    })
    
    env = make_env(ENV_NAME)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    agent = REINFORCEAgent(env, device, learning_rate=LEARNING_RATE, gamma=GAMMA, value_loss_coeff=VALUE_LOSS_COEFF)
    
    total_rewards = []
    best_mean_reward = None
    
    for episode in range(MAX_EPISODES):
        state, _ = env.reset()
        episode_reward = 0
        steps = 0 
        
        while True:
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            agent.rewards.append(reward)
            episode_reward += reward
            state = next_state
            steps += 1
            
            if done:
                total_rewards.append(episode_reward)
                mean_reward = np.mean(total_rewards[-NUMBER_OF_REWARDS_TO_AVERAGE:])
                
                policy_loss, value_loss = agent.finish_episode()
                
                # Log metrics with WandB
                wandb.log({
                    "episode": episode,
                    "reward": episode_reward,
                    "mean_reward": mean_reward,
                    "policy_loss": policy_loss,
                    "value_loss": value_loss,
                    "steps_per_episode": steps
                })
                
                # Save the best-performing model
                if best_mean_reward is None or best_mean_reward < mean_reward:
                    torch.save(agent.policy_net.state_dict(), "../../models/breakout/REINFORCE_policy_net.dat")
                    best_mean_reward = mean_reward
                
                print(f"Episode {episode}, reward: {episode_reward:.2f}, mean reward: {mean_reward:.2f}")
                
                break

#### GridSearch for Hyperparameters

In [None]:
"""from itertools import product

GAMMA_VALUES = [0.95, 0.99, 0.999]
LEARNING_RATE_VALUES = [1e-3, 1e-4, 1e-5]
VALUE_LOSS_COEFF_VALUES = [0.1, 0.5, 1.0]

def train_with_hyperparameters():
    for gamma, learning_rate, value_loss_coeff in product(GAMMA_VALUES, LEARNING_RATE_VALUES, VALUE_LOSS_COEFF_VALUES):
        print(f"Training with GAMMA={gamma}, LEARNING_RATE={learning_rate}, VALUE_LOSS_COEFF={value_loss_coeff}")
        
        wandb.init(
            project="breakout-reinforce-hyperparam-search",
            config={
                "gamma": gamma,
                "learning_rate": learning_rate,
                "value_loss_coeff": value_loss_coeff,
            },
            reinit=True  # Allow multiple runs in the same script
        )

        env = make_env(ENV_NAME)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        agent = REINFORCEAgent(env, device, learning_rate=learning_rate, gamma=gamma, value_loss_coeff=value_loss_coeff)

        total_rewards = []
        best_mean_reward = None

        for episode in range(2000): 
            state, _ = env.reset()
            episode_reward = 0
            steps = 0

            while True:
                action = agent.select_action(state)
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated

                agent.rewards.append(reward)
                episode_reward += reward
                state = next_state
                steps += 1

                if done:
                    total_rewards.append(episode_reward)
                    mean_reward = np.mean(total_rewards[-10:]) 

                    policy_loss, value_loss = agent.finish_episode()

                    # Log metrics with WandB
                    wandb.log({
                        "episode": episode,
                        "reward": episode_reward,
                        "mean_reward": mean_reward,
                        "policy_loss": policy_loss,
                        "value_loss": value_loss,
                        "steps_per_episode": steps,
                    })

                    # Save the best-performing model
                    if best_mean_reward is None or best_mean_reward < mean_reward:
                        model_name = f"policy_net_gamma{gamma}_lr{learning_rate}_vlc{value_loss_coeff}.dat"
                        torch.save(agent.policy_net.state_dict(), f"/kaggle/working/{model_name}")
                        best_mean_reward = mean_reward

                    print(f"Ep {episode}, reward: {episode_reward:.2f}, mean reward: {mean_reward:.2f}, "
                          f"gamma={gamma}, lr={learning_rate}, vlc={value_loss_coeff}")

                    break

        # Log final results for the current combination
        print(f"Completed training for GAMMA={gamma}, LEARNING_RATE={learning_rate}, VALUE_LOSS_COEFF={value_loss_coeff}")
        wandb.finish()

# Main Body
print("Training starts at", datetime.datetime.now())
train_with_hyperparameters()
print("Training ends at", datetime.datetime.now())"""

### MAIN BODY

In [None]:
print("Training starts at", datetime.datetime.now())
train()
print("Training ends at", datetime.datetime.now())
wandb.finish()

### LOAD MODELS

In [None]:
policy_net = PolicyNetwork(env.observation_space.shape, env.action_space.n).to(device)

policy_net.load_state_dict(torch.load("model_REINFORCE.dat", map_location=torch.device("cpu")))

policy_net.eval() #Remove dropout during inference

### MAKING GIF

In [None]:
from PIL import Image
import time

# Parameters
visualize = True
images = []
gif_file = "video_REINFORCE.gif"

# Reset environment
state, _ = env.reset()
total_reward = 0.0

# Play one episode
while True:
    start_ts = time.time()

    if visualize:
        # Render the environment's frame (for RGB environments)
        img = env.render()
        images.append(Image.fromarray(img))  # Store for GIF creation

    # Convert state to tensor and get the action from the policy network
    state_tensor = torch.tensor(np.array([state], copy=False)).float().to(device)
    logits, _ = policy_net(state_tensor)  # Assuming policy_net returns action logits and state value
    action_probs = torch.softmax(logits, dim=1)  # Convert logits to probabilities
    action = torch.multinomial(action_probs, 1).item()  # Sample action from probability distribution

    # Step in the environment
    state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated

    total_reward += reward

    if done:
        break

print(f"Total reward: {total_reward:.2f}")

# Create GIF from the frames collected
images[0].save(f"../../videos/breakout/{gif_file}", save_all=True, append_images=images[1:], duration=60, loop=0)
print(f"Episode exported to ../../videos/breakout/{gif_file}")
