In [1]:
import os
os.environ["ROM_DIR"] = "/net/csefiles/xzhanglab/mlobo6/miniconda3/envs/drl/lib/python3.10/site-packages/AutoROM/roms"

In [2]:
import gymnasium as gym
import ale_py
from ale_py import ALEInterface

In [3]:
%%capture
# @title Imports
import torch
from torch.distributions import Categorical
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import random
import copy
import numpy as np

# code should work on either, faster on gpu
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# random seeds for reproducability
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
ale = ALEInterface()
gym.register_envs(ale_py)

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


In [5]:
# @title Define Replay Buffer
class ReplayBuffer: # why have a replay buffer ? 
    def __init__(self, capacity=100_000):
        self.buffer = deque(maxlen=capacity)
        self.batch_size = 256

    def store(self, state, skill, action, reward, next_state, done):
        transitions = list(zip(state, skill, action, reward, next_state, 1 - torch.Tensor(done)))
        self.buffer.extend(transitions)

    def sample(self):
        batch = random.sample(self.buffer, self.batch_size)
        return [torch.stack(e).to(device) for e in zip(*batch)]  # state, skill, action, reward, next_state, not_done

    def __len__(self):
        return len(self.buffer)

In [6]:
# @title Visualization Code
import os
from gym.wrappers import RecordVideo
from IPython.display import Video, display, clear_output

def visualize(policy, n_episodes=10, max_steps=1000):
    """Visualize agent in an Atari environment."""

    env = gym.make("ALE/MsPacman-v5", render_mode="rgb_array")
    env = gym.wrappers.RecordVideo(
        env,
        episode_trigger=lambda num: num % 2 == 0,
        video_folder="./videos",
        name_prefix="pacman-video",
    )

    for episode in range(n_episodes):
        obs, info = env.reset()
        episode_over = False

        for i in range(max_steps):
            action = env.action_space.sample()
            obs, reward, done, truncated, info = env.step(action)

            episode_over = done or truncated
            if episode_over:
              break

    env.close()

    # Display the latest video
    clear_output(wait=True)
    display(Video("./rl-video-episode-0.mp4", embed=True))

In [7]:
# @title Evaluation Code

def evaluate(policy):

    # Create environment in rgb_array mode
    env = gym.make("InvertedPendulum-v5", reset_noise_scale=0.1, frame_skip=5)

    n = 3
    mean_duration = 0
    for i in range(n):
        obs, _ = env.reset()
        done, t = False, 0
        while not done and t < 200:
            with torch.no_grad():
                actions = policy(torch.Tensor(obs).to(device)[None, :])[:, 0]
            obs, _, done, _, _ = env.step(actions.cpu().numpy())
            t += 1

        mean_duration += t
        print(f"trial {i+1}/{n} lasted {t*.1:.3f} seconds")

    env.close()
    print(f"\nmean duration: {(mean_duration * .1 / n):.3f} seconds")


In [8]:
# Launch TensorBoard
%load_ext tensorboard
%tensorboard --logdir runs

Reusing TensorBoard on port 6006 (pid 2154436), started 1 day, 3:59:38 ago. (Use '!kill 2154436' to kill it.)

In [9]:
import gymnasium as gym
import ale_py
from ale_py import ALEInterface

import torch
from torch.distributions import Categorical
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import random
import copy
import numpy as np
from datetime import datetime

import cv2
import matplotlib.pyplot as plt
import imageio
import tqdm

from sklearn.manifold import TSNE
from torchvision.transforms import ToTensor, Resize
from PIL import Image
import gc

import os
from gymnasium.wrappers import RecordVideo


In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


Using device: cuda


In [11]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7c6f537668c0>

In [12]:
now = datetime.now()
timestamp = now.strftime("%d_%H%M%S")
writer = SummaryWriter(log_dir=f'runs/SAC-Discrete-{timestamp}')

In [13]:
ale = ALEInterface()
gym.register_envs(ale_py)

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


