In [None]:
import os
import torch
import gymnasium as gym
import numpy as np
import time
from torch.distributions import Beta
import torch.optim as optim
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import torch.nn as nn
import matplotlib.pyplot as plt
import cv2



# --- Masking Pipeline ---
def apply_gaussian_blur(channel, ksize=5):
    return cv2.GaussianBlur(channel, (ksize, ksize), 0)

def adjust_brightness(channel, factor=1.2):
    channel = np.clip(np.float32(channel) * factor, 0, 255)
    return np.uint8(channel)

def adjust_saturation_single_channel(channel, factor=1.5):
    rgb = cv2.merge([channel, channel, channel])
    hsv = cv2.cvtColor(rgb, cv2.COLOR_RGB2HSV)
    hsv[..., 1] = np.clip(hsv[..., 1] * factor, 0, 255)
    saturated = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    return cv2.cvtColor(saturated, cv2.COLOR_RGB2GRAY)

def extract_center_value(channel, region_size=7, offset_y=10):
    h, w = channel.shape
    cx, cy = w // 2, h // 2
    offsets = [-offset_y, 0, offset_y]
    values = []
    for off in offsets:
        y1 = cy + off - region_size // 2
        y2 = cy + off + region_size // 2
        x1 = cx - region_size // 2
        x2 = cx + region_size // 2
        region = channel[y1:y2, x1:x2]
        values.append(np.mean(region))
    return np.mean(values)

def extract_mask(channel, center_value, tolerance=25):
    diff = np.abs(channel.astype(np.int16) - int(center_value))
    mask = np.uint8(diff < tolerance) * 255
    kernel = np.ones((3, 3), np.uint8)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
    return mask

def mask_quality(mask):
    white = np.sum(mask == 255)
    total = mask.size
    return white / total

def combine_masks(mask_r, mask_g, mask_b, threshold=127):
    qualities = [mask_quality(mask_r), mask_quality(mask_g), mask_quality(mask_b)]
    valid = [(0.05 < q < 0.29) for q in qualities]
    masks = [mask_r, mask_g, mask_b]
    used = [m for m, v in zip(masks, valid) if v]
    if not used:
        return np.zeros_like(mask_r)
    weight = 1.0 / len(used)
    combined = sum(weight * m for m in used)
    _, binary = cv2.threshold(combined.astype(np.uint8), threshold, 255, cv2.THRESH_BINARY)
    return binary

def get_binary_mask(obs, frame_count, center_color):
    r, g, b = cv2.split(obs)
    r_proc = adjust_saturation_single_channel(adjust_brightness(apply_gaussian_blur(r, 5), factor=0.9), factor=1.5)
    g_proc = adjust_saturation_single_channel(adjust_brightness(apply_gaussian_blur(g, 5), factor=0.9), factor=1.5)
    b_proc = adjust_saturation_single_channel(adjust_brightness(apply_gaussian_blur(b, 5), factor=0.9), factor=1.5)

    if frame_count == 15:
        r_val = extract_center_value(r)
        g_val = extract_center_value(g)
        b_val = extract_center_value(b)
        center_color[:] = [r_val, g_val, b_val]

    if center_color[0] is not None:
        mask_r = extract_mask(r_proc, center_color[0])
        mask_g = extract_mask(g_proc, center_color[1])
        mask_b = extract_mask(b_proc, center_color[2])
        mask = combine_masks(mask_r, mask_g, mask_b)
        # print(f"Mask shape: {mask.shape}, Mask quality: {mask_quality(mask)}")  # Debugging the mask
        return mask
    else:
        return np.zeros_like(r)



