In [None]:
# Cell 1: Install dependencies
!pip install pybind11 gymnasium pygame tqdm
!apt-get update && apt-get install -y build-essential python3-dev

In [None]:
# Cell 2: Create C++ files
%%writefile sokoban.h
#pragma once

#include <vector>
#include <string>
#include <tuple>
#include <array>
#include <span>
#include <algorithm>
#include <ranges>
#include <concepts>
#include <string_view>

class Sokoban {
public:
    // Tile types
    static constexpr int WALL = 0;
    static constexpr int EMPTY = 1;
    static constexpr int PLAYER = 2;
    static constexpr int BOX = 3;
    static constexpr int TARGET = 4;
    static constexpr int BOX_ON_TARGET = 5;
    static constexpr int PLAYER_ON_TARGET = 6;
    
    // Actions
    static constexpr int UP = 0;
    static constexpr int DOWN = 1;
    static constexpr int LEFT = 2;
    static constexpr int RIGHT = 3;
    
    static constexpr int DEFAULT_SIZE = 10;
    static constexpr int MAX_ACTIONS = 4;

public:
    Sokoban();
    explicit Sokoban(int width, int height = DEFAULT_SIZE);
    
    void reset();
    void load_level(const std::string& level_str);
    std::tuple<std::vector<int>, float, bool> step(int action);
    std::vector<int> get_observation() const;
    std::vector<std::vector<int>> get_grid() const;
    bool is_solved() const;
    bool is_valid_action(int action) const;
    
    // Additional utility methods
    constexpr int get_width() const noexcept { return width_; }
    constexpr int get_height() const noexcept { return height_; }
    constexpr std::pair<int, int> get_player_position() const noexcept { return {player_x_, player_y_}; }

private:
    int width_, height_;
    std::vector<std::vector<int>> grid_;
    std::vector<std::vector<int>> initial_grid_;
    int player_x_, player_y_;
    int initial_player_x_, initial_player_y_;
    
    // Pre-allocated vectors to avoid dynamic allocation in step()
    mutable std::vector<int> observation_buffer_;
    
    // Helper methods
    static constexpr std::pair<int, int> get_direction_offset(int action) noexcept;
    constexpr bool is_valid_position(int x, int y) const noexcept;
    void find_player_position();
    void parse_level_string(std::string_view level_str);
    void create_default_level();
    void initialize_buffers();
    
    // Game logic helpers
    bool can_move_to(int x, int y) const noexcept;
    bool can_push_box_to(int x, int y) const noexcept;
    void move_player(int from_x, int from_y, int to_x, int to_y) noexcept;
    void move_box(int from_x, int from_y, int to_x, int to_y) noexcept;
};

In [None]:
# Cell 3: Create sokoban.cpp
%%writefile sokoban.cpp
#include "sokoban.h"
#include <sstream>
#include <stdexcept>
#include <cassert>

// Define static constexpr members for pybind11
constexpr int Sokoban::WALL;
constexpr int Sokoban::EMPTY;
constexpr int Sokoban::PLAYER;
constexpr int Sokoban::BOX;
constexpr int Sokoban::TARGET;
constexpr int Sokoban::BOX_ON_TARGET;
constexpr int Sokoban::PLAYER_ON_TARGET;

constexpr int Sokoban::UP;
constexpr int Sokoban::DOWN;
constexpr int Sokoban::LEFT;
constexpr int Sokoban::RIGHT;

Sokoban::Sokoban() : Sokoban(DEFAULT_SIZE, DEFAULT_SIZE) {}

Sokoban::Sokoban(int width, int height) 
    : width_(width), height_(height), player_x_(0), player_y_(0),
      initial_player_x_(0), initial_player_y_(0) {
    
    if (width <= 0 || height <= 0) {
        throw std::invalid_argument("Grid dimensions must be positive");
    }
    
    grid_.resize(height_, std::vector<int>(width_, EMPTY));
    
    // Only create default level if dimensions match default size
    if (width_ == DEFAULT_SIZE && height_ == DEFAULT_SIZE) {
        create_default_level();
        find_player_position();
    } else {
        // For custom dimensions, create a simple empty level with player at center
        if (width_ >= 3 && height_ >= 3) {
            // Add walls around the border
            for (int y = 0; y < height_; ++y) {
                for (int x = 0; x < width_; ++x) {
                    if (x == 0 || x == width_ - 1 || y == 0 || y == height_ - 1) {
                        grid_[y][x] = WALL;
                    }
                }
            }
            // Place player in center
            player_x_ = width_ / 2;
            player_y_ = height_ / 2;
            grid_[player_y_][player_x_] = PLAYER;
        } else {
            // Very small grid, just place player at (1,1) if possible
            player_x_ = std::min(1, width_ - 1);
            player_y_ = std::min(1, height_ - 1);
            grid_[player_y_][player_x_] = PLAYER;
        }
    }
    
    initial_grid_ = grid_;
    initial_player_x_ = player_x_;
    initial_player_y_ = player_y_;
    initialize_buffers();
}