In [14]:
def cleanup_memory():
    gc.collect()  # Clean up CPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()            # Release cached GPU memory (PyTorch only)
        torch.cuda.ipc_collect()            # Clean up interprocess memory (multi-GPU safe)


In [15]:
import warnings

In [16]:
# store experience in a replay buffer, sample from it 
class ReplayBuffer:
    def __init__(self, capacity=100_000):
        self.buffer = deque(maxlen=capacity)
        self.batch_size = 512

    def store(self, state, skill, action, reward, next_state, done):
        transitions = list(zip(state, skill, action, reward, next_state, 1 - torch.Tensor(done)))
        self.buffer.extend(transitions)

    def sample(self):
        if len(self.buffer) < self.batch_size:
            batch = random.choices(self.buffer, k=self.batch_size)
            warnings.warn(f"Requested batch size {self.batch_size} is larger than buffer size \
                 {len(self.buffer)}. Sampling with replacement.", category=UserWarning)

        else:
            batch = random.sample(self.buffer, self.batch_size)

        return [torch.stack(e).to(device) for e in zip(*batch)]  # state, skill, action, reward, next_state, not_done
    
    def sample_by_skill(self, skill_id, num_samples=128): ## THIS METHOD ASSUMES ONE HOT ENCODED SKILLS
        filtered = [transition for transition in self.buffer
                    if torch.argmax(transition[1]).item() == skill_id]

        if len(filtered) < num_samples:
            batch = filtered
            warnings.warn(f"Not enough samples for skill {skill_id}. Sampling with replacement.", category=UserWarning)
        else:
            batch = random.sample(filtered, num_samples)

        return [torch.stack(e).to(device) for e in zip(*batch)]


    def __len__(self):
        return len(self.buffer)


In [17]:
# # to encode input pacman images 
# class CNNEncoder(nn.Module):
#     def __init__(self, out_dim):
#         super().__init__()
#         self.conv = nn.Sequential(
#             nn.Conv2d(3, 32, kernel_size=8, stride=4),  # (3, 210, 160) → (32, 52, 39)
#             nn.ReLU(),
#             nn.Conv2d(32, 64, kernel_size=4, stride=2), # → (64, 25, 18)
#             nn.ReLU(),
#             nn.Conv2d(64, 64, kernel_size=3, stride=1), # → (64, 23, 16)
#             nn.ReLU()
#         ).to(device)
#         self.fc = nn.Sequential(
#             nn.Flatten(),
#             nn.Linear(64 * 22 * 16, out_dim),
#             nn.ReLU()
#         ).to(device)

#     def forward(self, x):
#         return self.fc(self.conv(x / 255.0))




In [18]:
class DRL:
    def __init__(self, buffer_size = 10000):
        self.n_envs = 32
        self.n_steps = 512

        self.envs = gym.vector.SyncVectorEnv(
            [lambda: gym.make('ALE/MsPacman-v5') for _ in range(self.n_envs)])

        self.replay_buffer = ReplayBuffer(capacity=buffer_size)

    def rollout(self, agent, i, n_skill, encoder):
        """Collect experience and store it in the replay buffer"""
        encoder = encoder.to(device)
        obs, _ = self.envs.reset()
        obs = torch.tensor(obs, dtype=torch.float32, device=device).permute(0, 3, 1, 2)# .to(device)
        enc_obs = encoder(obs)

        env_skills = F.normalize(torch.randn(self.n_envs, n_skill, device=device), dim=1)  # shape: (n_envs, n_skill)
        total_rewards = torch.zeros(self.n_envs, device=device)


        for step_num in range(self.n_steps):
            with torch.no_grad():
                actions = agent.get_action(enc_obs, env_skills)

            next_obs, rewards, dones, truncateds, _ = self.envs.step(actions.cpu().numpy())
            next_obs = torch.tensor(next_obs, dtype=torch.float32, device=device).permute(0, 3, 1, 2)
            rewards = torch.tensor(rewards, dtype=torch.float32, device=device)
            not_done = ~(dones | truncateds)

            # Store the transitions
            self.replay_buffer.store(obs, env_skills, actions, rewards, next_obs, not_done)
            obs = next_obs
            with torch.no_grad():
                enc_obs = encoder(obs)
            total_rewards += rewards

            # Resample skills only for environments where episode ended
            done_mask = dones | truncateds
            if done_mask.any() or step_num % 50 == 0: ## RESAMPLING SKILLS QUITE A LOT :)
                new_skills = F.normalize(torch.randn(done_mask.sum(), n_skill, device=device), dim=1)
                env_skills[done_mask] = new_skills


        writer.add_scalar("stats/Rewards", total_rewards.mean().item() / self.n_steps, i)

