# Google ColabでLPM探索を実行 (Atariのみ)

このノートブックは、Craftiumのインストールをスキップし、Atari環境で `train_lpm.py` スクリプトを実行します。

## 1. セットアップ
Atari環境と録画に必要なシステム依存関係をインストールします。

In [None]:
# GPUの確認
!nvidia-smi

In [None]:
!apt-get update
!apt-get install -y ffmpeg xvfb libsdl2-dev python3-opengl cmake zlib1g-dev

## 2. Pythonライブラリのインストール
AtariをサポートするGymnasiumとその他のユーティリティをインストール

In [None]:
!pip install gymnasium[classic_control,atari,accept-rom-license] \
    moviepy imageio wandb opencv-python matplotlib tqdm pillow "gymnasium>=1.0.0" shimmy

## 3. Explorationモジュールの作成
train_lpm.pyの実行に必要な `exploration` モジュールのファイルを作成。

In [None]:
import os

# explorationディレクトリの作成
os.makedirs("exploration", exist_ok=True)

# __init__.pyの作成
with open("exploration/__init__.py", "w") as f:
    f.write("")

print("explorationディレクトリと__init__.pyを作成しました")

In [None]:
%%writefile exploration/cifar.py
import math
import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image

def load_cifar_dataset():
    """
    Load the CIFAR-10 dataset.
    Returns:
        cifar_dataset: The loaded CIFAR-10 dataset
    """
    # Define transform
    transform = transforms.Compose([transforms.ToTensor()])
    # Download dataset if not already downloaded
    cifar_dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )
    return cifar_dataset

def generate_random_cifar_observation(cifar_dataset, width=32, height=32):
    """
    Generate a random observation from CIFAR-10 dataset with the specified dimensions.
    The function stretches the image to fill the observation space.
    Args:
        cifar_dataset: The CIFAR-10 dataset
        width: Width of the observation
        height: Height of the observation
    Returns:
        observation: numpy array with shape (height, width, 3)
    """
    # Select a random image
    random_index = random.randint(0, len(cifar_dataset) - 1)
    image, _ = cifar_dataset[random_index]
    
    # Convert tensor to PIL Image
    image_np = image.permute(1, 2, 0).numpy()  # Change from CxHxW to HxWxC
    pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
    
    # Stretch the image to the desired dimensions
    stretched_image = pil_image.resize((width, height), Image.Resampling.BILINEAR)
    
    # Convert to numpy array format
    observation = np.array(stretched_image)
    return observation

def create_cifar_function_simple():
    """
    Create a simple get_random_cifar function using your existing code
    
    Returns:
        Function that returns random CIFAR images (32, 32, 3) uint8
    """
    print("Loading CIFAR-10 dataset...")
    cifar_dataset = load_cifar_dataset()
    print(f"✅ Loaded {len(cifar_dataset)} CIFAR-10 images")
    
    def get_random_cifar():
        """Return a random CIFAR image as numpy array (32, 32, 3) uint8"""
        return generate_random_cifar_observation(cifar_dataset, width=32, height=32)
    
    return get_random_cifar

# Simple usage for NoisyTV wrapper
def create_cifar_function_for_noisy_tv(dataset_type='cifar10'):
    """
    Create CIFAR function for NoisyTV wrapper (simplified version)
    
    Args:
        dataset_type: 'cifar10' only (using your existing code)
    
    Returns:
        Function that returns random CIFAR images
    """
    if dataset_type.lower() != 'cifar10':
        print(f"Warning: Only cifar10 supported with simple loader, got {dataset_type}")
    
    return create_cifar_function_simple()


In [None]:
%%writefile exploration/noisy_wrapper.py
"""
Debug CIFAR Wrapper - Adds extensive logging to find the issue
"""

import numpy as np
import random
import matplotlib.pyplot as plt
import gymnasium as gym
from gymnasium import spaces