void Sokoban::reset() {
    grid_ = initial_grid_;
    player_x_ = initial_player_x_;
    player_y_ = initial_player_y_;
}

void Sokoban::load_level(const std::string& level_str) {
    parse_level_string(level_str);
    initial_grid_ = grid_;
    find_player_position();
    initial_player_x_ = player_x_;
    initial_player_y_ = player_y_;
    initialize_buffers();
}

std::tuple<std::vector<int>, float, bool> Sokoban::step(int action) {
    if (!is_valid_action(action)) {
        return {get_observation(), 0.0f, false};
    }
    
    auto [dx, dy] = get_direction_offset(action);
    const int new_x = player_x_ + dx;
    const int new_y = player_y_ + dy;
    
    assert(is_valid_position(new_x, new_y));
    
    const int target_tile = grid_[new_y][new_x];
    
    // Handle wall collision
    if (target_tile == WALL) {
        return {get_observation(), 0.0f, false};
    }
    
    // Handle box pushing
    if (target_tile == BOX || target_tile == BOX_ON_TARGET) {
        const int box_new_x = new_x + dx;
        const int box_new_y = new_y + dy;
        
        if (!is_valid_position(box_new_x, box_new_y) || 
            !can_push_box_to(box_new_x, box_new_y)) {
            return {get_observation(), 0.0f, false};
        }
        
        // Execute box move
        move_box(new_x, new_y, box_new_x, box_new_y);
    }
    
    // Execute player move
    move_player(player_x_, player_y_, new_x, new_y);
    player_x_ = new_x;
    player_y_ = new_y;
    
    // Check win condition and calculate reward
    const bool solved = is_solved();
    const float reward = solved ? 1.0f : 0.0f;
    
    return {get_observation(), reward, solved};
}

std::vector<int> Sokoban::get_observation() const {
    // Use pre-allocated buffer to avoid allocation
    observation_buffer_.clear();
    observation_buffer_.reserve(width_ * height_);
    
    for (const auto& row : grid_) {
        observation_buffer_.insert(observation_buffer_.end(), row.begin(), row.end());
    }
    
    return observation_buffer_;
}

std::vector<std::vector<int>> Sokoban::get_grid() const {
    return grid_;
}

bool Sokoban::is_solved() const {
    bool has_boxes = false;
    bool has_unsolved_boxes = false;
    
    for (const auto& row : grid_) {
        for (int cell : row) {
            if (cell == BOX) {
                has_unsolved_boxes = true;
                has_boxes = true;
            } else if (cell == BOX_ON_TARGET) {
                has_boxes = true;
            }
        }
    }
    
    // Level is solved if there are boxes in the level but no unsolved boxes
    return has_boxes && !has_unsolved_boxes;
}

bool Sokoban::is_valid_action(int action) const {
    if (action < 0 || action >= MAX_ACTIONS) {
        return false;
    }
    
    auto [dx, dy] = get_direction_offset(action);
    const int new_x = player_x_ + dx;
    const int new_y = player_y_ + dy;
    
    return is_valid_position(new_x, new_y);
}

constexpr std::pair<int, int> Sokoban::get_direction_offset(int action) noexcept {
    switch (action) {
        case UP:    return {0, -1};
        case DOWN:  return {0, 1};
        case LEFT:  return {-1, 0};
        case RIGHT: return {1, 0};
        default:    return {0, 0};
    }
}

constexpr bool Sokoban::is_valid_position(int x, int y) const noexcept {
    return x >= 0 && x < width_ && y >= 0 && y < height_;
}

void Sokoban::find_player_position() {
    for (int y = 0; y < height_; ++y) {
        for (int x = 0; x < width_; ++x) {
            if (grid_[y][x] == PLAYER || grid_[y][x] == PLAYER_ON_TARGET) {
                player_x_ = x;
                player_y_ = y;
                return;
            }
        }
    }
    throw std::runtime_error("Player not found in level");
}