In [19]:
def encode_state(state, encoder): # takes state and passes through CNNEncoder 
    if state.dim() == 4 and state.shape[1] == 3:
        features = encoder(state)
    elif state.dim() == 3 and state.shape[0] == 3:
        state = state.unsqueeze(0)
        features = encoder(state)
    elif state.dim() == 2:
        features = state
    else:
        raise ValueError(f'found a state with shape: {state.shape} in get_latent_representationss')
    
    return features

In [20]:
class SkillPolicy: # implements discrete SAC 
    def __init__(self, n_obs, n_skills, n_actions, representation, tau=0.005, lr=3e-4, gamma = 0.99, automatic_entropy_tuning=True, par_alpha=0.2, target_entropy=None):
        
        
        self.alpha = par_alpha # 1.5
        self.n_actions = n_actions
        self.encoder = representation.encoder
        self.representation = representation
        self.tau = tau
        self.gamma = gamma

        # Q1Network
        self.q1_net = nn.Sequential(
            nn.Linear(n_obs + n_skills, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.n_actions)
        ).to(device)
        
        # Q2Network
        self.q2_net = copy.deepcopy(self.q1_net).to(device)

        # Policy Network
        self.policy = nn.Sequential(
            nn.Linear(n_obs + n_skills, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.n_actions)
        ).to(device)

        # Target Q1 Netowrk 
        self.q1_target_net = copy.deepcopy(self.q1_net).to(device)
        
        # Target Q2 network
        self.q2_target_net = copy.deepcopy(self.q1_net).to(device)

        # self.q1_target_net.load_state_dict(self.q1_net.state_dict()) # maybe faster
        # self.q2_target_net.load_state_dict(self.q2_net.state_dict())

        # Single Q optimizer 
        self.q_optimizer = Adam(
            list(self.q1_net.parameters()) + list(self.q2_net.parameters()), lr=lr
        )

        # Policy Optimizer 
        self.policy_optimizer = Adam(self.policy.parameters(), lr=lr)
        
        
        # self.gamma = gamma

        ###########   add temperature  ############
        self.automatic_entropy_tuning = automatic_entropy_tuning
        if automatic_entropy_tuning:
            # target entropy = −|A|
            self.target_entropy = -self.n_actions if target_entropy is None else target_entropy
            self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
            self.alpha_optimizer = Adam([self.log_alpha], lr=lr)
        else:
            self.alpha = par_alpha
        ################################################





    ############ Selection of the action  #############
    def get_policy_distribution(self, states, skills):
        states = encode_state(states, self.encoder)
        inputs = torch.cat([states, skills], dim=-1)
        logits = self.policy(inputs)
        return Categorical(logits=logits)

    def get_action(self, states, skills, eval=False):
        dist = self.get_policy_distribution(states, skills)
        if eval:
            return torch.argmax(dist.probs, dim=-1)
        return dist.sample()
    ###################################################


    # def get_entropy(self, states, skills):
    #     dist = self.get_policy_distribution(states, skills)
    #     return dist.entropy()


    def get_q_loss(self, states, actions, rewards, next_states, not_dones, skills):

        # ext_reward = rewards.to(device)

        with torch.no_grad():
            phi = self.representation.get_latent_representation(states)  # phi(s) 
            phi_next = self.representation.get_latent_representation(next_states) # phi (s')

            delta_phi = phi - phi_next
            # skills = F.normalize(skills, p=2, dim=1)

            intrinsic_reward = torch.einsum('bi,bi->b', delta_phi, skills).unsqueeze(-1)
            intrinsic_reward = torch.clamp(intrinsic_reward, -10, 10)
            # add extrinsic reward ??
            
            ############## target Q value #######################
            next_states = encode_state(next_states, self.encoder)
            next_input = torch.cat([next_states, skills], dim=-1)
            next_q1 = self.q1_target_net(next_input)
            next_q2 = self.q2_target_net(next_input)
            next_q = torch.min(next_q1, next_q2)

            
            next_pi = self.get_policy_distribution(next_states, skills)
            log_probs = next_pi.logits.log_softmax(dim=-1)
            # next_entropy = next_pi.entropy().unsqueeze(-1)
            # next_entropy = next_pi.logits.log_softmax(dim=-1)

            if self.automatic_entropy_tuning:
                alpha = self.log_alpha.exp()
            else:
                alpha = self.alpha

            next_q_val = (next_pi.probs * (next_q - alpha * log_probs)).sum(dim=-1, keepdim=True)
            q_target = intrinsic_reward + (self.gamma * not_dones * next_q_val) # + ext_reward # added extrinisic reward, may not make sense
            #####################################################

        ############  Q losses ##################
        states = encode_state(states, self.encoder)
        current_inputs = torch.cat([states, skills], dim=-1)
        q1 = self.q1_net(current_inputs).gather(1, actions.long().unsqueeze(-1))
        q2 = self.q2_net(current_inputs).gather(1, actions.long().unsqueeze(-1))

        loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)
        #######################################
        return loss
        

    #################### policy loss ##################
    def get_policy_loss(self, states, skills):
        dist = self.get_policy_distribution(states, skills)
        probs = dist.probs
        log_probs = dist.logits.log_softmax(dim=-1)

        states = encode_state(states, self.encoder)

        inputs = torch.cat([states, skills], dim=-1)
        q1 = self.q1_net(inputs)
        q2 = self.q2_net(inputs)
        q = torch.min(q1, q2)

        # added alpha 
        if self.automatic_entropy_tuning:
            alpha = self.log_alpha.exp()
        else:
            alpha = self.alpha


        policy_loss = -(probs * (q - (alpha * log_probs))).sum(dim=1).mean()
        return policy_loss
    #####################################################

    ############### entropy loss ###################
    def get_entropy_loss(self, states, skills):
        dist = self.get_policy_distribution(states, skills)
        # dist.entropy()
        probs = dist.probs
        log_probs = dist.logits.log_softmax(dim=-1)

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
        else:
            alpha_loss = 0
        return alpha_loss
    # self.alpha_opt.zero_grad()
    #         alpha_loss.backward()
    #         self.alpha_opt.step()
    ###################################################    


    def update(self, replay_buffer, i):

      for _ in range(1): # i dont think its good idea to set this too high ?
        states, skills, actions, rewards, next_states, not_dones = replay_buffer.sample()
        # Compute Q-loss
        q_loss = self.get_q_loss(states, actions, rewards, next_states, not_dones, skills)
        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()

        # Update the policy
        policy_loss = self.get_policy_loss(states, skills)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

      # Soft update the target networks using Polyak averaging
      self.soft_update(self.q1_net, self.q1_target_net)
      self.soft_update(self.q2_net, self.q2_target_net)
      # entropy = self.get_entropy(states, skills).mean().item()

      alpha_loss = self.get_entropy_loss(states, skills)
      self.alpha_optimizer.zero_grad()
      alpha_loss.backward()
      self.alpha_optimizer.step()

      writer.add_scalar("loss/entropy_loss", alpha_loss, i)

      writer.add_scalar("loss/q_loss", q_loss.item(), i)
      writer.add_scalar("loss/ - policy loss", -policy_loss.item(), i)


    def soft_update(self, source, target):
        """Soft update the target network"""
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)


