In [28]:
!pip install "gymnasium>=1.2.0" "imageio>=2.37.0" "lightning>=2.5.2" "matplotlib>=3.10.3" "minigrid>=3.0.0" "pyyaml>=6.0.2" "minigrid" "seaborn>=0.13.2" "tensorboard>=2.20.0" "torch>=2.7.1"

Collecting gymnasium>=1.2.0
  Downloading gymnasium-1.2.0-py3-none-any.whl.metadata (9.9 kB)
Collecting lightning>=2.5.2
  Downloading lightning-2.5.2-py3-none-any.whl.metadata (38 kB)
Collecting matplotlib>=3.10.3
  Downloading matplotlib-3.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting minigrid>=3.0.0
  Downloading minigrid-3.0.0-py3-none-any.whl.metadata (6.7 kB)
Collecting seaborn>=0.13.2
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting tensorboard>=2.20.0
  Downloading tensorboard-2.20.0-py3-none-any.whl.metadata (1.8 kB)
Collecting torch>=2.7.1
  Downloading torch-2.7.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch>=2.7.1)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch>=2.7.1)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-ru

### Base Agent

In [6]:
from abc import ABC,abstractmethod
import pytorch_lightning as pl
from typing import Dict, Any



class BaseAgent(pl.LightningModule, ABC):
    """Base class for agents"""
    
    def __init__(self,config:Dict[str,Any]):
        super().__init__()
        self.config = config
        self.save_hyperparameters(config)
        
    @abstractmethod
    def act(self,obs:Dict[str,Any]):
        """select an action given the observation"""
        pass
    
    @abstractmethod
    def training_step(self,batch:Dict[str,Any],batch_idx:int):
        """training step for the agent"""
        pass
        
    
    def configure_optimizers(self):
        """configure the optimizer for the agent"""
        raise NotImplementedError("Subclasses must implement the configure_optimizers method.")
    
    
    def save_checkpoint(self,path:str):
        """save the checkpoint for the agent"""
        self.trainer.save_checkpoint(path)
    
    @classmethod
    def load_checkpoint(cls,checkpoint_path:str,config=None):
        """load the checkpoint for the agent"""
        if config is None :
            return cls.load_from_checkpoint(checkpoint_path)
        return cls.load_from_checkpoint(checkpoint_path,config=config)

### Base Model

In [7]:
import torch.nn as nn




class BaseModel(nn.Module):
    def __init__(self, ):
        super().__init__()
    
    def forward(self, x):
        raise NotImplementedError("Subclasses must implement the forward method.")
    
    
    def save(self, path: str):
        torch.save(self.state_dict(), path)
    
    
    def load(self, path: str):
        self.load_state_dict(torch.load(path))

## DIAYN Agent

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.amp
import numpy as np
from collections import deque, namedtuple
from typing import Tuple, Dict, Any


# Define transition tuple
Transition = namedtuple('Transition', 
                       ('state', 'action', 'skill', 'next_state', 'done', 'reward'))

