In [1]:
#!pip install "gymnasium>=1.2.0" "imageio>=2.37.0" "lightning>=2.5.2" "matplotlib>=3.10.3" "minigrid>=3.0.0" "pyyaml>=6.0.2" "minigrid" "seaborn>=0.13.2" "tensorboard>=2.20.0" "torch>=2.7.1"

# src

## Models

### Base Model

In [3]:
import torch.nn as nn




class BaseModel(nn.Module):
    def __init__(self, ):
        super().__init__()
    
    def forward(self, x):
        raise NotImplementedError("Subclasses must implement the forward method.")
    
    
    def save(self, path: str):
        torch.save(self.state_dict(), path)
    
    
    def load(self, path: str):
        self.load_state_dict(torch.load(path))

### Discriminator

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F



class SkillDiscriminator(nn.Module):
    """Skill discriminator for DIAYN that classifies which skill was used."""
    
    def __init__(self, input_dim: int, skill_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, skill_dim)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)
    
    def compute_reward(self, state: torch.Tensor, skill: torch.Tensor) -> torch.Tensor:
        """Compute the intrinsic reward for the given state and skill."""
        with torch.no_grad():
            logits = self.forward(state)
            log_probs = F.log_softmax(logits, dim=-1)
            return (log_probs * skill).sum(dim=-1, keepdim=True)


### Encoder

In [5]:
from torch import nn
from typing import Tuple
import torch