class Net(nn.Module):
    """
    Convolutional Neural Network for PPO
    """

    def __init__(self, img_stack):
        super(Net, self).__init__()
        self.cnn_base = nn.Sequential(  # input shape (4, 96, 96)
            nn.Conv2d(img_stack, 8, kernel_size=4, stride=2),
            nn.ReLU(),  # activation
            nn.Conv2d(8, 16, kernel_size=3, stride=2),  # (8, 47, 47)
            nn.ReLU(),  # activation
            nn.Conv2d(16, 32, kernel_size=3, stride=2),  # (16, 23, 23)
            nn.ReLU(),  # activation
            nn.Conv2d(32, 64, kernel_size=3, stride=2),  # (32, 11, 11)
            nn.ReLU(),  # activation
            nn.Conv2d(64, 128, kernel_size=3, stride=1),  # (64, 5, 5)
            nn.ReLU(),  # activation
            nn.Conv2d(128, 256, kernel_size=3, stride=1),  # (128, 3, 3)
            nn.ReLU(),  # activation
        )  # output shape (256, 1, 1)
        self.v = nn.Sequential(nn.Linear(256, 100), nn.ReLU(), nn.Linear(100, 1))
        self.fc = nn.Sequential(nn.Linear(256, 100), nn.ReLU())
        self.alpha_head = nn.Sequential(nn.Linear(100, 3), nn.Softplus())
        self.beta_head = nn.Sequential(nn.Linear(100, 3), nn.Softplus())
        self.apply(self._weights_init)

    @staticmethod
    def _weights_init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
            nn.init.constant_(m.bias, 0.1)

    def forward(self, x):
        x = self.cnn_base(x)
        x = x.view(-1, 256)
        v = self.v(x)
        x = self.fc(x)
        alpha = self.alpha_head(x) + 1
        beta = self.beta_head(x) + 1

        return (alpha, beta), v


img_stack=4

transition = np.dtype([('s', np.float64, (img_stack, 96, 96)), 
                       ('a', np.float64, (3,)), ('a_logp', np.float64),
                       ('r', np.float64), ('s_', np.float64, (img_stack, 96, 96))])
GAMMA=0.99
EPOCH= 8 # beter than 10
MAX_SIZE = 2000 ## CUDA out of mem for max_size=10000
BATCH=128 
EPS=0.1
LEARNING_RATE = 0.001 # bettr than 0.005 or 0.002 
action_repeat = 10

def rgb2gray(rgb, norm=True):
    # Convert RGB to grayscale using the standard formula
    gray = np.dot(rgb[..., :3], [0.299, 0.587, 0.114])  # RGB to grayscale
    if norm:
        # Normalize the grayscale image to range [-1, 1]
        gray = gray / 128.0 - 1.0
    return gray
class Agent():
    """ Agent for training """
    
    def __init__(self, device):
        self.training_step = 0
        self.net = Net(img_stack).double().to(device)
        self.buffer = np.empty(MAX_SIZE, dtype=transition)
        self.counter = 0
        self.device = device
        
        self.optimizer = optim.Adam(self.net.parameters(), lr=LEARNING_RATE)  ## lr=1e-3

    def select_action(self, state):
        state = torch.from_numpy(state).double().to(self.device).unsqueeze(0)
        
        with torch.no_grad():
            alpha, beta = self.net(state)[0]
        dist = Beta(alpha, beta)
        action = dist.sample()
        a_logp = dist.log_prob(action).sum(dim=1)

        action = action.squeeze().cpu().numpy()
        a_logp = a_logp.item()
        return action, a_logp


    def store(self, transition):
        self.buffer[self.counter] = transition
        self.counter += 1
        if self.counter == MAX_SIZE:
            self.counter = 0
            return True
        else:
            return False


agent = Agent('cpu')