In [None]:
class CNNEncoder(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=8, stride=4),  # (3, 210, 160) → (32, 52, 39)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # → (64, 25, 18)
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), # → (64, 23, 16)
            nn.ReLU()
        ).to(device)
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 22 * 16, out_dim),
            nn.ReLU()
        ).to(device)

    def forward(self, x):
        return self.fc(self.conv(x / 255.0))

In [22]:

class RepresentationFunction(nn.Module): # Meat of METRA
    def __init__(self, n_obs, n_skill, lr=3e-4):
        super().__init__()

        # Encode input state 
        self.encoder = CNNEncoder(out_dim=256) # output of the CNN

        # phi network - n_obs is hardcoded to 256 in main
        self.representation_func = nn.Sequential(
            nn.Linear(n_obs, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, n_skill)
        ).to(device)

        # joint optimization over encoder and phi
        self.optimizer = Adam(list(self.parameters()) + list(self.encoder.parameters()), lr=1e-4)
        
        # lagrange mult
        self.lambda_param = nn.Parameter(torch.tensor(0.05, requires_grad=True, device=device))
        
        # lambda optimizer 
        self.lambda_optimizer = Adam([self.lambda_param], lr=1e-4)
        
        # threshold for metra constraint 
        self.epsilon = 0.001 # 1.0


    def get_latent_representation(self, state, normalize=False):
        features = encode_state(state, self.encoder) # Runs CNNEncoder 
        x = self.representation_func(features) # phi output 
        if normalize:
            return F.normalize(x, p=2, dim=1)
        return x


    def update(self, replay_buffer, i):
        """
        Updates the representation function φ and the Lagrange multiplier λ
        based on skill-consistency and distance constraints.
        """
        for _ in range(50): # 10

          state, skill, action, reward, next_state, not_done  = replay_buffer.sample()

          current_representations = self.get_latent_representation(state)    # φ(s)
          next_representations = self.get_latent_representation(next_state)  # φ(s')

          # Consistency loss term
          consistency_term = torch.einsum('bi,bi->b', (next_representations - current_representations), skill)

          # Distance penalty term (to enforce norm constraints)
          diff_norm_squared = torch.norm(current_representations - next_representations, dim=1) ** 2
          penalty_term = torch.minimum(
              torch.tensor(self.epsilon, device=diff_norm_squared.device),
              (1.0 - diff_norm_squared).clone()
          )

          # Representation loss
          representation_loss = -(consistency_term + self.lambda_param * penalty_term).mean()

          # Lambda loss (only on penalty term)
          lambda_loss = (self.lambda_param * penalty_term.detach()).mean()

          # Backprop: update \phi
          self.optimizer.zero_grad()
          representation_loss.backward(retain_graph=True)
          self.optimizer.step()

          # Backprop: update \lambda
          self.lambda_optimizer.zero_grad()
          lambda_loss.backward()
          self.lambda_optimizer.step()
          with torch.no_grad():
                self.lambda_param.clamp_(min=0.0)

        writer.add_scalar("loss/representation_loss", representation_loss.item(), i)
        writer.add_scalar("loss/  lambda_loss", lambda_loss.item(), i)