class MiniGridEncoder(BaseModel):
    """Encoder for the MiniGrid environment So that it can be used in the DIAYN agent
    
    Args:
        input_shape (Tuple[int]): The shape of the input observation.
        hidden_size (int): The size of the hidden layer.
    
    Returns:
        torch.Tensor: The encoded observation.
    """
    
    def __init__(self, obs_shape: Tuple[int], feature_dim: int = 64, obs_type: str = "rgb"):
        super().__init__()
        self.obs_type = obs_type
        self.obs_shape = obs_shape
        
        # Determine input channels based on observation type
        self.in_channels = 3 if obs_type == "rgb" else 1
        
        # CNN architecture for processing observations
        self.conv = nn.Sequential(
            nn.Conv2d(self.in_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        

        with torch.no_grad():
            # Create a dummy input with correct shape (N, C, H, W)
            dummy = torch.zeros(1, self.in_channels, *obs_shape[:2])
            conv_out = self.conv(dummy)
            self.conv_output_dim = conv_out.shape[1]
            

        self.fc = nn.Linear(self.conv_output_dim, feature_dim)
        self.feature_dim = feature_dim
    
    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        """Forward pass of the encoder
        
        Args:
            obs (torch.Tensor): The input observation of shape (batch, H, W, C) or (H, W, C)
            
        Returns:
            torch.Tensor: The encoded observation of shape (batch, hidden_dim)
        """
        # Ensure we have a batch dimension
        if len(obs.shape) == 3:  # (H, W, C) -> (1, H, W, C)
            obs = obs.unsqueeze(0)
            
        # Convert to float and normalize if needed
        if obs.dtype == torch.uint8:
            obs = obs.float() / 255.0
            
        # Convert from NHWC to NCHW format expected by PyTorch
        if obs.shape[-1] in [1, 3]:  # If channels are last
            obs = obs.permute(0, 3, 1, 2)  # NHWC -> NCHW
            
        # Ensure we have the right number of channels
        if self.obs_type == 'rgb' and obs.shape[1] != 3:
            if obs.shape[1] == 1:  # If grayscale, repeat to 3 channels
                obs = obs.repeat(1, 3, 1, 1)
            else:
                raise ValueError(f"Expected 1 or 3 channels for RGB, got {obs.shape[1]} channels")
                

        x = self.conv(obs)
        return torch.relu(self.fc(x))
    



## Agents

### Base Agent

In [2]:
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import torch.nn as nn
import torch
from torch.utils.tensorboard import SummaryWriter

class BaseAgent(nn.Module, ABC):
    """Base class for all agents.
    
    This base class provides common functionality for all agents, including
    device management and basic training utilities.
    """
    
    def __init__(self, config: Dict[str, Any], writer: Optional[SummaryWriter] = None):
        """Initialize the base agent.
        
        Args:
            config: Configuration dictionary for the agent
            writer: TensorBoard SummaryWriter for logging
        """
        super().__init__()
        self.config = config
        self.writer = writer
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    @abstractmethod
    def act(self, obs: Dict[str, Any], deterministic: bool = False) -> Any:
        """Select an action given an observation.
        
        Args:
            obs: Observation from the environment
            deterministic: Whether to sample deterministically
            
        Returns:
            Action to take
        """
        pass
    
    def to(self, device):
        """Move model to device and update self.device."""
        self.device = device
        return super().to(device)
        
    def log_scalar(self, tag: str, value: float, step: int):
        """Log a scalar value to TensorBoard.
        
        Args:
            tag: Tag for the scalar
            value: Value to log
            step: Current step for x-axis
        """
        if self.writer is not None:
            self.writer.add_scalar(tag, value, step)
    
    def log_histogram(self, tag: str, values: torch.Tensor, step: int):
        """Log a histogram to TensorBoard.
        
        Args:
            tag: Tag for the histogram
            values: Values to create histogram from
            step: Current step for x-axis
        """
        if self.writer is not None:
            self.writer.add_histogram(tag, values, step)
    
    def log_model_graph(self, model: nn.Module, sample_input: Any):
        """Log model graph to TensorBoard.
        
        Args:
            model: Model to log
            sample_input: Sample input for the model
        """
        if self.writer is not None:
            try:
                self.writer.add_graph(model, sample_input)
                self.writer.flush()
            except Exception as e:
                print(f"Failed to log model graph: {e}")
    
    def save_checkpoint(self, path: str, **kwargs):
        """Save model checkpoint.
        
        Args:
            path: Path to save checkpoint to
            **kwargs: Additional items to save in checkpoint
        """
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'config': self.config,
            **kwargs
        }
        torch.save(checkpoint, path)
    
    @classmethod
    def load_checkpoint(cls, path: str, **kwargs):
        """Load model from checkpoint.
        
        Args:
            path: Path to checkpoint file
            **kwargs: Additional arguments to pass to model constructor
            
        Returns:
            Loaded model instance
        """
        checkpoint = torch.load(path, map_location='cpu')
        config = {**checkpoint['config'], **kwargs}
        model = cls(config)
        model.load_state_dict(checkpoint['model_state_dict'])
        return model

2025-07-28 22:43:16.294927: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753742596.317975     469 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753742596.324927     469 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


### DIAYN Agent

In [6]:
import os
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from collections import deque, namedtuple
from typing import Dict, Any, Tuple, Optional
from torch.utils.tensorboard import SummaryWriter

Transition = namedtuple('Transition', 
                       ('state', 'action', 'skill', 'next_state', 'done', 'reward'))


class DIAYNAgent(BaseAgent):
    """Diversity is All You Need (DIAYN) agent implementation."""
    
    def __init__(self, config: Dict[str, Any], writer: Optional[SummaryWriter] = None, log_dir: Optional[str] = None):
        """Initialize the DIAYN agent.
        
        Args:
            config: Configuration dictionary containing agent parameters
            writer: TensorBoard SummaryWriter for logging
            log_dir: Directory for logging
        """
        super().__init__(config, writer)
        
        self.log_dir = log_dir
        # Environment parameters
        self.obs_shape = config["obs_shape"]
        self.action_dim = config["action_dim"]
        self.skill_dim = config.get("skill_dim", 8)
        
        # Training parameters
        self.lr = float(config.get("lr", 3e-4))
        self.gamma = float(config.get("gamma", 0.99))
        self.batch_size = int(config.get("batch_size", 64))
        self.replay_size = int(config.get("replay_size", 10000))
        self.entropy_coef = float(config.get("entropy_coef", 0.01))
        
        # Initialize models
        self.encoder = MiniGridEncoder(
            self.obs_shape,
            feature_dim=config.get("hidden_dim", 64)
        )
        
        # Policy network
        self.policy = nn.Sequential(
            nn.Linear(self.encoder.feature_dim + self.skill_dim, 256),
            nn.ReLU(),
            nn.Linear(256, self.action_dim)
        )
        
        # Discriminator network
        self.discriminator = SkillDiscriminator(
            input_dim=self.encoder.feature_dim,
            skill_dim=self.skill_dim,
            hidden_dim=config.get("hidden_dim", 256)
        )
        
        # Optimizers
        self.optimizer_d = torch.optim.AdamW(
            self.discriminator.parameters(),
            lr=self.lr,
            weight_decay=1e-5
        )
        self.optimizer_p = torch.optim.AdamW(
            list(self.encoder.parameters()) + list(self.policy.parameters()),
            lr=self.lr,
            weight_decay=1e-5
        )
        
        # Learning rate schedulers
        self.scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer_d, T_max=1000
        )
        self.scheduler_p = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer_p, T_max=1000
        )
        
        # Replay buffer
        self.replay_buffer = deque(maxlen=self.replay_size)
        
        # Move to device
        self.to(self.device)
    
    
    
    
    def forward(self, obs: torch.Tensor, skill: torch.Tensor) -> torch.Tensor:
        """Forward pass through the agent."""
        encoded = self.encoder(obs)
        x = torch.cat([encoded, skill], dim=-1)
        return self.policy(x)
    
    
    
    
    def act(self, obs: Dict[str, Any], skill: np.ndarray, deterministic: bool = False) -> int:
        """
        Select an action given an observation and skill.
        Args:
            obs: Observation from the environment (can be dict or array-like)
            skill: Skill vector (numpy array)
            deterministic: Whether to sample deterministically
            
        Returns:
            Action to take
        """
        with torch.no_grad():
            try:
                # Handle both dictionary and array observations
                if isinstance(obs, dict) and 'observation' in obs:
                    obs_data = obs['observation']
                else:
                    obs_data = obs
                    
                # Convert to tensor and ensure correct shape and device
                obs_tensor = torch.as_tensor(obs_data, dtype=torch.float32, device=self.device)
                if obs_tensor.dim() == 3:  # If image, add batch dimension
                    obs_tensor = obs_tensor.unsqueeze(0)
                    
                # Convert skill to tensor
                skill_tensor = torch.as_tensor(skill, dtype=torch.float32, device=self.device)
                if skill_tensor.dim() == 1:  # Add batch dimension if needed
                    skill_tensor = skill_tensor.unsqueeze(0)
                
                # Get action logits
                logits = self.forward(obs_tensor, skill_tensor)
                
                # Check for invalid logits
                if torch.isnan(logits).any() or torch.isinf(logits).any():
                    print("Warning: Invalid logits detected, using random action")
                    return np.random.randint(0, self.action_dim)
                
                if deterministic:
                    action = torch.argmax(logits, dim=-1)
                else:
                    # Clamp logits for numerical stability
                    logits = torch.clamp(logits, min=-10, max=10)
                    probs = F.softmax(logits, dim=-1)
                    
                    # Check for invalid probabilities
                    if torch.isnan(probs).any() or torch.isinf(probs).any():
                        print("Warning: Invalid probabilities detected, using uniform distribution")
                        probs = torch.ones_like(logits) / logits.shape[-1]
                    
                    action = torch.multinomial(probs, num_samples=1)
                
                return action.item()
                
            except Exception as e:
                print(f"Error in act(): {e}")
                # Return random action as fallback
                return np.random.randint(0, self.action_dim)
        
    def update(self, batch: Dict[str, torch.Tensor], step: int) -> Tuple[float, float]:
        """Update the agent's parameters using a batch of experiences.
        
        Args:
            batch: Batch of transitions
            step: Current training step (for logging)
            
        Returns:
            Tuple of (discriminator_loss, policy_loss)
        """
        # Unpack batch
        states = batch["state"].to(self.device)
        actions = batch["action"].to(self.device)
        skills = batch["skill"].to(self.device)
        next_states = batch["next_state"].to(self.device)
        dones = batch["done"].to(self.device)
        
        # Train discriminator
        with torch.amp.autocast(device_type=self.device.type, enabled=self.device.type == 'cuda'):
            # Encode states
            states_enc = self.encoder(states)
            next_states_enc = self.encoder(next_states).detach()
            
            # Train discriminator
            logits = self.discriminator(next_states_enc)
            target = skills.argmax(dim=-1)
            if target.max() >= logits.size(1):
                raise ValueError(f"Bad skill index {target.max()} vs {logits.size(1)}")

            loss_d = F.cross_entropy(logits, target)
            
            # Train policy
            policy_input = torch.cat([states_enc, skills], dim=-1)
            logits = self.policy(policy_input)
            probs = F.softmax(logits, dim=-1)
            log_probs = F.log_softmax(logits, dim=-1)
            entropy = -(probs * log_probs).sum(dim=-1)
            
            # Compute intrinsic reward
            with torch.no_grad():
                pred_skill_probs = F.softmax(self.discriminator(next_states_enc), dim=-1)
                log_pred_skill_probs = torch.log(pred_skill_probs + 1e-6)
                intrinsic_reward = (log_pred_skill_probs * skills).sum(dim=-1)
                
            # Compute policy loss
            policy_loss = -(log_probs.gather(1, actions.unsqueeze(1)) * intrinsic_reward.unsqueeze(1)).mean()
            entropy_loss = -self.entropy_coef * entropy.mean()
            loss_p = policy_loss + entropy_loss
            
        # Update discriminator
        self.optimizer_d.zero_grad()
        loss_d.backward()
        torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 0.5)
        self.optimizer_d.step()
        
        # Update policy
        self.optimizer_p.zero_grad()
        loss_p.backward()
        torch.nn.utils.clip_grad_norm_(
            list(self.encoder.parameters()) + list(self.policy.parameters()),
            0.5
        )
        self.optimizer_p.step()
        
        # Update learning rates
        self.scheduler_d.step()
        self.scheduler_p.step()
        
        # Log metrics
        if self.writer is not None:
            self.writer.add_scalar('train/discriminator_loss', loss_d.item(), step)
            self.writer.add_scalar('train/policy_loss', loss_p.item(), step)
            self.writer.add_scalar('train/entropy', entropy.mean().item(), step)
            self.writer.add_scalar('train/intrinsic_reward', intrinsic_reward.mean().item(), step)
            
            # Log learning rates
            self.writer.add_scalar('lr/discriminator', self.scheduler_d.get_last_lr()[0], step)
            self.writer.add_scalar('lr/policy', self.scheduler_p.get_last_lr()[0], step)
        
        return loss_d.item(), loss_p.item()
    
    def add_to_replay(self, transition: Transition) -> None:
        """Add a transition to the replay buffer."""
        self.replay_buffer.append(transition)
    
    def sample_batch(self, batch_size: int) -> Dict[str, torch.Tensor]:
        """
        Sample a batch of transitions from the replay buffer.
        Args:
            batch_size: Number of transitions to sample
        Returns:
            Dictionary containing batched tensors for states, actions, skills, next_states, dones, rewards.
        """
        if len(self.replay_buffer) < batch_size:
            return None

        transitions = random.sample(self.replay_buffer, batch_size)
        batch = Transition(*zip(*transitions))
        states = torch.stack([torch.FloatTensor(s) for s in batch.state])
        actions = torch.LongTensor(batch.action)
        skills = torch.FloatTensor(np.array(batch.skill))
        next_states = torch.stack([torch.FloatTensor(s) for s in batch.next_state])
        dones = torch.FloatTensor(batch.done)
        rewards = torch.FloatTensor(batch.reward)

        return {
            'state': states,
            'action': actions,
            'skill': skills,
            'next_state': next_states,
            'done': dones,
            'reward': rewards
        }
    
    def save_checkpoint(self, path: str) -> None:
        """Save agent state to checkpoint."""
        torch.save({
            'encoder_state_dict': self.encoder.state_dict(),
            'policy_state_dict': self.policy.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict(),
            'optimizer_d_state_dict': self.optimizer_d.state_dict(),
            'optimizer_p_state_dict': self.optimizer_p.state_dict(),
            'scheduler_d_state_dict': self.scheduler_d.state_dict(),
            'scheduler_p_state_dict': self.scheduler_p.state_dict(),
            'replay_buffer': self.replay_buffer,
            'config': self.config
        }, path)
    
    def load_checkpoint(self, path: str) -> None:
        """Load agent state from checkpoint."""
        checkpoint = torch.load(path, map_location=self.device)
        self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
        self.policy.load_state_dict(checkpoint['policy_state_dict'])
        self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        self.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
        self.optimizer_p.load_state_dict(checkpoint['optimizer_p_state_dict'])
        self.scheduler_d.load_state_dict(checkpoint['scheduler_d_state_dict'])
        self.scheduler_p.load_state_dict(checkpoint['scheduler_p_state_dict'])
        self.replay_buffer = checkpoint['replay_buffer']
        self.config = checkpoint['config']
        
        # Move models to device
        self.to(self.device)
    
    def log_model_graph(self) -> None:
        if self.writer is None:
            return
        if not all(hasattr(self, attr) for attr in ['encoder', 'policy', 'discriminator']):
            print("Graph logging skipped: model parts not initialized.")
            return

        dummy_obs = torch.zeros((1, *self.obs_shape), device=self.device)
        dummy_skill = torch.zeros((1, self.skill_dim), device=self.device)

        # Policy wrapper
        class PolicyWrapper(nn.Module):
            def __init__(self, agent):
                super().__init__()
                self.encoder = agent.encoder
                self.policy = agent.policy
            def forward(self, obs, skill):
                encoded = self.encoder(obs)
                return self.policy(torch.cat([encoded, skill], dim=-1))

        # Discriminator wrapper
        class DiscriminatorWrapper(nn.Module):
            def __init__(self, agent):
                super().__init__()
                self.encoder = agent.encoder
                self.discriminator = agent.discriminator
            def forward(self, obs):
                return self.discriminator(self.encoder(obs))

        # Write policy graph in its own run
        writer_policy = SummaryWriter(log_dir=os.path.join(self.log_dir, "policy_graph"))
        writer_policy.add_graph(PolicyWrapper(self), (dummy_obs, dummy_skill))
        writer_policy.flush()

        # Write discriminator graph in separate run
        writer_disc = SummaryWriter(log_dir=os.path.join(self.log_dir, "discriminator_graph"))
        writer_disc.add_graph(DiscriminatorWrapper(self), (dummy_obs,))
        writer_disc.flush()

    
    def sample_skill(self) -> np.ndarray:
        """Sample a random one-hot skill vector."""
        skill_idx = np.random.randint(0, self.skill_dim)
        skill = np.zeros(self.skill_dim, dtype=np.float32)
        skill[skill_idx] = 1.0
        return skill
    
    def train(self, mode: bool = True) -> 'DIAYNAgent':
        """Set the agent in training mode."""
        super().train(mode)
        self.encoder.train(mode)
        self.policy.train(mode)
        self.discriminator.train(mode)
        return self
    
    def eval(self) -> 'DIAYNAgent':
        """Set the agent in evaluation mode."""
        return self.train(False)
    
    def to(self, device) -> 'DIAYNAgent':
        """Move the agent to the specified device."""
        super().to(device)
        self.encoder = self.encoder.to(device)
        self.policy = self.policy.to(device)
        self.discriminator = self.discriminator.to(device)
        return self
    
    def store_transition(self, transition: Transition) -> None:
        self.replay_buffer.append(transition)



