# MCTS Chess Engine - Colab GPU Test

This notebook compiles and tests the C++ MCTS engine on Google Colab with T4 GPU support.


## Step 1: Install Dependencies and Setup Environment


In [None]:
# Install build tools
!apt-get update -qq
!apt-get install -y -qq build-essential cmake wget unzip

# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")


## Step 2: Download LibTorch (CUDA version)


In [None]:
import os

# Download LibTorch with CUDA support
LIBTORCH_VERSION = "2.1.0"
CUDA_VERSION = "11.8"  # Colab T4 uses CUDA 11.8

LIBTORCH_PATH = "/content/libtorch"

if not os.path.exists(LIBTORCH_PATH):
    print("Downloading LibTorch with CUDA support...")
    print("This may take a few minutes (~1.5GB download)...")
    
    # Download LibTorch
    download_url = f"https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-{LIBTORCH_VERSION}%2Bcu118.zip"
    print(f"Downloading from: {download_url}")
    
    !wget -q --show-progress {download_url} -O libtorch.zip
    
    if not os.path.exists("libtorch.zip"):
        raise FileNotFoundError("Failed to download LibTorch. Check your internet connection.")
    
    print("Extracting LibTorch (this may take a minute)...")
    !unzip -q libtorch.zip
    
    if not os.path.exists(LIBTORCH_PATH):
        raise FileNotFoundError("LibTorch extraction failed. Check disk space.")
    
    !rm libtorch.zip
    print("✓ LibTorch downloaded and extracted successfully")
else:
    print("✓ LibTorch already exists")

# Verify installation
print(f"\nVerifying LibTorch installation...")
print(f"LibTorch path: {LIBTORCH_PATH}")
print(f"LibTorch exists: {os.path.exists(LIBTORCH_PATH)}")

if os.path.exists(LIBTORCH_PATH):
    torch_include = os.path.join(LIBTORCH_PATH, "include")
    torch_lib = os.path.join(LIBTORCH_PATH, "lib")
    script_h = os.path.join(torch_include, "torch", "script.h")
    
    print(f"  Include dir: {torch_include} (exists: {os.path.exists(torch_include)})")
    print(f"  Lib dir: {torch_lib} (exists: {os.path.exists(torch_lib)})")
    print(f"  torch/script.h: {os.path.exists(script_h)}")
    
    if os.path.exists(script_h):
        print("\n✓ LibTorch installation verified!")
    else:
        print("\n⚠ Warning: torch/script.h not found. LibTorch may be incomplete.")
else:
    print("\n✗ LibTorch installation failed!")


## Step 3: Upload MCTS Source Code

You can either:
1. Upload the MCTS folder directly
2. Clone from GitHub (if your repo is public)
3. Use the file browser to upload


In [None]:
import os

# ============================================================================
# CHOOSE ONE OPTION BELOW:
# ============================================================================

# OPTION 1: Clone entire repo from GitHub (EASIEST if repo is public)
# Replace with your GitHub username and repo name:
REPO_URL = "https://github.com/yourusername/ChessMirror.git"  # UPDATE THIS!
REPO_DIR = "/content/ChessMirror"
MCTS_DIR = "/content/ChessMirror/MCTS"

if not os.path.exists(MCTS_DIR):
    print("Cloning repository from GitHub...")
    !git clone {REPO_URL} {REPO_DIR}
    if os.path.exists(MCTS_DIR):
        print(f"✓ Repository cloned! MCTS directory: {MCTS_DIR}")
    else:
        print(f"✗ MCTS directory not found. Check repo structure.")
else:
    print(f"✓ MCTS directory already exists: {MCTS_DIR}")

# ============================================================================
# OPTION 2: Upload MCTS folder as zip file (if repo is private)
# ============================================================================
# Uncomment this if you prefer to upload:
# from google.colab import files
# import zipfile
# 
# print("Please upload your MCTS folder as a zip file...")
# uploaded = files.upload()
# for filename in uploaded.keys():
#     if filename.endswith('.zip'):
#         with zipfile.ZipFile(filename, 'r') as zip_ref:
#             zip_ref.extractall('/content')
#         MCTS_DIR = "/content/MCTS"  # Adjust if zip extracts differently
#         print(f"✓ Extracted to {MCTS_DIR}")
#         break