void Sokoban::parse_level_string(std::string_view level_str) {
    std::vector<std::string_view> lines;
    
    // Split by newlines using ranges
    size_t start = 0;
    for (size_t end = level_str.find('\n'); end != std::string_view::npos; 
         start = end + 1, end = level_str.find('\n', start)) {
        lines.emplace_back(level_str.substr(start, end - start));
    }
    if (start < level_str.length()) {
        lines.emplace_back(level_str.substr(start));
    }
    
    if (lines.empty()) {
        throw std::invalid_argument("Empty level string");
    }
    
    height_ = static_cast<int>(lines.size());
    width_ = static_cast<int>(std::ranges::max_element(lines, {}, 
        [](const auto& line) { return line.length(); })->length());
    
    // Resize grid
    grid_.assign(height_, std::vector<int>(width_, EMPTY));
    
    for (int y = 0; y < height_; ++y) {
        const auto& row = lines[y];
        for (int x = 0; x < std::min(width_, static_cast<int>(row.length())); ++x) {
            switch (row[x]) {
                case '#': grid_[y][x] = WALL; break;
                case ' ': grid_[y][x] = EMPTY; break;
                case '@': grid_[y][x] = PLAYER; break;
                case '$': grid_[y][x] = BOX; break;
                case '.': grid_[y][x] = TARGET; break;
                case '*': grid_[y][x] = BOX_ON_TARGET; break;
                case '+': grid_[y][x] = PLAYER_ON_TARGET; break;
                default:  grid_[y][x] = EMPTY; break;
            }
        }
    }
}

void Sokoban::create_default_level() {
    // Create a simple 7x7 level with 2 boxes as requested
    constexpr std::string_view default_level = 
        "#######\n"
        "#  .  #\n"
        "# $ @ #\n"
        "#     #\n"
        "# $   #\n"
        "#  .  #\n"
        "#######";
    
    parse_level_string(default_level);
}

void Sokoban::initialize_buffers() {
    observation_buffer_.reserve(width_ * height_);
}

bool Sokoban::can_move_to(int x, int y) const noexcept {
    if (!is_valid_position(x, y)) return false;
    
    const int tile = grid_[y][x];
    return tile == EMPTY || tile == TARGET;
}

bool Sokoban::can_push_box_to(int x, int y) const noexcept {
    if (!is_valid_position(x, y)) return false;
    
    const int tile = grid_[y][x];
    return tile == EMPTY || tile == TARGET;
}

void Sokoban::move_player(int from_x, int from_y, int to_x, int to_y) noexcept {
    // Restore the tile where player was
    const int from_tile = grid_[from_y][from_x];
    if (from_tile == PLAYER_ON_TARGET) {
        grid_[from_y][from_x] = TARGET;
    } else {
        grid_[from_y][from_x] = EMPTY;
    }
    
    // Place player at new position
    const int to_tile = grid_[to_y][to_x];
    if (to_tile == TARGET) {
        grid_[to_y][to_x] = PLAYER_ON_TARGET;
    } else {
        grid_[to_y][to_x] = PLAYER;
    }
}

void Sokoban::move_box(int from_x, int from_y, int to_x, int to_y) noexcept {
    // Place box at new position
    const int to_tile = grid_[to_y][to_x];
    if (to_tile == TARGET) {
        grid_[to_y][to_x] = BOX_ON_TARGET;
    } else {
        grid_[to_y][to_x] = BOX;
    }
    
    // Restore the tile where box was (will be overwritten by player move)
    const int from_tile = grid_[from_y][from_x];
    if (from_tile == BOX_ON_TARGET) {
        grid_[from_y][from_x] = TARGET;
    } else {
        grid_[from_y][from_x] = EMPTY;
    }
}

In [None]:
# Cell 4: Create Python bindings
%%writefile python_bindings.cpp
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "sokoban.h"

namespace py = pybind11;