## Configs

In [None]:
with open("diayn.yaml","w") as f :
    f.write("""
# configs/diayn.yaml
# DIAYN (Diversity is All You Need) Configuration
# Reference: https://arxiv.org/abs/1802.06070

# ===== Environment Configuration =====
env_id: "MiniGrid-Empty-8x8-v0"  # MiniGrid environment ID
obs_type: "rgb"  # Observation type: "rgb" (3-channel) or "grid" (single-channel)

# Available MiniGrid environments (uncomment to use):
# - "MiniGrid-Empty-5x5-v0"
# - "MiniGrid-Empty-8x8-v0"
# - "MiniGrid-Empty-16x16-v0"
# - "MiniGrid-DoorKey-5x5-v0"
# - "MiniGrid-DoorKey-8x8-v0"
# - "MiniGrid-FourRooms-v0"

# ===== Agent Configuration =====
agent:
  # Observation and action spaces (will be auto-filled)
  obs_shape: [7, 7, 3]  # [height, width, channels] - will be overridden
  action_dim: 7  # MiniGrid action space size - will be overridden
  
  # Skill configuration
  skill_dim: 8  # Number of discrete skills to learn
  
  # Network architecture
  hidden_dim: 512  # Hidden layer size for all networks
  
  # Training hyperparameters
  lr: 1e-5  # Learning rate
  gamma: 0.99  # Discount factor
  entropy_coef: 0.1  # Entropy coefficient for policy gradient
  
  # Replay buffer
  batch_size: 32768 # Batch size for training
  replay_size: 50000  # Maximum replay buffer size
  
  # Intrinsic reward scaling
  intrinsic_reward_scale: 1.0  # Scale factor for intrinsic rewards
  
  # Device configuration
  device: "auto"  # "auto", "cpu", or "cuda"

# ===== Training Configuration =====
training:
  # Training procedure
  max_episodes: 1000000 # Maximum number of training episodes
  max_steps_per_episode: 1000  # Maximum steps per episode
  
  # Logging and evaluation
  log_interval: 10000  # Log metrics every N episodes
  eval_interval: 50000  # Evaluate every N episodes
  eval_episodes: 5000  # Number of evaluation episodes
  
  # Checkpointing
  save_interval: 10000  # Save model every N episodes
  
  # Early stopping (optional)
  early_stop_reward: None  # Stop training if average reward exceeds this value
  patience: 100  # Number of episodes to wait before early stopping

  #Envirement 
  env_parallel: 4
# ===== Logging Configuration =====
logging:
  log_dir: "logs"  # Base directory for logs
  project_name: "skill-discovery"  # Project name for experiment tracking
  use_tensorboard: true  # Enable TensorBoard logging
  save_video: false  # Save video of agent's performance
  video_interval: 1000  # Save video every N episodes

# ===== Notes =====
# 1. obs_shape and action_dim will be automatically set based on the environment
# 2. For best results, adjust batch_size and replay_size based on available memory
# 3. Increase skill_dim for more diverse behaviors, but training will be slower
# 4. Monitor training progress using TensorBoard: `tensorboard --logdir=logs`
    """)