# ============================================================================
# OPTION 3: Manual upload via file browser
# ============================================================================
# Use Colab's file browser (folder icon on left) to upload files to /content/MCTS/

# Verify MCTS structure
if os.path.exists(MCTS_DIR):
    print(f"\n✓ MCTS directory found: {MCTS_DIR}")
    required = [
        "src/mcts.cpp",
        "include/mcts.h",
        "Chess/mcts_bridge.cpp",
        "Makefile"
    ]
    missing = []
    for file in required:
        path = os.path.join(MCTS_DIR, file)
        if not os.path.exists(path):
            missing.append(file)
    
    if missing:
        print(f"⚠ Missing files: {missing}")
    else:
        print("✓ All required files present!")
else:
    print(f"✗ MCTS directory not found. Please use one of the options above.")


## Step 4: Compile MCTS Bridge


In [None]:
import os

# Auto-detect MCTS directory based on what exists
# Priority: 1) Cloned repo path, 2) Uploaded path
if os.path.exists("/content/ChessMirror/MCTS"):
    MCTS_DIR = "/content/ChessMirror/MCTS"
    print("✓ Detected: Cloned repository structure (/content/ChessMirror/MCTS)")
elif os.path.exists("/content/MCTS"):
    MCTS_DIR = "/content/MCTS"
    print("✓ Detected: Uploaded MCTS folder (/content/MCTS)")
else:
    print("✗ MCTS directory not found!")
    print("Please run Step 3 first to either:")
    print("  - Clone your repository, OR")
    print("  - Upload your MCTS source files")
    raise FileNotFoundError("MCTS directory not found. Please run Step 3 first.")

print(f"Using MCTS directory: {MCTS_DIR}\n")

LIBTORCH_PATH = "/content/libtorch"

# Verify LibTorch exists
print("Verifying LibTorch installation...")
print(f"LibTorch path: {LIBTORCH_PATH}")
print(f"LibTorch exists: {os.path.exists(LIBTORCH_PATH)}")

if not os.path.exists(LIBTORCH_PATH):
    print("\n✗ LibTorch not found!")
    print("Please run Step 2 first to download LibTorch.")
    print("\nTo fix this:")
    print("1. Go back to Step 2 (Download LibTorch)")
    print("2. Run that cell")
    print("3. Wait for download to complete")
    print("4. Then come back and run this cell")
    raise FileNotFoundError(f"LibTorch not found at {LIBTORCH_PATH}. Please run Step 2 first.")

# Verify structure
torch_include = os.path.join(LIBTORCH_PATH, "include")
torch_lib = os.path.join(LIBTORCH_PATH, "lib")
script_h = os.path.join(torch_include, "torch", "script.h")

print(f"  Include dir: {torch_include} (exists: {os.path.exists(torch_include)})")
print(f"  Lib dir: {torch_lib} (exists: {os.path.exists(torch_lib)})")
print(f"  torch/script.h: {os.path.exists(script_h)}")

if not os.path.exists(script_h):
    print("\n⚠ Warning: torch/script.h not found. LibTorch may be incomplete.")
    print("Please re-run Step 2 to re-download LibTorch.")

# Set environment variables
os.environ['LIBTORCH_PATH'] = LIBTORCH_PATH
os.environ['LD_LIBRARY_PATH'] = f"{LIBTORCH_PATH}/lib:{os.environ.get('LD_LIBRARY_PATH', '')}"

# Change to MCTS directory
os.chdir(MCTS_DIR)
print(f"\nChanged to: {os.getcwd()}")

# Compile the bridge
print("\nCompiling MCTS bridge...")
print(f"Using LibTorch at: {LIBTORCH_PATH}")