PYBIND11_MODULE(sokoban_engine, m) {
    m.doc() = "High-performance Sokoban game engine for reinforcement learning";
    
    // Define constants at module level
    m.attr("WALL") = py::int_(Sokoban::WALL);
    m.attr("EMPTY") = py::int_(Sokoban::EMPTY);
    m.attr("PLAYER") = py::int_(Sokoban::PLAYER);
    m.attr("BOX") = py::int_(Sokoban::BOX);
    m.attr("TARGET") = py::int_(Sokoban::TARGET);
    m.attr("BOX_ON_TARGET") = py::int_(Sokoban::BOX_ON_TARGET);
    m.attr("PLAYER_ON_TARGET") = py::int_(Sokoban::PLAYER_ON_TARGET);
    
    m.attr("UP") = py::int_(Sokoban::UP);
    m.attr("DOWN") = py::int_(Sokoban::DOWN);
    m.attr("LEFT") = py::int_(Sokoban::LEFT);
    m.attr("RIGHT") = py::int_(Sokoban::RIGHT);
    
    py::class_<Sokoban>(m, "Sokoban")
        .def(py::init<>())
        .def(py::init<int, int>(), "Construct with custom dimensions", 
             py::arg("width"), py::arg("height") = Sokoban::DEFAULT_SIZE)
        
        // Core game methods
        .def("reset", &Sokoban::reset, "Reset to initial state")
        .def("load_level", &Sokoban::load_level, "Load level from string", py::arg("level_str"))
        .def("step", &Sokoban::step, "Execute action and return (observation, reward, done)", 
             py::arg("action"))
        .def("get_observation", &Sokoban::get_observation, "Get flattened grid observation")
        .def("get_grid", &Sokoban::get_grid, "Get 2D grid for rendering")
        .def("is_solved", &Sokoban::is_solved, "Check if level is solved")
        .def("is_valid_action", &Sokoban::is_valid_action, "Check if action is valid", 
             py::arg("action"))
        
        // Utility methods
        .def("get_width", &Sokoban::get_width, "Get grid width")
        .def("get_height", &Sokoban::get_height, "Get grid height")
        .def("get_player_position", &Sokoban::get_player_position, "Get player (x, y) position");
}

In [None]:
# Cell 5: Create environment file
%%writefile sokoban_env.py
import gymnasium as gym
import numpy as np
import pygame