class MiniGridEncoder(BaseModel):
    """Encoder for the MiniGrid environment So that it can be used in the DIAYN agent
    
    Args:
        input_shape (Tuple[int]): The shape of the input observation.
        hidden_size (int): The size of the hidden layer.
    
    Returns:
        torch.Tensor: The encoded observation.
    """
    
    def __init__(self, obs_shape: Tuple[int], feature_dim: int = 64, obs_type: str = "rgb"):
        super().__init__()
        self.obs_type = obs_type
        self.obs_shape = obs_shape
        
        # Determine input channels based on observation type
        self.in_channels = 3 if obs_type == "rgb" else 1
        
        # CNN architecture for processing observations
        self.conv = nn.Sequential(
            nn.Conv2d(self.in_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        

        with torch.no_grad():
            # Create a dummy input with correct shape (N, C, H, W)
            dummy = torch.zeros(1, self.in_channels, *obs_shape[:2])
            conv_out = self.conv(dummy)
            self.conv_output_dim = conv_out.shape[1]
            

        self.fc = nn.Linear(self.conv_output_dim, feature_dim)
        self.feature_dim = feature_dim
    
    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        """Forward pass of the encoder
        
        Args:
            obs (torch.Tensor): The input observation of shape (batch, H, W, C) or (H, W, C)
            
        Returns:
            torch.Tensor: The encoded observation of shape (batch, hidden_dim)
        """
        # Ensure we have a batch dimension
        if len(obs.shape) == 3:  # (H, W, C) -> (1, H, W, C)
            obs = obs.unsqueeze(0)
            
        # Convert to float and normalize if needed
        if obs.dtype == torch.uint8:
            obs = obs.float() / 255.0
            
        # Convert from NHWC to NCHW format expected by PyTorch
        if obs.shape[-1] in [1, 3]:  # If channels are last
            obs = obs.permute(0, 3, 1, 2)  # NHWC -> NCHW
            
        # Ensure we have the right number of channels
        if self.obs_type == 'rgb' and obs.shape[1] != 3:
            if obs.shape[1] == 1:  # If grayscale, repeat to 3 channels
                obs = obs.repeat(1, 3, 1, 1)
            else:
                raise ValueError(f"Expected 1 or 3 channels for RGB, got {obs.shape[1]} channels")
                

        x = self.conv(obs)
        return torch.relu(self.fc(x))
    




class SkillDiscriminator(BaseModel):
    """
    Skill discriminator for the DIAYN agent
    
    Args:
        input_dim (int): The dimension of the input.
        hidden_dim (int): The dimension of the hidden layer.
    
    Returns:
        torch.Tensor: The skill discriminator output.
    """
    
    def __init__(self,state_dim,skill_dim,hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim,hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim,hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim,skill_dim)
        )
        
    def forward(self,state:torch.Tensor):
        return self.net(state)
    
    
    def compute_reward(self,state,skill):
        logits = self(state)
        log_probs = F.log_softmax(logits,dim=-1)
        return (log_probs * skill).sum(dim=-1) + np.log(skill.size(1))



class DIAYNAgent(BaseAgent):
    """DIAYN agent for the MiniGrid environment"""
    def __init__(self,config:Dict[str,Any]):
        super().__init__(config)
        
        #Environment parameters
        self.obs_shape = config["agent"]["obs_shape"]
        self.action_dim = config["agent"]["action_dim"]
        self.skill_dim = config.get("skill_dim",8)
        self.obs_type = config.get("obs_type","rgb")
        
        # Training parameters
        self.lr = float(config.get("lr", 3e-4))
        self.gamma = float(config.get("gamma", 0.99))
        self.entropy_coeff = float(config.get("entropy_coeff", 0.01))
        self.batch_size = int(config.get("batch_size", 64))
        self.replay_size = int(config.get("replay_size", 10000))
        self.entropy_coef = float(config.get("entropy_coef", 0.01))
        
        #Models
        self.encoder = MiniGridEncoder(self.obs_shape,
                                       feature_dim = config.get("hidden_dim",64),
                                       obs_type = self.obs_type
                                       ).to(self.device)
        
        #Policy Network 
        self.policy = nn.Sequential(
            nn.Linear(self.encoder.feature_dim + self.skill_dim,64),
            nn.ReLU(),
            nn.Linear(64,64),
            nn.ReLU(),
            nn.Linear(64,self.action_dim)
        ).to(self.device)
        
        #Discriminator Network
        self.discriminator = SkillDiscriminator(
            self.encoder.feature_dim,
            self.skill_dim,
            hidden_dim = config.get("hidden_dim",64)
        ).to(self.device)
        
        #Replay Buffer
        self.replay_buffer = deque(maxlen=self.replay_size)
        
        #Metrics
        self.episode_rewards = []
        self.episode_lengths = []
        
    def forward(self,obs:torch.Tensor,skill:torch.Tensor,deterministic:bool=False) -> torch.Tensor:
        """
        Forward pass of the agent
        Args: 
            obs (torch.Tensor): The observation.
            skill (torch.Tensor): The skill.
        Returns:
            torch.Tensor: The action.
        """
        with torch.no_grad():
            encoded_obs = self.encoder(obs) #return the state in latent space
            x = torch.cat([encoded_obs,skill],dim=-1).to(self.device)
            logits = self.policy(x).to(self.device)
            if deterministic:
                return torch.argmax(logits,dim=-1).to(self.device)
            else:
                probs = F.softmax(logits,dim=-1).to(self.device)
                return torch.multinomial(probs,1).squeeze(-1).to(self.device)
            
    
    def act(self,obs:torch.Tensor,skill:torch.Tensor=None,deterministic:bool=False) -> torch.Tensor:
        """Select an action given the observation and skill"""
        if skill is None:
            skill = self._sample_skill()
        skill = torch.FloatTensor(skill).unsqueeze(0).to(self.device) if not isinstance(skill,torch.Tensor) else skill
        obs = torch.FloatTensor(obs).unsqueeze(0).to(self.device) if not isinstance(obs,torch.Tensor) else obs
        return self.forward(obs,skill,deterministic).cpu().numpy().item()
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        """
        Perform a single training step with mixed precision support.
        
        Args:
            batch: Batch of transitions (if None, sample from replay buffer)
            batch_idx: Batch index
            optimizer_idx: Index of the optimizer (0: discriminator, 1: policy)
            
        Returns:
            torch.Tensor: The computed loss
        """
        # If batch is None, sample from replay buffer
        if batch is None:
            if len(self.replay_buffer) < self.batch_size:
                return torch.tensor(0.0, device=self.device)
            batch = self._sample_batch()
            
        states, actions, skills, next_states, dones, rewards = self._unpack_batch(batch)
        
        with torch.amp.autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu', 
                              enabled=self.device.type == 'cuda'):
            # Encode states and next_states
            with torch.no_grad():
                states_enc = self.encoder(states)
                next_states_enc = self.encoder(next_states)
            
            # Train Discriminator
            if optimizer_idx == 0:
                logits = self.discriminator(next_states_enc)
                loss_d = F.cross_entropy(logits, skills.argmax(dim=-1))
                self.log("train/loss_discriminator", loss_d, prog_bar=True)
                return loss_d
                
            # Compute Policy
            if optimizer_idx == 1:
                policy_input = torch.cat([states_enc, skills], dim=-1)
                logits = self.policy(policy_input)
                
                probs = F.softmax(logits, dim=-1)
                log_probs = F.log_softmax(logits, dim=-1)
                entropy = -(probs * log_probs).sum(dim=-1)
                
                with torch.no_grad():
                    intrinsic_reward = self.discriminator.compute_reward(next_states_enc, skills)
                    
                # Compute policy gradient loss
                policy_loss = -(log_probs.gather(1, actions.unsqueeze(1)) * intrinsic_reward.detach()).mean()
                entropy_loss = -self.entropy_coef * entropy.mean()
                
                total_loss = policy_loss + entropy_loss
                
                # Log metrics
                self.log("train/loss_policy", policy_loss, prog_bar=True)
                self.log("train/entropy", entropy.mean(), prog_bar=True)
                self.log("train/total_loss", total_loss)
                self.log("train/avg_intrinsic_reward", intrinsic_reward.mean())
                
                return total_loss
            
    def configure_optimizers(self):
        """Configure optimizers for discriminator and policy with weight decay."""
        # Use AdamW with weight decay
        opt_d = torch.optim.AdamW(
            self.discriminator.parameters(), 
            lr=self.lr,
            weight_decay=1e-5,
            eps=1e-5
        )
        opt_p = torch.optim.AdamW(
            list(self.encoder.parameters()) + list(self.policy.parameters()),
            lr=self.lr,
            weight_decay=1e-5,
            eps=1e-5
        )
        
        # Learning rate scheduling
        scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(opt_d, T_max=1000)
        scheduler_p = torch.optim.lr_scheduler.CosineAnnealingLR(opt_p, T_max=1000)
        
        return [opt_d, opt_p], [scheduler_d, scheduler_p]   
        
    def _unpack_batch(self, batch):
        """Unpack a batch of transitions from the replay buffer.
        
        Args:
            batch: Batch of transitions. Can be:
                - List of transitions (state, action, skill, next_state, done, reward)
                - Tuple of (states, actions, skills, next_states, dones, rewards)
                
        Returns:
            tuple: (states, actions, skills, next_states, dones, rewards) as torch tensors
        """
        # Handle case where batch is already a tuple of tensors
        if isinstance(batch, (list, tuple)) and len(batch) == 6 and all(torch.is_tensor(x) for x in batch):
            return batch
            
        # Handle case where batch is a list of transitions
        if isinstance(batch, list) and len(batch) > 0 and isinstance(batch[0], (list, tuple)):
            # Transpose the batch: from list of transitions to list of fields
            states, actions, skills, next_states, dones, rewards = zip(*batch)
        else:
            # Assume batch is already in the correct format
            states, actions, skills, next_states, dones, rewards = batch
        
        # Convert to numpy arrays if they aren't already
        states = np.array(states) if not isinstance(states, np.ndarray) else states
        next_states = np.array(next_states) if not isinstance(next_states, np.ndarray) else next_states
        skills = np.array(skills) if not isinstance(skills, np.ndarray) else skills
        
        # Convert to tensors and move to device
        states = torch.as_tensor(states, dtype=torch.float32, device=self.device)
        actions = torch.as_tensor(actions, dtype=torch.long, device=self.device)
        skills = torch.as_tensor(skills, dtype=torch.float32, device=self.device)
        next_states = torch.as_tensor(next_states, dtype=torch.float32, device=self.device)
        dones = torch.as_tensor(dones, dtype=torch.float32, device=self.device)
        rewards = torch.as_tensor(rewards, dtype=torch.float32, device=self.device)
        
        # Ensure correct shapes
        if len(states.shape) == 3:  # (B, H, W) -> (B, 1, H, W)
            states = states.unsqueeze(1)
        if len(next_states.shape) == 3:
            next_states = next_states.unsqueeze(1)
            
        return states, actions, skills, next_states, dones, rewards
        
    def _sample_batch(self):
        """Sample a batch from replay buffer."""
        if len(self.replay_buffer) < self.batch_size:
            batch_size = len(self.replay_buffer)
        else:
            batch_size = self.batch_size
            
        # Sample random indices
        indices = np.random.choice(len(self.replay_buffer), size=batch_size, replace=False)
        transitions = [self.replay_buffer[i] for i in indices]
        batch = Transition(*zip(*transitions))
        
        return (
            torch.stack(batch.state).to(self.device),
            torch.cat(batch.action).to(self.device),
            torch.stack(batch.skill).to(self.device),
            torch.stack(batch.next_state).to(self.device),
            torch.stack(batch.done).to(self.device),
            torch.stack(batch.reward).to(self.device)
        )
    
    
    def _sample_skill(self):
        """Sample a random one-hot skill vector"""
        skill = np.zeros(self.skill_dim)
        skill[np.random.randint(self.skill_dim)] = 1
        return skill

    
    def add_to_replay(self,state,action,skill,next_state,done,reward):
        
        """Add transition to replay buffer."""
        self.replay_buffer.append(Transition(
            torch.FloatTensor(state).to(self.device),
            torch.LongTensor([action]).to(self.device),
            torch.FloatTensor(skill).to(self.device),
            torch.FloatTensor(next_state).to(self.device),
            torch.FloatTensor([done]).to(self.device),
            torch.FloatTensor([reward]).to(self.device)
        ))