In [8]:
4096*2

8192

## Envs

### MinitGrid init

In [9]:
"""Register MiniGrid environments with Gymnasium."""
from gymnasium.envs.registration import register

# Register all available MiniGrid environments
register(
    id="MiniGrid-Empty-5x5-v0",
    entry_point="minigrid.envs:EmptyEnv",
    kwargs={"size": 5}
)

register(
    id="MiniGrid-Empty-8x8-v0",
    entry_point="minigrid.envs:EmptyEnv",
    kwargs={"size": 8}
)

register(
    id="MiniGrid-Empty-16x16-v0",
    entry_point="minigrid.envs:EmptyEnv",
    kwargs={"size": 16}
)

# You can add more environments as needed
register(
    id="MiniGrid-DoorKey-5x5-v0",
    entry_point="minigrid.envs:DoorKeyEnv",
    kwargs={"size": 5}
)

register(
    id="MiniGrid-DoorKey-8x8-v0",
    entry_point="minigrid.envs:DoorKeyEnv",
    kwargs={"size": 8}
)


### MiniGrid_Wrapper

In [10]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces



class MiniGridWrapper(gym.Wrapper):
    """Wrapper for the MiniGrid environment"""
    def __init__(self,env,skill_dim=8,obs_type="rgb"):
        super().__init__(env)
        self.skill_dim = skill_dim
        self.obs_type = obs_type
        
        if obs_type == "rgb":
            self.obs_shape = (7,7,3)
        else: #obs_type = grid
            self.obs_shape = (7,7)
        
        self.observation_space = spaces.Dict({
            "observation": spaces.Box(
                low=0,high=255,
                shape=self.obs_shape,
                dtype=np.uint8
            ),
            "skill": spaces.Box(
                low=-1.0,high=1.0,
                shape=(skill_dim,),
                dtype=np.float32
            )
        })
    
    def reset(self,**kwargs):
        obs,info = super().reset(**kwargs)
        return self._process_obs(obs),info
    
    def step(self,action):
        obs,reward,terminated,truncated,info = self.env.step(action)
        return self._process_obs(obs),reward,terminated,truncated,info
    
    def _process_obs(self,obs):
        """
        Process the observation to match the observation space
        """
        
        if self.obs_type == "rgb":
            obs_array = obs["image"][...,:3]
        else:
            obs_array = obs["image"][...,0]
        
        skill = np.random.uniform(-1,1,size=(self.skill_dim,))
        
        return {
            "observation":obs_array,
            "skill":skill
        }
        