## not checkedn

In [23]:

def compute_heatmap(frames, save_path):
    heatmap = None
    for frame in frames:
        gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        # naive thresholding to isolate the player
        _, thresh = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY)
        if heatmap is None:
            heatmap = np.zeros_like(thresh, dtype=np.float32)
        heatmap += thresh.astype(np.float32)

    # Normalize and save heatmap
    plt.figure(figsize=(6, 6))
    plt.imshow(heatmap, cmap='hot', interpolation='nearest')
    plt.axis('off')
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.close()

def visualize_representation(replay_buffer, representation, n_samples=1000, folder='.', n_skills=10):
    skill_data = []
    for skill_id in range(n_skills):
        samples = replay_buffer.sample_by_skill(skill_id, num_samples=max(int(n_samples/10), 100))
        skill_data.append(samples)

    # Combine all samples
    states = torch.cat([s[0] for s in skill_data], dim=0)
    skills = torch.cat([s[1] for s in skill_data], dim=0)

    with torch.no_grad():
        skill_ids = torch.argmax(skills, dim=1)  # assume Gaussian skills aren't passed here!

        phis = representation.get_latent_representation(states).cpu().numpy()
        skill_ids_np = skill_ids.cpu().numpy()

    tsne = TSNE(n_components=2, perplexity=10)
    tsne_result = tsne.fit_transform(phis)

    plt.figure(figsize=(8, 6))
    for skill_id in range(n_skills):
        idx = skill_ids_np == skill_id
        if np.sum(idx) == 0:
            continue
        plt.scatter(tsne_result[idx, 0], tsne_result[idx, 1], label=f'Skill {skill_id}', alpha=0.6)
    plt.legend()
    plt.title('t-SNE of Representation φ(s)')
    plt.savefig(f"{folder}/representation_tsne.png")
    plt.close()