### Configs

In [20]:
with open("diayn.yaml","w") as f :
    f.write("""
# configs/diayn.yaml
# DIAYN (Diversity is All You Need) Configuration
# Reference: https://arxiv.org/abs/1802.06070

# ===== Environment Configuration =====
env_id: "MiniGrid-Empty-8x8-v0"  # MiniGrid environment ID
obs_type: "rgb"  # Observation type: "rgb" (3-channel) or "grid" (single-channel)

# Available MiniGrid environments (uncomment to use):
# - "MiniGrid-Empty-5x5-v0"
# - "MiniGrid-Empty-8x8-v0"
# - "MiniGrid-Empty-16x16-v0"
# - "MiniGrid-DoorKey-5x5-v0"
# - "MiniGrid-DoorKey-8x8-v0"
# - "MiniGrid-FourRooms-v0"

# ===== Agent Configuration =====
agent:
  # Observation and action spaces (will be auto-filled)
  obs_shape: [7, 7, 3]  # [height, width, channels] - will be overridden
  action_dim: 7  # MiniGrid action space size - will be overridden
  
  # Skill configuration
  skill_dim: 8  # Number of discrete skills to learn
  
  # Network architecture
  hidden_dim: 256  # Hidden layer size for all networks
  
  # Training hyperparameters
  lr: 3e-4  # Learning rate
  gamma: 0.99  # Discount factor
  entropy_coef: 0.1  # Entropy coefficient for policy gradient
  
  # Replay buffer
  batch_size: 512  # Batch size for training
  replay_size: 50000  # Maximum replay buffer size
  
  # Intrinsic reward scaling
  intrinsic_reward_scale: 1.0  # Scale factor for intrinsic rewards
  
  # Device configuration
  device: "auto"  # "auto", "cpu", or "cuda"

# ===== Training Configuration =====
training:
  # Training procedure
  max_episodes: 100000 # Maximum number of training episodes
  max_steps_per_episode: 200  # Maximum steps per episode
  
  # Logging and evaluation
  log_interval: 100  # Log metrics every N episodes
  eval_interval: 500  # Evaluate every N episodes
  eval_episodes: 50  # Number of evaluation episodes
  
  # Checkpointing
  save_interval: 10000  # Save model every N episodes
  
  # Early stopping (optional)
  early_stop_reward: None  # Stop training if average reward exceeds this value
  patience: 100  # Number of episodes to wait before early stopping

# ===== Logging Configuration =====
logging:
  log_dir: "logs"  # Base directory for logs
  project_name: "skill-discovery"  # Project name for experiment tracking
  use_tensorboard: true  # Enable TensorBoard logging
  save_video: false  # Save video of agent's performance
  video_interval: 1000  # Save video every N episodes

# ===== Notes =====
# 1. obs_shape and action_dim will be automatically set based on the environment
# 2. For best results, adjust batch_size and replay_size based on available memory
# 3. Increase skill_dim for more diverse behaviors, but training will be slower
# 4. Monitor training progress using TensorBoard: `tensorboard --logdir=logs`
    """)