# Pass LIBTORCH_PATH as a Make variable (not just environment variable)
# The Makefile uses 'ifdef LIBTORCH_PATH' which checks Make variables
compile_cmd = f"make LIBTORCH_PATH={LIBTORCH_PATH} Bridge"
result = !{compile_cmd} 2>&1

for line in result:
    print(line)

# Check if compilation succeeded
if os.path.exists("mcts_bridge"):
    print("\n✓ Compilation successful! mcts_bridge executable created")
    !ls -lh mcts_bridge
else:
    print("\n✗ Compilation failed. Check errors above.")
    print("\nTroubleshooting:")
    print("1. Make sure LibTorch was downloaded in Step 2")
    print("2. Check that torch/script.h exists in LibTorch include directory")
    print("3. Verify MCTS source files are complete")


## Step 5: Download/Upload Model File


In [None]:
# Option 1: Upload model file
from google.colab import files
import os

MODEL_PATH = "/content/model.pt"
EXPECTED_SIZE_MB = 220  # chessnet_new_ts.pt is ~223MB, allow some variance

if not os.path.exists(MODEL_PATH):
    print("⚠️  IMPORTANT: Your model file is ~223MB. Large uploads can be corrupted.")
    print("   If upload fails or file is corrupted, consider using Google Drive instead.\n")
    print("Please upload your model file (chessnet_new_ts.pt)...")
    uploaded = files.upload()
    for filename in uploaded.keys():
        if filename.endswith('.pt'):
            !mv {filename} {MODEL_PATH}
            print(f"Model saved to {MODEL_PATH}")
            break
else:
    print(f"Model already exists: {MODEL_PATH}")

# Option 2: Download from URL (if hosted somewhere)
# !wget -O {MODEL_PATH} https://your-model-url.com/model.pt

# Option 3: Mount Google Drive and copy from there
# from google.colab import drive
# drive.mount('/content/drive')
# !cp /content/drive/MyDrive/chessnet_new_ts.pt {MODEL_PATH}

# Verify file exists and check size
if os.path.exists(MODEL_PATH):
    file_size_mb = os.path.getsize(MODEL_PATH) / (1024 * 1024)
    print(f"\n✓ Model file found: {MODEL_PATH}")
    print(f"  Size: {file_size_mb:.1f} MB")
    
    if file_size_mb < EXPECTED_SIZE_MB * 0.9:
        print(f"\n⚠️  WARNING: File size is suspiciously small!")
        print(f"   Expected: ~{EXPECTED_SIZE_MB}MB, Got: {file_size_mb:.1f}MB")
        print(f"   The file may be corrupted or incomplete.")
        print(f"   Please re-upload or use Google Drive for large files.\n")
    else:
        print(f"  ✓ File size looks correct (~{EXPECTED_SIZE_MB}MB expected)")
    
    # Try to verify the model file is valid by loading it with torch
    print("\nVerifying model file integrity...")
    try:
        import torch
        # Try to load the model (this will fail if corrupted)
        model = torch.jit.load(MODEL_PATH, map_location='cpu')
        print("✓ Model file is valid and can be loaded!")
        print(f"  Model type: {type(model)}")
    except Exception as e:
        print(f"\n✗ ERROR: Model file appears to be corrupted!")
        print(f"   Error: {str(e)}")
        print(f"\n   Solutions:")
        print(f"   1. Re-upload the model file")
        print(f"   2. Use Google Drive (mount it and copy from there)")
        print(f"   3. Download from a URL if you have one")
        print(f"   4. Check that the file wasn't truncated during upload")
        raise
else:
    print(f"\n✗ Model file not found at {MODEL_PATH}")
    print("Please upload the model file using one of the options above.")


## Step 6: Test MCTS with GPU


In [None]:
import subprocess
import os

# MCTS_DIR should match what you set in Step 3
MCTS_DIR = "/content/ChessMirror/MCTS" if os.path.exists("/content/ChessMirror/MCTS") else "/content/MCTS"
MODEL_PATH = "/content/model.pt"
LIBTORCH_PATH = "/content/libtorch"