class NoisyTVEnvWrapperCIFAR(gym.Wrapper):
    def __init__(self, env, get_random_cifar, num_random_actions=1):
        super().__init__(env)
        
        self.get_random_cifar = get_random_cifar
        self.num_random_actions = num_random_actions
        
        # Store original action space info
        self.original_actions = self.env.action_space.n
        self.random_actions = list(range(self.original_actions, self.original_actions + num_random_actions))
        
        # Update action space
        self.action_space = spaces.Discrete(self.original_actions + num_random_actions)
        
        # Test CIFAR function
        try:
            test_cifar = self.get_random_cifar()
        except Exception as e:
            print(f"   ❌ CIFAR test failed: {e}")
    
    def step(self, action):
        
        if action in self.random_actions:
            
            # Execute NOOP instead
            obs, reward, terminated, truncated, info = self.env.step(0)
            
            # Apply CIFAR replacement
            obs_replaced = self.add_noisy_tv(obs)
            
            return obs_replaced, reward, terminated, truncated, info
        else:
            return self.env.step(action)
    
    def add_noisy_tv(self, obs):
        obs = obs.copy()
        """Replace observation with CIFAR-based noise"""
        try:
            # Get random CIFAR image
            cifar_img = self.get_random_cifar()  # Should be (32, 32, 3)
            
            # Convert to grayscale
            if len(cifar_img.shape) == 3 and cifar_img.shape[2] == 3:
                cifar_gray = np.dot(cifar_img, [0.2989, 0.5870, 0.1140]).astype(np.uint8)
            else:
                cifar_gray = cifar_img
            
            # Create replacement matching obs shape exactly
            replacement = self._create_replacement(cifar_gray, obs)
            
            return replacement
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            
            # Fallback: simple random noise
            replacement = np.random.randint(0, 256, size=obs.shape, dtype=obs.dtype)
            return replacement
    
    def _create_replacement(self, cifar_gray, target_obs):
        """Create replacement that exactly matches target observation"""
        target_shape = target_obs.shape
        target_dtype = target_obs.dtype
        
        if len(target_shape) == 3:
            # 3D observation (H, W, C)
            h, w, c = target_shape
            
            # Tile CIFAR to cover the area
            tile_h = (h + 31) // 32
            tile_w = (w + 31) // 32
            
            tiled = np.tile(cifar_gray, (tile_h, tile_w))
            cropped = tiled[:h, :w]
            
            if c == 1:
                replacement = cropped.reshape(h, w, 1)
            elif c == 3:
                replacement = np.stack([cropped] * 3, axis=2)
            else:
                replacement = np.stack([cropped] * c, axis=2)
                
        elif len(target_shape) == 2:
            # 2D observation (H, W)
            h, w = target_shape
            tile_h = (h + 31) // 32
            tile_w = (w + 31) // 32
            
            tiled = np.tile(cifar_gray, (tile_h, tile_w))
            replacement = tiled[:h, :w]
            
        else:
            replacement = np.random.randint(0, 256, size=target_shape, dtype=target_dtype)
        
        # Ensure correct dtype
        replacement = replacement.astype(target_dtype)
        
        return replacement
    
    def get_action_meanings(self):
        """Get action meanings including CIFAR actions"""
        if hasattr(self.env, 'get_action_meanings'):
            meanings = self.env.get_action_meanings()
        else:
            meanings = ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN']
        
        meanings = list(meanings)
        for i in range(self.num_random_actions):
            meanings.append(f'CIFAR_DEBUG_{i+1}')
        
        return meanings


## 4. 必要であればGoogleドライブのマウント
Googleドライブをマウント

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
# CHANGE THIS PATH to where you uploaded your folder
# e.g., '/content/drive/MyDrive/WorldModel/LPM_exploration'
try:
    os.chdir('/content/drive/MyDrive/')
    print(f"カレントディレクトリ: {os.getcwd()}")
except Exception as e:
    print(f"ドライブのマウントまたはパスの検索に失敗しました: {e}")
    print("ファイルがアップロードされているか、パスが正しいか確認してください。")

## 5. WandB ログイン
実験記録のためにWandBにログイン \
事前にAPIキーを取得しておいてください． \
実行すると最初に3つの選択肢が出てきますが2の"use an exitstin W&B account"選択して，その後でAPIキーの入力を求められます． \
記録の必要がなければ3を選択すると記録せずに実行します．この場合はAPIキーを求められません．

In [None]:
import wandb
wandb.login()

## 6. 学習の実行 (Atari)
Atari Breakoutの学習スクリプトを実行します。 \
実行コマンド引数
- --env-name  実行するAtari環境の名前を入れて実行します．デフォルトはボール崩しになってます．
- --steps  学習を行うステップ数です．colab上でボール崩しだと1分で10000ステップぐらい学習しました．
- wandb  WandBで記録するようにしています．

In [None]:
import os
os.environ['MUJOCO_GL'] = 'egl'

# Run training for Atari Breakout
!python train_lpm.py --env-name "ALE/Breakout-v5" --steps 50000 --wandb

以下で利用可能なAtari環境を確認できます。

In [None]:
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)

# List available Atari v5 environments
atari_envs = [env_id for env_id in gym.registry.keys() if "ALE/" in env_id and "-v5" in env_id]
atari_envs.sort()

print(f"Available Atari Envs: {len(atari_envs)}")
print("All Environments:")
for env in atari_envs:
    print(env)