### MinitGrid init

In [23]:
"""Register MiniGrid environments with Gymnasium."""
from gymnasium.envs.registration import register

# Register all available MiniGrid environments
register(
    id="MiniGrid-Empty-5x5-v0",
    entry_point="minigrid.envs:EmptyEnv",
    kwargs={"size": 5}
)

register(
    id="MiniGrid-Empty-8x8-v0",
    entry_point="minigrid.envs:EmptyEnv",
    kwargs={"size": 8}
)

register(
    id="MiniGrid-Empty-16x16-v0",
    entry_point="minigrid.envs:EmptyEnv",
    kwargs={"size": 16}
)

# You can add more environments as needed
register(
    id="MiniGrid-DoorKey-5x5-v0",
    entry_point="minigrid.envs:DoorKeyEnv",
    kwargs={"size": 5}
)

register(
    id="MiniGrid-DoorKey-8x8-v0",
    entry_point="minigrid.envs:DoorKeyEnv",
    kwargs={"size": 8}
)


In [24]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces



class MiniGridWrapper(gym.Wrapper):
    """Wrapper for the MiniGrid environment"""
    def __init__(self,env,skill_dim=8,obs_type="rgb"):
        super().__init__(env)
        self.skill_dim = skill_dim
        self.obs_type = obs_type
        
        if obs_type == "rgb":
            self.obs_shape = (7,7,3)
        else: #obs_type = grid
            self.obs_shape = (7,7)
        
        self.observation_space = spaces.Dict({
            "observation": spaces.Box(
                low=0,high=255,
                shape=self.obs_shape,
                dtype=np.uint8
            ),
            "skill": spaces.Box(
                low=-1.0,high=1.0,
                shape=(skill_dim,),
                dtype=np.float32
            )
        })
    
    def reset(self,**kwargs):
        obs,info = super().reset(**kwargs)
        return self._process_obs(obs),info
    
    def step(self,action):
        obs,reward,terminated,truncated,info = self.env.step(action)
        return self._process_obs(obs),reward,terminated,truncated,info
    
    def _process_obs(self,obs):
        """
        Process the observation to match the observation space
        """
        
        if self.obs_type == "rgb":
            obs_array = obs["image"][...,:3]
        else:
            obs_array = obs["image"][...,0]
        
        skill = np.random.uniform(-1,1,size=(self.skill_dim,))
        
        return {
            "observation":obs_array,
            "skill":skill
        }
        