class ObservationExtractor(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = self.env.observation_space['observation']
        
    def observation(self, obs):
        return obs['observation']
        

## Utils

### Train utils

In [11]:
import yaml
import argparse
import numpy as np
from typing import Tuple, Optional
from torch.utils.tensorboard import SummaryWriter
from collections import namedtuple



Transition = namedtuple('Transition', 
                       ('state', 'action', 'skill', 'next_state', 'done', 'reward'))




def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="/kaggle/working/diayn.yaml")
    parser.add_argument("--log_dir", type=str, default="logs")
    return parser.parse_args()




def load_config(config_path):
    with open(config_path, "r") as f:
        return yaml.safe_load(f)





def collect_rollout(env: MiniGridWrapper, agent: DIAYNAgent, max_steps: int = 1000) -> Tuple[float, int]:
    """Collect a single rollout from the environment with optimized data transfer.
    Args:
        env: The environment to collect the rollout from
        agent: The agent to select actions
        max_steps: Maximum number of steps per episode
        
    Returns:
        Tuple containing total reward and episode length
    """
    obs, _ = env.reset()
    skill = agent.sample_skill()
    episode_reward = 0.0
    episode_length = 0
    done = False
    # print("Episode length: ", episode_length)
    # print("Max steps: ", max_steps)
    # print(episode_length < max_steps)
    while not done and episode_length < max_steps:
        
        action = agent.act(obs, skill)        
        next_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        transition = Transition(state=obs['observation'] if isinstance(obs, dict) and 'observation' in obs else obs,
                                action=action,
                                skill=skill,
                                next_state=next_obs['observation'] if isinstance(next_obs, dict) and 'observation' in next_obs else next_obs,
                                done=done,
                                reward=reward)
        agent.add_to_replay(transition)
        obs = next_obs
        episode_reward += reward
        episode_length += 1
        
    return episode_reward, episode_length