def evaluate_skills(env_name, policy, representation, global_step, writer=None, n_skills=5, steps_per_skill=512, video_dir='skill_videos'):
    os.makedirs(video_dir, exist_ok=True)

    temp_buffer = ReplayBuffer(capacity=steps_per_skill * n_skills)
    temp_buffer.batch_size = 512

    for skill_id in range(n_skills): # For every z vector, I will perform ONLY the action related to that skill and get the reward
        env = gym.make(env_name, render_mode='rgb_array')
        obs, _ = env.reset()
        # obs = torch.tensor(obs, dtype=torch.float32, device=device).view(1, -1) ## NO need to flatten CNN encoder in representation function will handle it
        obs = torch.unsqueeze(torch.tensor(obs, dtype=torch.float32, device=device), 0).permute(0, 3, 1, 2)

        # USING ONE HOT ENCODED SKILLS HERE SO THAT WE CAN FIND THE DIFFERENT SKILLS SEPARATELY
        skill = torch.zeros(1, n_skills, device=device)
        skill[0, skill_id] = 1.0  # One-hot vector for evaluation


        frames = []
        total_intrinsic_reward = 0
        step_count = 0

        while step_count < steps_per_skill:
            with torch.no_grad():
                action = policy.get_action(obs, skill, eval=True).item()
            next_obs, _, terminated, truncated, _ = env.step(action)
            frame = env.render()
            frames.append(frame)

            # next_obs_tensor = torch.tensor(next_obs, dtype=torch.float32, device=device).view(1, -1)
            next_obs_tensor = torch.unsqueeze(torch.tensor(next_obs, dtype=torch.float32, device=device), 0).permute(0, 3, 1, 2) ## Encoding doesn't matter because get_latent_representation encodes
            delta_phi = representation.get_latent_representation(next_obs_tensor) - representation.get_latent_representation(obs)
            intrinsic_reward = torch.einsum('bi,bi->b', delta_phi, skill).item()
            temp_buffer.store(obs, skill, torch.tensor([action]), torch.tensor([intrinsic_reward]), next_obs_tensor, torch.tensor([~(terminated or truncated)]))
            total_intrinsic_reward += intrinsic_reward

            if terminated or truncated:
                break

            obs = next_obs_tensor
            step_count += 1

        env.close()

        video_path = os.path.join(video_dir, f"skill_{skill_id}.mp4")
        imageio.mimsave(video_path, frames, fps=30)

        writer.add_scalar("representation/mean_delta_phi_norm", delta_phi.norm(dim=1).item(), global_step)

        heatmap_path = os.path.join(video_dir, f"skill_{skill_id}_heatmap.png")
        # compute_heatmap(frames, heatmap_path)

        del frames
        torch.cuda.empty_cache()  
        # print(f"Skill {skill_id}: steps = {step_count}, intrinsic reward = {total_intrinsic_reward:.2f}, saved to {video_path}")
        if writer:
            writer.add_scalar(f"eval/intrinsic_reward_skill_{skill_id}", total_intrinsic_reward, global_step)
            writer.add_scalar(f"eval/episode_length_skill_{skill_id}", step_count, global_step)
    
    visualize_representation(temp_buffer, representation, folder=video_dir, n_skills=n_skills)