### Train

In [21]:
import os
import yaml
import argparse
import gymnasium as gym
import numpy as np
import torch
import torch.optim as optim
import torch.amp
from pathlib import Path
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint


def parse_args():
    import sys
    sys.argv = ['']

    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="/kaggle/working/diayn.yaml")
    parser.add_argument("--log_dir", type=str, default="logs")
    return parser.parse_args()

def load_config(config_path):
    with open(config_path, "r") as f:
        return yaml.safe_load(f)

def collect_rollout(env, agent, max_steps=1000):
    """Collect a single rollout from the environment with optimized data transfer."""
    obs, _ = env.reset()
    skill = agent._sample_skill()
    episode_reward = 0
    episode_length = 0
    
    # Convert initial observation to tensor
    obs_array = np.asarray(obs["observation"], dtype=np.float32)
    
    for _ in range(max_steps):
        # Convert to tensor and move to device
        obs_tensor = torch.from_numpy(obs_array).unsqueeze(0).to(agent.device)
        skill_tensor = torch.FloatTensor(skill).unsqueeze(0).to(agent.device)
        
        # Ensure tensors are contiguous
        obs_tensor = obs_tensor.contiguous()
        skill_tensor = skill_tensor.contiguous()
        
        # Get action from agent
        with torch.no_grad(), torch.amp.autocast(device_type='cuda' if agent.device.type == 'cuda' else 'cpu', enabled=agent.device.type == 'cuda'):
            action = agent.act(obs_tensor, skill_tensor, deterministic=False)
        
        # Step environment
        next_obs, reward, done, _, _ = env.step(action.item() if torch.is_tensor(action) else action)
        
        # Convert next observation to numpy
        next_obs_array = np.asarray(next_obs["observation"], dtype=np.float32)
        
        # Store transition
        agent.add_to_replay(
            obs_array,  # Use numpy array
            action.item() if torch.is_tensor(action) else action,
            skill,
            next_obs_array,  # Use numpy array
            done,
            reward
        )
        
        # Update state
        obs_array = next_obs_array
        obs = next_obs
        episode_reward += reward
        episode_length += 1
        
        if done:
            break
            
        # Clear CUDA cache periodically
        if episode_length % 100 == 0 and torch.cuda.is_available():
            torch.cuda.empty_cache()
            
    return episode_reward, episode_length