# Constants will be imported from the compiled module
class SokobanEnv(gym.Env):
    metadata = {"render_modes": ["human", "ansi"], "render_fps": 5}

    def __init__(self, render_mode=None):
        # Import here to avoid issues during compilation
        import sokoban_engine
        self.sokoban_engine = sokoban_engine
        
        self.game = sokoban_engine.Sokoban()
        
        # Get actual observation dimensions
        obs = self.game.get_observation()
        self.obs_dim = len(obs)
        
        self.observation_space = gym.spaces.Box(low=0, high=6, shape=(self.obs_dim,), dtype=np.int32)
        self.action_space = gym.spaces.Discrete(4)
        
        self.render_mode = render_mode
        self.window = None
        self.clock = None
        self._current_obs = None
        self.window_size = 600
        
        if self.render_mode not in self.metadata["render_modes"] + [None]:
            raise ValueError(f"Invalid render_mode: {self.render_mode}")

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.game.reset()
        grid = self.game.get_observation()
        self._current_obs = np.array(grid, dtype=np.int32)
        
        # Update grid dimensions
        self.grid_width = self.game.get_width()
        self.grid_height = self.game.get_height()
        self.tile_size = min(self.window_size // self.grid_width, self.window_size // self.grid_height)
        
        if self.render_mode == "human":
            self.render()
        return self._current_obs, {}

    def step(self, action):
        grid, reward, terminated = self.game.step(action)
        self._current_obs = np.array(grid, dtype=np.int32)
        truncated = False
        info = {}
        if self.render_mode == "human":
            self.render()
        return self._current_obs, reward, terminated, truncated, info

    def render(self):
        if self.render_mode == "human":
            return self._render_human()
        elif self.render_mode == "ansi":
            return self._render_ansi()

    def _render_human(self):
        if self.window is None:
            pygame.init()
            self.window = pygame.display.set_mode((self.window_size, self.window_size))
            pygame.display.set_caption("Sokoban")
        if self.clock is None:
            self.clock = pygame.time.Clock()

        self.window.fill((255, 255, 255))

        grid = self._current_obs.reshape((self.grid_height, self.grid_width))
        for row in range(self.grid_height):
            for col in range(self.grid_width):
                tile = grid[row, col]
                x = col * self.tile_size
                y = row * self.tile_size
                rect = pygame.Rect(x, y, self.tile_size, self.tile_size)

                # Determine base color
                is_target = tile in [self.sokoban_engine.TARGET, self.sokoban_engine.BOX_ON_TARGET, self.sokoban_engine.PLAYER_ON_TARGET]
                base_color = (144, 238, 144) if is_target else (255, 255, 255)
                pygame.draw.rect(self.window, base_color, rect)

                # Draw objects
                if tile == self.sokoban_engine.WALL:
                    pygame.draw.rect(self.window, (100, 100, 100), rect)
                elif tile == self.sokoban_engine.BOX:
                    pygame.draw.rect(self.window, (165, 42, 42), rect)
                elif tile == self.sokoban_engine.BOX_ON_TARGET:
                    pygame.draw.rect(self.window, (0, 128, 0), rect)
                elif tile in [self.sokoban_engine.PLAYER, self.sokoban_engine.PLAYER_ON_TARGET]:
                    center = (x + self.tile_size // 2, y + self.tile_size // 2)
                    radius = self.tile_size // 2 - 5
                    pygame.draw.circle(self.window, (0, 0, 255), center, radius)

                # Draw tile borders for clarity
                pygame.draw.rect(self.window, (200, 200, 200), rect, 1)

        pygame.display.flip()
        self.clock.tick(self.metadata["render_fps"])

    def _render_ansi(self):
        if self._current_obs is None:
            return ""
        
        grid = self._current_obs.reshape((self.grid_height, self.grid_width))
        
        chars = {
            self.sokoban_engine.WALL: '#',
            self.sokoban_engine.EMPTY: ' ',
            self.sokoban_engine.PLAYER: '@',
            self.sokoban_engine.BOX: '$',
            self.sokoban_engine.TARGET: '.',
            self.sokoban_engine.BOX_ON_TARGET: '*',
            self.sokoban_engine.PLAYER_ON_TARGET: '+',
        }
        
        result = []
        for row in grid:
            result.append(''.join(chars.get(cell, '?') for cell in row))
        return '\n'.join(result)

    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()
            self.window = None

In [None]:
# Cell 6: Create training script
%%writefile train_sokoban_ppo.py
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import time
import os
from tqdm import tqdm

# Import environment
from sokoban_env import SokobanEnv

class PPONetwork(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim=256, use_one_hot=True):
        super().__init__()
        self.use_one_hot = use_one_hot
        self.num_tile_types = 7
        
        if use_one_hot:
            input_dim = obs_dim * self.num_tile_types
        else:
            input_dim = obs_dim
        
        print(f"Network input dimension: {input_dim}")
        
        self.shared_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, action_dim)
        )
        
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
            torch.nn.init.constant_(module.bias, 0.0)
    
    def preprocess_obs(self, obs):
        if self.use_one_hot:
            obs_tensor = torch.tensor(obs, dtype=torch.long)
            one_hot = torch.nn.functional.one_hot(obs_tensor, num_classes=self.num_tile_types)
            return one_hot.float().view(-1)
        else:
            return torch.tensor(obs, dtype=torch.float32) / 6.0
    
    def forward(self, obs):
        if not isinstance(obs, torch.Tensor):
            obs = self.preprocess_obs(obs)
        
        features = self.shared_net(obs)
        logits = self.policy_head(features)
        value = self.value_head(features)
        return logits, value
    
    def get_action(self, obs):
        with torch.no_grad():
            logits, value = self.forward(obs)
            probs = torch.softmax(logits, dim=-1)
            dist = torch.distributions.Categorical(probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            return action.item(), log_prob.item(), value.item()
    
    def evaluate_actions(self, obs, actions):
        logits, value = self.forward(obs)
        probs = torch.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy()
        return log_probs, entropy, value

class PPOBuffer:
    def __init__(self, obs_dim, size, gamma=0.99, gae_lambda=0.95):
        self.obs_buf = np.zeros((size, obs_dim), dtype=np.int32)
        self.act_buf = np.zeros(size, dtype=np.int32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.val_buf = np.zeros(size, dtype=np.float32)
        self.logp_buf = np.zeros(size, dtype=np.float32)
        self.adv_buf = np.zeros(size, dtype=np.float32)
        self.ret_buf = np.zeros(size, dtype=np.float32)
        self.gamma, self.gae_lambda = gamma, gae_lambda
        self.ptr, self.path_start_idx, self.max_size = 0, 0, size
    
    def store(self, obs, act, rew, val, logp):
        assert self.ptr < self.max_size
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.val_buf[self.ptr] = val
        self.logp_buf[self.ptr] = logp
        self.ptr += 1
    
    def finish_path(self, last_val=0):
        path_slice = slice(self.path_start_idx, self.ptr)
        rews = np.append(self.rew_buf[path_slice], last_val)
        vals = np.append(self.val_buf[path_slice], last_val)
        
        deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
        self.adv_buf[path_slice] = self._discount_cumsum(deltas, self.gamma * self.gae_lambda)
        
        self.ret_buf[path_slice] = self._discount_cumsum(rews[:-1], self.gamma)
        
        self.path_start_idx = self.ptr
    
    def _discount_cumsum(self, x, discount):
        return np.array([np.sum(x[i:] * (discount ** np.arange(len(x) - i))) 
                        for i in range(len(x))])
    
    def get(self):
        assert self.ptr == self.max_size
        self.ptr, self.path_start_idx = 0, 0
        
        adv_mean, adv_std = np.mean(self.adv_buf), np.std(self.adv_buf)
        self.adv_buf = (self.adv_buf - adv_mean) / (adv_std + 1e-8)
        
        return (self.obs_buf, 
                self.act_buf,
                self.ret_buf,
                self.adv_buf,
                self.logp_buf)

class PPOAgent:
    def __init__(self, env, hidden_dim=256, lr=3e-4, gamma=0.99, gae_lambda=0.95,
                 clip_ratio=0.2, train_epochs=4, batch_size=64, use_one_hot=True):
        self.env = env
        
        obs, _ = self.env.reset()
        self.obs_dim = len(obs)
        self.act_dim = env.action_space.n
        
        print(f"Detected observation dimension: {self.obs_dim}")
        print(f"Action dimension: {self.act_dim}")
        
        self.hidden_dim = hidden_dim
        self.lr = lr
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_ratio = clip_ratio
        self.train_epochs = train_epochs
        self.batch_size = batch_size
        self.use_one_hot = use_one_hot
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        self.network = PPONetwork(
            self.obs_dim, self.act_dim, hidden_dim, use_one_hot
        ).to(self.device)
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)
        
        self.episode_returns = deque(maxlen=100)
        self.episode_lengths = deque(maxlen=100)
        self.success_rate = deque(maxlen=100)
        self.training_steps = 0
    
    def train(self, total_timesteps=5_000_000, rollout_length=128, 
              save_freq=100_000, checkpoint_path="sokoban_ppo.pth"):
        
        # Kaggle-specific path
        if '/kaggle/working' in os.getcwd():
            checkpoint_path = '/kaggle/working/' + checkpoint_path
        
        buffer = PPOBuffer(self.obs_dim, rollout_length, self.gamma, self.gae_lambda)
        
        start_step = 0
        if os.path.exists(checkpoint_path):
            self.load_checkpoint(checkpoint_path)
            start_step = self.training_steps
            print(f"Resumed training from step {start_step}")
        
        obs, _ = self.env.reset()
        ep_return, ep_length = 0, 0
        
        progress_bar = tqdm(total=total_timesteps, initial=start_step, desc="Training")
        
        for step in range(start_step, total_timesteps):
            for _ in range(rollout_length):
                action, logp, val = self.network.get_action(obs)
                next_obs, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated
                
                buffer.store(obs, action, reward, val, logp)
                
                obs = next_obs
                ep_return += reward
                ep_length += 1
                
                if done:
                    success = (reward > 0.5)
                    
                    self.episode_returns.append(ep_return)
                    self.episode_lengths.append(ep_length)
                    self.success_rate.append(success)
                    
                    if len(self.episode_returns) >= 10 and len(self.episode_returns) % 10 == 0:
                        avg_return = np.mean(self.episode_returns)
                        avg_length = np.mean(self.episode_lengths)
                        current_success_rate = np.mean(self.success_rate)
                        tqdm.write(f"Step {step}: Avg Return: {avg_return:.2f}, "
                                 f"Avg Length: {avg_length:.1f}, "
                                 f"Success Rate: {current_success_rate:.2f}")
                    
                    obs, _ = self.env.reset()
                    ep_return, ep_length = 0, 0
            
            with torch.no_grad():
                _, _, last_val = self.network.get_action(obs)
            buffer.finish_path(last_val)
            
            self.update(buffer)
            
            progress_bar.update(rollout_length)
            self.training_steps += rollout_length
            
            if self.training_steps % save_freq == 0:
                self.save_checkpoint(checkpoint_path)
                print(f"\nCheckpoint saved at step {self.training_steps}")
        
        progress_bar.close()
        self.save_checkpoint(checkpoint_path)
        print("Training completed!")
    
    def update(self, buffer):
        obs, acts, rets, advs, old_logps = buffer.get()
        
        obs_tensor = torch.tensor(obs, dtype=torch.int32)
        acts_tensor = torch.tensor(acts, dtype=torch.long).to(self.device)
        rets_tensor = torch.tensor(rets, dtype=torch.float32).to(self.device)
        advs_tensor = torch.tensor(advs, dtype=torch.float32).to(self.device)
        old_logps_tensor = torch.tensor(old_logps, dtype=torch.float32).to(self.device)
        
        if self.network.use_one_hot:
            obs_processed = []
            for ob in obs_tensor:
                obs_processed.append(self.network.preprocess_obs(ob.numpy()))
            obs_processed = torch.stack(obs_processed).to(self.device)
        else:
            obs_processed = obs_tensor.float().to(self.device) / 6.0
        
        for _ in range(self.train_epochs):
            indices = torch.randperm(len(obs_processed))
            
            for start in range(0, len(obs_processed), self.batch_size):
                end = start + self.batch_size
                idx = indices[start:end]
                
                batch_obs = obs_processed[idx]
                batch_acts = acts_tensor[idx]
                batch_rets = rets_tensor[idx]
                batch_advs = advs_tensor[idx]
                batch_old_logps = old_logps_tensor[idx]
                
                logps, entropy, values = self.network.evaluate_actions(batch_obs, batch_acts)
                
                ratio = torch.exp(logps - batch_old_logps)
                clip_adv = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * batch_advs
                policy_loss = -torch.min(ratio * batch_advs, clip_adv).mean()
                
                value_loss = 0.5 * ((values.squeeze() - batch_rets) ** 2).mean()
                
                entropy_loss = -entropy.mean()
                
                total_loss = policy_loss + value_loss + 0.01 * entropy_loss
                
                self.optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
                self.optimizer.step()
    
    def save_checkpoint(self, path):
        checkpoint = {
            'network_state_dict': self.network.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'training_steps': self.training_steps,
            'episode_returns': list(self.episode_returns),
            'episode_lengths': list(self.episode_lengths),
            'success_rate': list(self.success_rate),
            'obs_dim': self.obs_dim,
            'act_dim': self.act_dim,
        }
        torch.save(checkpoint, path)
    
    def load_checkpoint(self, path):
        checkpoint = torch.load(path, map_location=self.device)
        self.network.load_state_dict(checkpoint['network_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.training_steps = checkpoint['training_steps']
        self.episode_returns = deque(checkpoint['episode_returns'], maxlen=100)
        self.episode_lengths = deque(checkpoint['episode_lengths'], maxlen=100)
        self.success_rate = deque(checkpoint['success_rate'], maxlen=100)

def main():
    env = SokobanEnv(render_mode=None)
    
    agent = PPOAgent(
        env,
        hidden_dim=256,
        lr=3e-4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_ratio=0.2,
        train_epochs=4,
        batch_size=64,
        use_one_hot=False  # Start with raw for stability
    )
    
    print("Starting PPO training on Sokoban...")
    
    agent.train(
        total_timesteps=5_000_000,
        rollout_length=128,
        save_freq=100_000,
        checkpoint_path="sokoban_ppo.pth"
    )

if __name__ == "__main__":
    main()

In [None]:
# Cell 7: Compile C++ module
!g++ -O3 -Wall -shared -std=c++20 -fPIC $(python3 -m pybind11 --includes) \
    python_bindings.cpp sokoban.cpp -o sokoban_engine$(python3-config --extension-suffix)

In [None]:
# Cell 8: Set PYTHONPATH and test
import sys
import os
sys.path.append('/kaggle/working')
os.chdir('/kaggle/working')

# Test the compiled module
import sokoban_engine
print("✅ C++ module loaded successfully!")
print(f"Constants: WALL={sokoban_engine.WALL}, EMPTY={sokoban_engine.EMPTY}")

# Test environment
from sokoban_env import SokobanEnv
env = SokobanEnv(render_mode=None)
obs, _ = env.reset()
print(f"Environment works! Observation shape: {obs.shape}")

In [None]:
# Cell 9: Start training (this will run in background when you close the tab)
!cd /kaggle/working && python train_sokoban_ppo.py