In [24]:


env = gym.make('ALE/MsPacman-v5')
# n_obs = env.observation_space.shape[0] * env.observation_space.shape[1] * env.observation_space.shape[2]
n_obs = 256 # Encoded dimension
n_actions = env.action_space.n
n_skill = 10 # 0 - NOOP, 1 - UP, 2 - LEFT, 3 - RIGHT, 4 - DOWN

representation = RepresentationFunction(n_obs, n_skill)
policy = SkillPolicy(n_obs, n_skill, n_actions, representation)

start_epoch = 0

chkpt_path = ''
if len(chkpt_path) > 0:
    if not os.path.exists(chkpt_path):
        raise ValueError("Load folder path doesn't exist")
    checkpoint = torch.load(chkpt_path)
    policy.policy.load_state_dict(checkpoint['policy_state_dict'])
    policy.q1_net.load_state_dict(checkpoint['q1_state_dict'])
    policy.q2_net.load_state_dict(checkpoint['q2_state_dict'])
    policy.q1_target_net.load_state_dict(checkpoint['q1_target_state_dict'])
    policy.q2_target_net.load_state_dict(checkpoint['q2_target_state_dict'])
    policy.policy_optimizer.load_state_dict(checkpoint['policy_optimizer_state_dict'])
    policy.q_optimizer.load_state_dict(checkpoint['q_optimizer_state_dict'])

    representation.representation_func.load_state_dict(checkpoint['representation_state_dict'])
    representation.optimizer.load_state_dict(checkpoint['representation_optimizer_state_dict'])
    representation.lambda_param = torch.tensor(checkpoint['lambda_param'], requires_grad=True)
    representation.lambda_optimizer.load_state_dict(checkpoint['lambda_optimizer_state_dict'])

    start_epoch = checkpoint['epoch'] + 1


# Training loop
num_epochs = 551

drl = DRL(buffer_size = 10_000)

if not os.path.exists(f"./model/{timestamp}"):
    os.makedirs(f"./model/{timestamp}")

folder = f'./model/{timestamp}'

for epoch in range(start_epoch, num_epochs + start_epoch):
    print(f'EPOCH: {epoch}')
    drl.rollout(policy, epoch, n_skill, representation.encoder)
    representation.update(drl.replay_buffer, epoch)
    policy.update(drl.replay_buffer, epoch)
    if epoch % 10 == 0:
        ## EVALUATING SKILL
        evaluate_skills('ALE/MsPacman-v5', policy, representation,\
                epoch, writer=writer, n_skills=n_skill, video_dir=f'{folder}/video_eval_{epoch}') ## ONLY doing for the first 3 skills

        ## EVALUATING REPRESENTATION FUNCTION
        # visualize_representation(replay_buffer=drl.replay_buffer,\
        #      representation=representation, folder=f"{folder}/video_eval_{epoch}") ## DOING THIS IN SkILL EVALUATION WHICH DOES ONE HOT ENCODED SKILLS
        # writer.add_image("eval/tSNE", \
        #      torch.tensor(np.array(Image.open(f"{folder}/video_eval_{epoch}/representation_tsne.png"))).permute(2, 0, 1), epoch)

        gc.collect()
    cleanup_memory()