def collect_parallel_rollout(env, agent: DIAYNAgent, max_steps: int = 1000) -> Tuple[np.ndarray, np.ndarray]:
    """
    Collect a rollout using vectorized environments.
    
    Args:
        env: Vectorized environment (SyncVectorEnv or AsyncVectorEnv)
        agent: DIAYN agent
        max_steps: Maximum number of steps per episode
    
    Returns:
        Tuple of (total_rewards, episode_lengths) arrays per environment
    """
    num_envs = env.num_envs
    obs, _ = env.reset()
    skills = np.array([agent.sample_skill() for _ in range(num_envs)], dtype=np.float32)

    episode_rewards = np.zeros(num_envs, dtype=np.float32)
    episode_lengths = np.zeros(num_envs, dtype=np.int32)
    dones = np.zeros(num_envs, dtype=bool)

    for step in range(max_steps):
        actions = []
        for i in range(num_envs):
            if not dones[i]:
                action = agent.act(obs[i], skills[i])
            else:
                action = 0  # Dummy action for already-done envs
            actions.append(action)

        actions = np.array(actions)
        next_obs, rewards, terminations, truncations, infos = env.step(actions)

        step_dones = np.logical_or(terminations, truncations)

        for i in range(num_envs):
            if dones[i]:  # Skip if already done
                continue

            transition = Transition(
                state=obs[i]['observation'] if isinstance(obs[i], dict) else obs[i],
                action=actions[i],
                skill=skills[i],
                next_state=next_obs[i]['observation'] if isinstance(next_obs[i], dict) else next_obs[i],
                done=step_dones[i],
                reward=rewards[i]
            )
            agent.add_to_replay(transition)

            episode_rewards[i] += rewards[i]
            episode_lengths[i] += 1

        dones = np.logical_or(dones, step_dones)
        obs = next_obs

        if np.all(dones):
            break

    return episode_rewards, episode_lengths