def train():
    args = parse_args()
    config = load_config(args.config)
    
    # Set up device and mixed precision
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize mixed precision training
    scaler = torch.amp.GradScaler(enabled=device.type == 'cuda')
    
    # Create environment
    env = gym.make(config["env_id"], render_mode="rgb_array")
    env = MiniGridWrapper(
        env, 
        skill_dim=config["agent"]["skill_dim"],
        obs_type=config["obs_type"]
    )
    
    # Update obs_shape in config
    config["agent"]["obs_shape"] = env.observation_space["observation"].shape
    config["agent"]["action_dim"] = env.action_space.n
    
    # Initialize agent with config
    agent = DIAYNAgent(config).to(device)
    
    agent.train()
    
    # Enable optimizations if using CUDA
    if device.type == 'cuda':
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        
        # Print GPU info
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        print(f"CUDA version: {torch.version.cuda}")
        print(f"PyTorch version: {torch.__version__}")


    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = os.path.join(args.log_dir, f"diayn_{timestamp}")
    os.makedirs(log_dir, exist_ok=True)
    
    logger = TensorBoardLogger(
        save_dir=args.log_dir,
        name="diayn",
        version=timestamp
    )
        

    checkpoint_dir = os.path.join(log_dir, "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        filename="diayn-{epoch:03d}",
        save_top_k=3,
        monitor="train/avg_reward",
        mode="max",
    )
    
    # Training parameters from config
    num_episodes = config.get("training", {}).get("num_episodes", 1000)
    max_steps = config.get("training", {}).get("max_steps", 1000)
    log_interval = config.get("training", {}).get("log_interval", 10)
    eval_interval = config.get("training", {}).get("eval_interval", 50)
    save_interval = config.get("training", {}).get("save_interval", 100)
    checkpoint_dir = Path(config.get("training", {}).get("checkpoint_dir", "checkpoints"))
    
    # Create progress bar
    pbar = tqdm(range(1, num_episodes + 1), desc="Training", unit="episode")
    
    # Initialize optimizers
    optimizers = agent.configure_optimizers()
    if isinstance(optimizers, (list, tuple)) and len(optimizers) > 0:
        if isinstance(optimizers[0], (list, tuple)):
            # Handle case where optimizers is a list of optimizers and schedulers
            optimizers = optimizers[0]
    
    # Unpack optimizers if we have multiple
    if isinstance(optimizers, (list, tuple)) and len(optimizers) >= 2:
        optimizer_d, optimizer_p = optimizers[0], optimizers[1]
    else:
        # Fallback to single optimizer if needed
        optimizer_d = optimizers[0] if isinstance(optimizers, (list, tuple)) else optimizers
        optimizer_p = optimizer_d
    
    episode_rewards = []
    episode_lengths = []
    total_steps = 0
    best_reward = -float('inf')
    
    for episode in pbar:
        try:
            # Set models to train mode
            agent.train()
            
            # Collect rollout
            episode_reward, episode_length = collect_rollout(env, agent, max_steps)
            
            # Store metrics
            episode_rewards.append(episode_reward)
            episode_lengths.append(episode_length)
            total_steps += episode_length
            
            # Calculate running averages
            avg_reward = np.mean(episode_rewards[-100:]) if episode_rewards else 0
            avg_length = np.mean(episode_lengths[-100:]) if episode_lengths else 0
            
            # Update progress bar
            pbar.set_postfix({
                'reward': f'{episode_reward:.2f}',
                'avg_reward': f'{avg_reward:.2f}',
                'length': episode_length,
                'steps': total_steps
            })
            
            # Log metrics
            if logger is not None:
                logger.experiment.add_scalar("train/episode_reward", episode_reward, episode)
                logger.experiment.add_scalar("train/episode_length", episode_length, episode)
                logger.experiment.add_scalar("train/avg_reward", avg_reward, episode)
                logger.experiment.add_scalar("train/avg_length", avg_length, episode)
                logger.experiment.add_scalar("train/total_steps", total_steps, episode)
            
            # Save best model
            if episode_reward > best_reward:
                best_reward = episode_reward
                if logger is not None and hasattr(logger, 'save_checkpoint'):
                    checkpoint = {
                        'episode': episode,
                        'model_state_dict': agent.state_dict(),
                        'reward': episode_reward,
                        'optimizer_state_dict': [opt.state_dict() for opt in optimizers] if isinstance(optimizers, list) else optimizers.state_dict(),
                    }
                    logger.save_checkpoint(checkpoint, is_best=True)
            
            # Save model at intervals
            if episode % save_interval == 0 and logger is not None and hasattr(logger, 'save_checkpoint'):
                checkpoint = {
                    'episode': episode,
                    'model_state_dict': agent.state_dict(),
                    'reward': episode_reward,
                    'optimizer_state_dict': [opt.state_dict() for opt in optimizers] if isinstance(optimizers, list) else optimizers.state_dict(),
                }
                logger.save_checkpoint(checkpoint, is_best=False, filename=f'checkpoint_ep{episode}.pt')
            
            # Clear CUDA cache periodically
            if episode % 10 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()
                
        except Exception as e:
            print(f"Error during episode {episode}: {str(e)}")
            import traceback
            traceback.print_exc()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            continue
        
        # Perform training step
        if len(agent.replay_buffer) >= agent.batch_size:
            try:
                # Train discriminator
                optimizer_d.zero_grad()
                with torch.amp.autocast(device_type='cuda' if agent.device.type == 'cuda' else 'cpu', enabled=agent.device.type == 'cuda'):
                    loss_d = agent.training_step(None, episode, optimizer_idx=0)
                scaler.scale(loss_d).backward()
                scaler.step(optimizer_d)
                scaler.update()
                
                # Train policy
                optimizer_p.zero_grad()
                with torch.amp.autocast(device_type='cuda' if agent.device.type == 'cuda' else 'cpu', enabled=agent.device.type == 'cuda'):
                    loss_p = agent.training_step(None, episode, optimizer_idx=1)
                scaler.scale(loss_p).backward()
                scaler.step(optimizer_p)
                
                # Update scaler for next iteration
                scaler.update()
                
                # Log training metrics
                if logger is not None:
                    logger.experiment.add_scalar("train/loss_discriminator", loss_d.item(), episode)
                    logger.experiment.add_scalar("train/loss_policy", loss_p.item(), episode)
                    
            except RuntimeError as e:
                print(f"Error during training step: {str(e)}")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()  # Wait for all kernels to finish
            
            if logger:
                logger.experiment.add_scalar("train/loss_discriminator", loss_d.item(), episode)
                logger.experiment.add_scalar("train/loss_policy", loss_p.item(), episode)
        
        # Perform evaluation
        if (episode + 1) % eval_interval == 0:
            agent.eval()
            eval_rewards = []
            with torch.no_grad():
                for _ in range(5):
                    eval_reward, _ = collect_rollout(env, agent)
                    eval_rewards.append(eval_reward)
            
            avg_eval_reward = np.mean(eval_rewards)
            if logger:
                logger.experiment.add_scalar("eval/avg_reward", avg_eval_reward, episode)
            
            agent.train()
        

        if (episode + 1) % save_interval == 0 or episode == num_episodes - 1:
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(checkpoint_dir, f"diayn_episode_{episode+1}.ckpt")
            torch.save({
                'episode': episode,
                'model_state_dict': agent.state_dict(),
                'optimizer_d_state_dict': optimizer_d.state_dict(),
                'optimizer_p_state_dict': optimizer_p.state_dict(),
                'episode_rewards': episode_rewards,
                'episode_lengths': episode_lengths,
                'config': config
            }, checkpoint_path)
    

    final_model_path = os.path.join(checkpoint_dir, "diayn_final.pt")
    torch.save({
        'model_state_dict': agent.state_dict(),
        'config': config
    }, final_model_path)
    print(f"\nTraining complete! Final model saved to {final_model_path}")
    

    metrics = {
        'episode_rewards': episode_rewards,
        'episode_lengths': episode_lengths,
        'config': config
    }
    
    metrics_path = os.path.join(log_dir, 'training_metrics.pt')
    torch.save(metrics, metrics_path)
    print(f"Training metrics saved to {metrics_path}")
    
    return agent

In [29]:
train()

Using device: cuda


  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


GPU: Tesla P100-PCIE-16GB
CUDA available: True
CUDA version: 12.4
PyTorch version: 2.6.0+cu124


Training:   0%|          | 0/1000 [00:00<?, ?episode/s, reward=0.54, avg_reward=0.54, length=132, steps=132]/usr/local/lib/python3.11/dist-packages/pytorch_lightning/core/module.py:441: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`
Training: 100%|██████████| 1000/1000 [16:12<00:00,  1.03episode/s, reward=-0.36, avg_reward=-0.43, length=387, steps=582012]


Training complete! Final model saved to checkpoints/diayn_final.pt
Training metrics saved to logs/diayn_20250727_194924/training_metrics.pt





DIAYNAgent(
  (encoder): MiniGridEncoder(
    (conv): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU()
      (4): Flatten(start_dim=1, end_dim=-1)
    )
    (fc): Linear(in_features=1568, out_features=64, bias=True)
  )
  (policy): Sequential(
    (0): Linear(in_features=72, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=7, bias=True)
  )
  (discriminator): SkillDiscriminator(
    (net): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=8, bias=True)
    )
  )
)