torch.save({
    'epoch': num_epochs,
    'policy_state_dict': policy.policy.state_dict(),
    'q1_state_dict': policy.q1_net.state_dict(),
    'q2_state_dict': policy.q2_net.state_dict(),
    'q1_target_state_dict': policy.q1_target_net.state_dict(),
    'q2_target_state_dict': policy.q2_target_net.state_dict(),
    'policy_optimizer_state_dict': policy.policy_optimizer.state_dict(),
    'q_optimizer_state_dict': policy.q_optimizer.state_dict(),

    'representation_state_dict': representation.representation_func.state_dict(),
    'representation_optimizer_state_dict': representation.optimizer.state_dict(),
    'lambda_param': representation.lambda_param.detach().item(),
    'lambda_optimizer_state_dict': representation.lambda_optimizer.state_dict(),
}, f"./model/{timestamp}/checkpoint.pth")

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]
  return disable_fn(*args, **kwargs)


EPOCH: 0


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 1


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 2
EPOCH: 3
EPOCH: 4
EPOCH: 5
EPOCH: 6
EPOCH: 7
EPOCH: 8
EPOCH: 9
EPOCH: 10




EPOCH: 11


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 12
EPOCH: 13
EPOCH: 14
EPOCH: 15
EPOCH: 16
EPOCH: 17
EPOCH: 18
EPOCH: 19
EPOCH: 20




EPOCH: 21


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 22
EPOCH: 23
EPOCH: 24
EPOCH: 25
EPOCH: 26
EPOCH: 27
EPOCH: 28
EPOCH: 29
EPOCH: 30




EPOCH: 31


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 32
EPOCH: 33
EPOCH: 34
EPOCH: 35
EPOCH: 36
EPOCH: 37
EPOCH: 38
EPOCH: 39
EPOCH: 40




EPOCH: 41


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 42
EPOCH: 43
EPOCH: 44
EPOCH: 45
EPOCH: 46
EPOCH: 47
EPOCH: 48
EPOCH: 49
EPOCH: 50




EPOCH: 51


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 52
EPOCH: 53
EPOCH: 54
EPOCH: 55
EPOCH: 56
EPOCH: 57
EPOCH: 58
EPOCH: 59
EPOCH: 60




EPOCH: 61


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 62
EPOCH: 63
EPOCH: 64
EPOCH: 65
EPOCH: 66
EPOCH: 67
EPOCH: 68
EPOCH: 69
EPOCH: 70




EPOCH: 71


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 72
EPOCH: 73
EPOCH: 74
EPOCH: 75
EPOCH: 76
EPOCH: 77
EPOCH: 78
EPOCH: 79
EPOCH: 80




EPOCH: 81


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 82
EPOCH: 83
EPOCH: 84
EPOCH: 85
EPOCH: 86
EPOCH: 87
EPOCH: 88
EPOCH: 89
EPOCH: 90




EPOCH: 91


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 92
EPOCH: 93
EPOCH: 94
EPOCH: 95
EPOCH: 96
EPOCH: 97
EPOCH: 98
EPOCH: 99
EPOCH: 100




EPOCH: 101


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 102
EPOCH: 103
EPOCH: 104
EPOCH: 105
EPOCH: 106
EPOCH: 107
EPOCH: 108
EPOCH: 109
EPOCH: 110




EPOCH: 111


  loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)


EPOCH: 112
EPOCH: 113
EPOCH: 114
EPOCH: 115
EPOCH: 116
EPOCH: 117
EPOCH: 118
EPOCH: 119
EPOCH: 120




OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacity of 44.45 GiB of which 1.75 MiB is free. Process 2191234 has 418.00 MiB memory in use. Including non-PyTorch memory, this process has 6.92 GiB memory in use. Process 2394182 has 536.00 MiB memory in use. Process 2404824 has 18.14 GiB memory in use. Process 2409193 has 18.40 GiB memory in use. Of the allocated memory 5.69 GiB is allocated by PyTorch, and 935.67 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)