class Wrapper:
    """
    Environment wrapper for CarRacing 
    """

    def __init__(self, env, img_stack=4, action_repeat=4):
        self.env = env
        self.img_stack = img_stack
        self.action_repeat = action_repeat
        self.die = False
        self.stack = []
        self.frame_count = 0
        self.center_color = [None, None, None]
        self.av_r = self.reward_memory()

    def reset(self):
        self.counter = 0
        self.die = False
        img_rgb, _ = self.env.reset()
        img_gray = rgb2gray(img_rgb)
        self.stack = [img_gray] * self.img_stack
        return np.array(self.stack)

    def step(self, action):
        total_reward = 0
        for _ in range(self.action_repeat):
            action[0] = np.clip(action[0], -1.0, 1.0)
            action[1] = np.clip(action[1], 0.0, 1.0)
            action[2] = np.clip(action[2], 0.0, 1.0)
            img_rgb, reward, done, truncated, info = self.env.step(action)

            if self.die:
                reward += 100
            if np.mean(img_rgb[:, :, 1]) > 185.0:
                reward -= 0.05

            total_reward += reward
            done = True if self.av_r(reward) <= -0.1 else done
            if done or self.die:
                break

        img_gray = rgb2gray(img_rgb)
        self.stack.pop(0)
        self.stack.append(img_gray)
        assert len(self.stack) == self.img_stack

        # Compute binary mask here from img_rgb
        binary_mask = self.get_binary_mask_from_rgb(img_rgb, self.frame_count, self.center_color)

        return np.array(self.stack), total_reward, done, self.die, binary_mask

    @staticmethod
    def reward_memory():
        count = 0
        length = 100
        history = np.zeros(length)

        def memory(reward):
            nonlocal count
            history[count] = reward
            count = (count + 1) % length
            return np.mean(history)

        return memory

    
agent = Agent('cpu')

video_dir = "/Applications/Files/SEM_7/MAJOR/RL/datavideos"  # Path where videos will be saved
os.makedirs(video_dir, exist_ok=True)  # Make sure the directory exists


env = gym.make('CarRacing-v2', verbose=1, render_mode='rgb_aray', domain_randomize=False)
# env = gym.wrappers.RecordVideo(env, video_dir, episode_trigger=lambda x: True)  # Record all episodes
env_wrap = Wrapper(env)

# Load Model
def load(agent, directory, filename):
    model_path = os.path.join(directory, filename)
    agent.net.load_state_dict(torch.load(model_path))

# Play Function
from collections import deque

# Directory for saving the output
output_dir = "ppo_output_npz"
os.makedirs(output_dir, exist_ok=True)


def extract_data_npz(env_wrap, agent, n_episodes=50):
    all_alphas = []
    all_betas = []
    all_vs = []
    all_masks = []

    for i_episode in range(n_episodes):
        state = env_wrap.reset()
        done = False
        die = False
        step = 0

        episode_alphas = []
        episode_betas = []
        episode_vs = []
        episode_masks = []

        while not done and not die:
            # Forward pass
            state_tensor = torch.from_numpy(state).double().unsqueeze(0).to(agent.device)
            with torch.no_grad():
                (alpha, beta), v = agent.net(state_tensor)

            alpha = alpha.squeeze().cpu().numpy()
            beta = beta.squeeze().cpu().numpy()
            v = v.item()
            mask = get_binary_mask(state, step)

            # Append step data
            episode_alphas.append(alpha)
            episode_betas.append(beta)
            episode_vs.append(v)
            episode_masks.append(mask)

            # Take action
            action, a_logp = agent.select_action(state)
            next_state, reward, done, die = env_wrap.step(
                action * np.array([2.0, 1.0, 1.0]) + np.array([-1.0, 0.0, 0.0])
            )

            state = next_state
            step += 1

        print(f"Episode {i_episode+1} finished with {step} steps.")

        all_alphas.append(np.array(episode_alphas))  # (steps, 3)
        all_betas.append(np.array(episode_betas))    # (steps, 3)
        all_vs.append(np.array(episode_vs))          # (steps,)
        all_masks.append(np.array(episode_masks))    # (steps, 96, 96)

    # Save everything in .npz format
    np.savez_compressed(
        os.path.join(output_dir, "ppo_50_episodes_data.npz"),
        alphas=all_alphas,
        betas=all_betas,
        values=all_vs,
        masks=all_masks
    )
    print(f"\nSaved data to: {os.path.join(output_dir, 'ppo_50_episodes_data.npz')}")

load(agent, '/Applications/Files/SEM_7/MAJOR/RL/model', 'model_weights_best.pth')

# Run and save
extract_data_npz(env_wrap, agent, n_episodes=50)



# play(env, agent, n_episodes=1)
# /Users/divyansh/Downloads/model_weights_best.pth
# Close the environment
env.close()