def evaluate(env, agent: DIAYNAgent, num_episodes: int = 5, episode: int = 0, writer: Optional[SummaryWriter] = None) -> float:
    """Evaluate the agent's performance on a vectorized environment."""
    agent.eval()
    eval_rewards = []
    eval_lengths = []
    
    episodes_collected = 0
    num_envs = env.num_envs

    while episodes_collected < num_episodes:
        rewards, lengths = collect_parallel_rollout(env, agent, max_steps=1000)
        
        for r ,l in zip(rewards, lengths):
            eval_rewards.append(r)
            eval_lengths.append(l)
            episodes_collected += 1
            if episodes_collected >= num_episodes:
                break

    avg_eval_reward = np.mean(eval_rewards)
    if writer:
        writer.add_scalar("eval/avg_reward", avg_eval_reward, episode)
        writer.add_scalar("eval/avg_length", np.mean(eval_lengths), episode)
    
    agent.train()
    return avg_eval_reward


### Envirement Utils

In [12]:
from gymnasium.vector import SyncVectorEnv


def make_env(env_id, obs_type):
    def _thunk():
        env = gym.make(env_id, render_mode="rgb_array")
        env = MiniGridWrapper(env, obs_type=obs_type)
        env = ObservationExtractor(env)
        return env
    return _thunk



## Scripts

### Train

In [13]:
import os
import sys
import torch
import numpy as np
from tqdm import tqdm
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter


def train():
    sys.argv = ['']
    args = parse_args()
    config = load_config(args.config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_name = torch.cuda.get_device_name(device)
    
    print(f"Using device: {device} | Device  Name: {device_name}")
    seed = config.get("seed", 42)
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Create  vectorized environments
    num_envs = config.get("training", {}).get("env_parallel", 1)
    env_thunks = [make_env(config["env_id"], config.get("obs_type", "rgb")) for _ in range(num_envs)]
    print(env_thunks[0])
    env = SyncVectorEnv(env_thunks) 

    sample_obs = env.reset()[0]  # env.reset() returns (obs, info)
    print(sample_obs.shape)
    config["agent"]["obs_shape"] = sample_obs.shape[1:]
    config["agent"]["action_dim"] = env.action_space.nvec[0]
    

    # Create unique log directory with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = os.path.join(args.log_dir, f"diayn_{timestamp}")
    os.makedirs(log_dir, exist_ok=True)
    
    # Initialize TensorBoard writer with the correct log directory
    writer = SummaryWriter(log_dir=log_dir)
    print(f"TensorBoard logs will be saved to: {os.path.abspath(log_dir)}")


    agent = DIAYNAgent(config["agent"], writer=writer, log_dir=log_dir).to(device)
    agent.log_model_graph()

    training_cfg = config.get("training", {})
    print(training_cfg.keys())
    num_episodes = training_cfg.get("max_episodes", 1000)
    eval_interval = training_cfg.get("eval_interval", 100)
    save_interval = training_cfg.get("save_interval", 100)
    
    # Training loop
    best_reward = -float('inf')
    episode_rewards = []
    episode_lengths = []
    
    # Create checkpoint directory
    checkpoint_dir = os.path.join(log_dir, "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Training progress bar
    pbar = tqdm(range(1, num_episodes + 1), desc="Training")
    
    for episode in pbar:
        try:
            # Clear CUDA cache periodically
            if episode % 10 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()
                
        except Exception as e:
            print(f"Error during episode {episode}: {str(e)}")
            import traceback
            traceback.print_exc()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            continue
        
        # Perform training step
        if len(agent.replay_buffer) >= agent.batch_size:
            try:
                batch = agent.sample_batch(agent.batch_size)
                if batch is None:
                    continue

                # Update both discriminator and policy
                loss_d, loss_p = agent.update(batch, episode)
                
                if writer is not None:
                    writer.add_scalar('train/discriminator_loss', loss_d, episode)
                    writer.add_scalar('train/policy_loss', loss_p, episode)
                    writer.add_scalar('train/replay_buffer_size', len(agent.replay_buffer), episode)

            except RuntimeError as e:
                print(f"Error during training step: {str(e)}")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()  # Wait for all kernels to finish
            
        # Perform evaluation
        if (episode + 1) % eval_interval == 0:
            _avg_eval_reward = evaluate(env, agent, num_episodes=5, episode=episode, writer=writer)
        

        if (episode + 1) % save_interval == 0 or episode == num_episodes - 1:
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(checkpoint_dir, f"diayn_episode_{episode+1}.ckpt")
            torch.save({
                'episode': episode,
                'model_state_dict': agent.state_dict(),
                'optimizer_d_state_dict': agent.optimizer_d.state_dict(),
                'optimizer_p_state_dict': agent.optimizer_p.state_dict(),
                'episode_rewards': episode_rewards,
                'episode_lengths': episode_lengths,
                'config': config
            }, checkpoint_path)
    

    final_model_path = os.path.join(checkpoint_dir, "diayn_final.pt")
    torch.save({
        'model_state_dict': agent.state_dict(),
        'optimizer_d_state_dict': agent.optimizer_d.state_dict(),
        'optimizer_p_state_dict': agent.optimizer_p.state_dict(),
        'config': config
    }, final_model_path)
    print(f"\nTraining complete! Final model saved to {final_model_path}")
    

    metrics = {
        'episode_rewards': episode_rewards,
        'episode_lengths': episode_lengths,
        'config': config
    }
    
    metrics_path = os.path.join(log_dir, 'training_metrics.pt')
    torch.save(metrics, metrics_path)
    print(f"Training metrics saved to {metrics_path}")
    
    return agent

In [None]:
agent = train()

Using device: cuda | Device  Name: Tesla P100-PCIE-16GB
<function make_env.<locals>._thunk at 0x7ac3a5169a80>
(4, 7, 7, 3)


  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


TensorBoard logs will be saved to: /kaggle/working/logs/diayn_20250728_224321


  if obs.shape[-1] in [1, 3]:  # If channels are last
  if self.obs_type == 'rgb' and obs.shape[1] != 3:


dict_keys(['max_episodes', 'max_steps_per_episode', 'log_interval', 'eval_interval', 'eval_episodes', 'save_interval', 'early_stop_reward', 'patience', 'env_parallel'])


  dones = torch.FloatTensor(batch.done)
Training:  85%|████████▌ | 850313/1000000 [05:22<29:47:10,  1.40it/s]

In [None]:
tensorboard --logdir /kaggle/working/logs

Collecting minigrid
  Downloading minigrid-3.0.0-py3-none-any.whl.metadata (6.7 kB)
Downloading minigrid-3.0.0-py3-none-any.whl (136 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m136.7/136.7 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: minigrid
Successfully installed minigrid-3.0.0
Note: you may need to restart the kernel to use updated packages.