# Set environment for GPU
env = os.environ.copy()
env['LD_LIBRARY_PATH'] = f"{LIBTORCH_PATH}/lib:{env.get('LD_LIBRARY_PATH', '')}"
env['CUDA_VISIBLE_DEVICES'] = '0'  # Use GPU 0

# Test position (starting position)
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"

# MCTS parameters
max_iterations = 10000
max_seconds = 5
cpuct = 1.5

print(f"Testing MCTS with GPU...")
print(f"FEN: {fen}")
print(f"Max iterations: {max_iterations}")
print(f"Max seconds: {max_seconds}")
print(f"CPUCT: {cpuct}")
print("\n" + "="*60)

# Run MCTS
bridge_path = os.path.join(MCTS_DIR, "mcts_bridge")
cmd = [
    bridge_path,
    MODEL_PATH,
    fen,
    str(max_iterations),
    str(max_seconds),
    str(cpuct)
]

try:
    result = subprocess.run(
        cmd,
        cwd=MCTS_DIR,
        env=env,
        capture_output=True,
        text=True,
        timeout=max_seconds + 10
    )
    
    print("STDOUT (move):")
    print(result.stdout)
    
    print("\n" + "="*60)
    print("STDERR (debug output with MCTS stats):")
    print(result.stderr)
    
    if result.returncode == 0:
        print("\n✓ MCTS completed successfully")
        print(f"Selected move: {result.stdout.strip()}")
    else:
        print(f"\n✗ MCTS failed with return code: {result.returncode}")
        
except subprocess.TimeoutExpired:
    print("\n✗ MCTS timed out")
except Exception as e:
    print(f"\n✗ Error: {e}")


In [None]:
import re
import subprocess
import os
import time

def parse_mcts_stats(stderr_output):
    """Extract MCTS performance stats from debug output"""
    stats = {}
    
    # Parse iterations
    iter_match = re.search(r'\[DEBUG\] Iterations: (\d+)', stderr_output)
    if iter_match:
        stats['iterations'] = int(iter_match.group(1))
    
    # Parse time
    time_match = re.search(r'\[DEBUG\] Time: ([\d.]+)s', stderr_output)
    if time_match:
        stats['time'] = float(time_match.group(1))
    
    # Parse positions/s
    pos_match = re.search(r'\[DEBUG\] Positions/s: ([\d.]+)', stderr_output)
    if pos_match:
        stats['positions_per_sec'] = float(pos_match.group(1))
    
    return stats

# Run performance test
# MCTS_DIR should match what you set in Step 3
MCTS_DIR = "/content/ChessMirror/MCTS" if os.path.exists("/content/ChessMirror/MCTS") else "/content/MCTS"
MODEL_PATH = "/content/model.pt"
LIBTORCH_PATH = "/content/libtorch"

env = os.environ.copy()
env['LD_LIBRARY_PATH'] = f"{LIBTORCH_PATH}/lib:{env.get('LD_LIBRARY_PATH', '')}"

fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
max_iterations = 10000
max_seconds = 5
cpuct = 1.5

bridge_path = os.path.join(MCTS_DIR, "mcts_bridge")
cmd = [bridge_path, MODEL_PATH, fen, str(max_iterations), str(max_seconds), str(cpuct)]

print("Running MCTS performance test...")
print(f"Max iterations: {max_iterations}, Max seconds: {max_seconds}")
print("\n" + "="*60)

start_time = time.time()
result = subprocess.run(cmd, cwd=MCTS_DIR, env=env, capture_output=True, text=True, timeout=max_seconds+10)
elapsed = time.time() - start_time

stats = parse_mcts_stats(result.stderr)

print("MCTS Performance Stats:")
print(f"  Iterations: {stats.get('iterations', 'N/A')}")
print(f"  Time taken: {stats.get('time', 'N/A'):.3f}s")
print(f"  Positions/s: {stats.get('positions_per_sec', 'N/A'):.0f}")
print(f"  Total elapsed: {elapsed:.3f}s")
print(f"  Selected move: {result.stdout.strip()}")

print("\n" + "="*60)
print("Full debug output:")
print(result.stderr)
