In [None]:
# -*- coding: utf-8 -*-
"""chaturaji-mcts-training-colab-async.ipynb

Automatically generated by Colab.

Original file is located at
    # Replace with your Colab notebook URL if desired
"""

import os
import sys
import subprocess # Needed to run external commands
import shutil # Needed for finding files
import time # Import time for sleep
import requests # For better downloads

# --- Configuration ---
repo_url = "https://github.com/Anurag-Baundwal/4pc-ffa-chaturaji-mcts-cpp"
repo_name = repo_url.split("/")[-1]
local_repo_path = os.path.join("/content", repo_name)
drive_models_path = '/content/drive/MyDrive/cpp_engine_models' # Google Drive path for saving models

# Libtorch Version (Make sure this matches CUDA version available on Colab - usually 11.x or 12.x)
# Check PyTorch website for the correct link for your desired CUDA version (e.g., 11.8 or 12.1)
# Example for 2.7.0 with CUDA 11.8:
libtorch_url_release = "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.7.0%2Bcu118.zip"

libtorch_zip_release = os.path.basename(libtorch_url_release).replace('%2B', '+') # Handle URL encoding in filename
libtorch_extract_dir = "/content/libtorch" # Directory to extract libtorch into
libtorch_path_release = os.path.join(libtorch_extract_dir, "libtorch_release", "libtorch") # Expected path after extraction


# --- NEW/UPDATED Training Parameters ---
iterations = 1000
games_per_iter = 1000
epochs_per_iter = 7
training_batch_size = 16384  # <--- Batch size for the Training DataLoader
num_workers = 8              # <--- Number of self-play worker threads
nn_batch_size = 4096         # <--- Batch size for the NN Evaluator thread during self-play
sims_per_move = 250
# --- End NEW/UPDATED ---

# Use the Google Drive path
save_dir = drive_models_path
# Optionally load a previous model from Drive
# load_model_path = f"{drive_models_path}/run_YYYYMMDD_HHMMSS/chaturaji_iter_X.pt"
load_model_path = "" # Set to path if resuming training

# --- Step 1: Clone or Pull Repository ---
print("--- Step 1: Cloning/Updating Repository ---")
if os.path.exists(local_repo_path):
    print(f"Repository '{repo_name}' already exists at {local_repo_path}. Removing for a fresh clone...")
    try:
        result = subprocess.run(['rm', '-rf', local_repo_path], check=True, capture_output=True, text=True)
        print(f"Removed {local_repo_path}")
    except subprocess.CalledProcessError as e:
        print(f"Error removing {local_repo_path}: Command '{e.cmd}' failed with exit code {e.returncode}")
        print("stderr:", e.stderr); print("stdout:", e.stdout)
        raise
    except FileNotFoundError:
         print(f"Warning: 'rm' command not found.")

print(f"Cloning {repo_url} branch 'main' into {local_repo_path}...")
result = subprocess.run(['git', 'clone', '--branch', 'main', repo_url, local_repo_path], capture_output=True, text=True)
if result.returncode != 0:
    print("Git clone failed:"); print(result.stderr)
    raise RuntimeError("Git clone failed.")
else:
    print(result.stdout); print("Clone successful.")

# --- Step 2: Install Bazel ---
print("\n--- Step 2: Installing Bazel ---")
bazel_check = subprocess.run(['which', 'bazel'], capture_output=True, text=True)
if bazel_check.returncode != 0:
    print("Bazel not found. Installing...")

    # Official Bazel Installation Steps for Debian/Ubuntu based systems (like Colab)
    install_cmds = [
        # 1. Install prerequisites
        "sudo apt-get update",
        "sudo apt-get install -y apt-transport-https curl gnupg",
        # 2. Add Bazel GPG key
        "curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor > bazel.gpg",
        "sudo mv bazel.gpg /etc/apt/trusted.gpg.d/",
        # 3. Add Bazel repository
        'echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | sudo tee /etc/apt/sources.list.d/bazel.list',
        # 4. Update package list *after* adding the repo
        "sudo apt-get update",
        # 5. Install Bazel (adjust version if needed, e.g., 'bazel-7' for version 7)
        "sudo apt-get install -y bazel"
    ]

    for cmd in install_cmds:
        print(f"Running: {cmd}")
        # Using shell=True is okay here because we trust these commands
        # and some involve pipes (|) or redirection (>).
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        if result.returncode != 0:
            print(f"Command failed: {cmd}\nstdout: {result.stdout}\nstderr: {result.stderr}")
            raise RuntimeError(f"Failed to execute Bazel installation step: {cmd}")
        else:
             print("Command successful.")

    # Verify installation
    print("Verifying Bazel installation...")
    result = subprocess.run(['bazel', '--version'], capture_output=True, text=True)
    if result.returncode != 0:
         print("Bazel verification failed after installation attempt!")
         print(result.stderr)
         raise RuntimeError("Bazel installation appeared successful but verification failed.")
    else:
        print(f"Bazel installation successful! Version:\n{result.stdout}")

else:
    print(f"Bazel already installed at: {bazel_check.stdout.strip()}")
    result = subprocess.run(['bazel', '--version'], capture_output=True, text=True)
    print(f"Bazel version:\n{result.stdout}")


# --- Step 3: Download and Extract Libtorch (GPU Linux) ---
print("\n--- Step 3: Downloading and Extracting Libtorch ---")
os.makedirs(libtorch_extract_dir, exist_ok=True)

# (download_and_extract function kept mostly the same, using requests)
def download_and_extract(libtorch_url, libtorch_zip, libtorch_path, description):
    base_extract_dir = os.path.dirname(os.path.dirname(libtorch_path))
    zip_dest_path = os.path.join(base_extract_dir, libtorch_zip)
    extract_target_path = os.path.dirname(libtorch_path)

    if not os.path.exists(libtorch_path):
        print(f"Downloading {description} Libtorch ({libtorch_url})...")
        print(f"Target file: {zip_dest_path}")
        try:
            with requests.get(libtorch_url, stream=True, timeout=60) as r: # Increased timeout
                r.raise_for_status()
                total_size = int(r.headers.get('content-length', 0))
                print(f"Download size: {total_size / (1024*1024):.2f} MB")
                chunk_size = 8192
                bytes_downloaded = 0
                start_time = time.time()
                with open(zip_dest_path, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=chunk_size):
                        if chunk:
                            f.write(chunk)
                            bytes_downloaded += len(chunk)
                print(f"\nDownload completed successfully.")
        except requests.exceptions.RequestException as e:
            print(f"\nERROR: Download failed: {e}")
            if os.path.exists(zip_dest_path): os.remove(zip_dest_path)
            raise
        print(f"Verifying downloaded file: {zip_dest_path}")
        time.sleep(1) # Give filesystem time
        if not os.path.exists(zip_dest_path):
             raise FileNotFoundError(f"Downloaded zip file not found: {zip_dest_path}")
        file_size = os.path.getsize(zip_dest_path)
        print(f"File size: {file_size} bytes")
        if file_size < 1024 * 1024 * 100: # Check for reasonable size (e.g., > 100MB)
             raise ValueError(f"Downloaded file size is too small: {file_size} bytes")

        print(f"Extracting {description} Libtorch to {extract_target_path}...")
        os.makedirs(extract_target_path, exist_ok=True)
        try:
             unzip_command = ['unzip', '-qo', zip_dest_path, '-d', extract_target_path]
             print(f"Executing: {' '.join(unzip_command)}")
             subprocess.run(unzip_command, check=True)
             print("Unzip completed successfully.")
        except subprocess.CalledProcessError as e:
             print(f"ERROR: unzip command failed with exit code {e.returncode}")
             raise
        except FileNotFoundError:
             print("ERROR: 'unzip' command not found. Trying to install...")
             subprocess.run(['sudo', 'apt-get', 'update'], check=True)
             subprocess.run(['sudo', 'apt-get', 'install', '-y', 'unzip'], check=True)
             print("Retrying extraction with unzip...")
             subprocess.run(['unzip', '-qo', zip_dest_path, '-d', extract_target_path], check=True)

        print(f"Cleaning up {libtorch_zip}...")
        os.remove(zip_dest_path)

        if not os.path.exists(libtorch_path):
             print(f"ERROR: Expected libtorch path '{libtorch_path}' not found after extraction!")
             print(f"Contents of '{extract_target_path}': {os.listdir(extract_target_path)}")
             raise FileNotFoundError(f"Extraction failed to produce expected path: {libtorch_path}")
        else:
             print(f"{description} Libtorch extracted successfully to {libtorch_path}")
    else:
        print(f"{description} Libtorch already extracted at {libtorch_path}")

# Download and extract Release version
download_and_extract(libtorch_url_release, libtorch_zip_release, libtorch_path_release, "libtorch_release")


# --- Step 4: Mount Google Drive & Create Model Directory ---
print("\n--- Step 4: Mounting Google Drive ---")
from google.colab import drive
drive.mount('/content/drive')

print(f"Creating model directory (if needed): {drive_models_path}")
os.makedirs(drive_models_path, exist_ok=True)


# --- Step 5: Modify WORKSPACE file for Colab Paths, Rules_CC, and Linux Libtorch ---
print("\n--- Step 5: Modifying WORKSPACE file for Colab (Linux Libtorch) ---")
workspace_file = os.path.join(local_repo_path, "WORKSPACE")

# Standard boilerplate for rules_cc
rules_cc_boilerplate = """
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
    name = "rules_cc",
    urls = ["https://github.com/bazelbuild/rules_cc/releases/download/0.1.1/rules_cc-0.1.1.tar.gz"],
    sha256 = "712d77868b3152dd618c4d64faaddefcc5965f90f5de6e6dd1d5ddcd0be82d42",
    strip_prefix = "rules_cc-0.1.1",
)
"""

# --- Define the CORRECT build_file_content for Linux GPU Libtorch ---
# Added "include/torch/cuda" to includes
linux_libtorch_build_content = r'''
load("@rules_cc//cc:defs.bzl", "cc_import", "cc_library")

package(default_visibility = ["//visibility:public"])

# Import necessary shared libraries (.so) for Linux
cc_import(
    name = "torch_so",
    shared_library = "lib/libtorch.so",
)
cc_import(
    name = "torch_cpu_so",
    shared_library = "lib/libtorch_cpu.so",
)
cc_import(
    name = "c10_so",
    shared_library = "lib/libc10.so",
)
cc_import(
    name = "torch_cuda_so",
    shared_library = "lib/libtorch_cuda.so", # Needed for GPU support
)
# Import c10_cuda explicitly if linker errors occur
cc_import(
    name = "c10_cuda_so",
    shared_library = "lib/libc10_cuda.so",
)
cc_import(
    name = "torch_global_deps_so",
    shared_library = "lib/libtorch_global_deps.so", # Often required
)
# Add imports for other .so files in libtorch/lib if linker errors occur later

cc_library(
    name = "libtorch",
    hdrs = glob(
        ["include/**/*.h", "include/**/*.hpp"],
        exclude = ["include/torch/csrc/autograd/generated/python_*.h"]
    ),
    includes = [
        "include",
        "include/torch/csrc/api/include",
        "include/torch/cuda", # <--- Ensure this CUDA include path is present
    ],
    # Link against the imported .so targets
    deps = [
        ":torch_so",
        ":torch_cpu_so",
        ":c10_so",
        ":c10_cuda_so", # Link against c10_cuda
        ":torch_cuda_so",
        ":torch_global_deps_so",
    ],
)
''' # End linux_libtorch_build_content definition

if os.path.exists(workspace_file):
    print(f"Generating new WORKSPACE content for Colab paths and Linux Libtorch...")
    new_workspace_content = []
    new_workspace_content.append('workspace(name = "chaturaji_cpp_project")\n\n')
    new_workspace_content.append(rules_cc_boilerplate + "\n\n")
    new_workspace_content.append("# --- Libtorch Configuration for Linux GPU (Colab) ---\n\n")
    new_workspace_content.append(f'new_local_repository(\n')
    new_workspace_content.append(f'    name = "libtorch_release",\n')
    new_workspace_content.append(f'    path = "{libtorch_path_release}",\n') # Use the correct Colab path
    new_workspace_content.append(f'    build_file_content = """{linux_libtorch_build_content}""",\n')
    new_workspace_content.append(f')\n\n') # Close the new_local_repository call

    # (Optional) Add other WORKSPACE content from the original if needed, filtering out old Libtorch stuff
    # For this project, the WORKSPACE only contained Libtorch, so we don't need to preserve much else.
    # If you had other dependencies (like rules_python, etc.), you'd need to parse the original
    # lines more carefully and append them here, skipping the old libtorch blocks.

    new_workspace_content.append("# --- End Libtorch Configuration ---\n")


    # Join the lines into the final content
    final_workspace_content = "".join(new_workspace_content)

    # --- Print snippets for verification ---
    print("Generated WORKSPACE content (showing relevant parts):")
    print(final_workspace_content[:500]) # Print beginning
    print("...")
    release_start = final_workspace_content.find("name = \"libtorch_release\"")
    if release_start != -1:
        end_of_block_marker = final_workspace_content.find(")", release_start) # Find closing parenthesis of new_local_repo
        if end_of_block_marker != -1:
             print(final_workspace_content[release_start : end_of_block_marker + 1]) # Print the whole block
        else:
            print("Could not find end of libtorch_release block marker ')'")
    print("...")

    with open(workspace_file, 'w') as f:
        f.write(final_workspace_content)
    print(f"{workspace_file} overwritten successfully with Linux configuration.")
else:
    print(f"Error: {workspace_file} not found!")
    raise FileNotFoundError(f"{workspace_file} not found!")

# --- Step 6: Construct Bazel Command (UPDATED Arguments) ---
print("\n--- Step 6: Constructing Bazel Command ---")
bazel_command = [
    "bazel",
    "run",
    "--enable_bzlmod=false", # Use WORKSPACE
    "//:chaturaji_engine",   # Target defined in your BUILD file
    "-c", "opt",             # Compile in optimized mode
    "--define", "use_cuda=true", # Enable CUDA compilation path
    "--",                    # Separator for program arguments
    # --- C++ Engine Arguments ---
    "--train",
    "--iterations", str(iterations),
    "--games", str(games_per_iter),
    "--epochs", str(epochs_per_iter),
    "--train-batch", str(training_batch_size), # Use renamed arg
    "--workers", str(num_workers),             # Use new arg
    "--nn-batch", str(nn_batch_size),          # Use new arg
    "--sims", str(sims_per_move),
    "--save-dir", save_dir,
]

if load_model_path:
    bazel_command.extend(["--load-model", load_model_path])

print("Bazel command to execute:")
print(" ".join(bazel_command))


# --- Step 7: Build and Execute Training ---
print("\n--- Step 7: Building and Starting Training ---")

# --- Build Only Command ---
print("Building the executable first...")
bazel_build_command = [
    "bazel", "build",
    "--enable_bzlmod=false",
    "//:chaturaji_engine",
    "-c", "opt",
    "--define", "use_cuda=true",
]
print("Bazel build command:")
print(" ".join(bazel_build_command))

try:
    build_result = subprocess.run(bazel_build_command, cwd=local_repo_path, check=True, capture_output=True, text=True)
    print("Build successful.")
except subprocess.CalledProcessError as e:
    print(f"ERROR: Bazel build failed with exit code {e.returncode}")
    print("Command:", ' '.join(e.cmd)); print("stdout:", e.stdout); print("stderr:", e.stderr)
    raise
except Exception as e:
    print(f"An unexpected error occurred during build: {e}")
    raise

# --- Execute Directly Command (UPDATED Arguments) ---
print("\nExecuting the compiled engine directly...")

# Define the path to the executable within bazel-bin
executable_path = os.path.join(local_repo_path, "bazel-bin", "chaturaji_engine")
print(f"Executable path: {executable_path}")

# Verify the executable exists
if not os.path.exists(executable_path):
     print(f"ERROR: Compiled executable not found at {executable_path}")
     # Add listing of bazel-bin to help debug path issues
     bazel_bin_path = os.path.join(local_repo_path, "bazel-bin")
     print(f"Contents of {bazel_bin_path}:")
     try:
         print(os.listdir(bazel_bin_path))
     except FileNotFoundError:
         print(f"{bazel_bin_path} directory not found.")
     except Exception as e:
        print(f"Could not list directory contents: {e}")
     raise FileNotFoundError(f"Compiled executable not found: {executable_path}")

engine_args = [
    "--train",
    "--iterations", str(iterations),
    "--games", str(games_per_iter),
    "--epochs", str(epochs_per_iter),
    "--train-batch", str(training_batch_size), # Use renamed arg
    "--workers", str(num_workers),             # Use new arg
    "--nn-batch", str(nn_batch_size),          # Use new arg
    "--sims", str(sims_per_move),
    "--save-dir", save_dir,
]
if load_model_path:
    engine_args.extend(["--load-model", load_model_path])

# Combine the executable path and its arguments
direct_run_command = [executable_path] + engine_args
print("Direct execution command:")
print(" ".join(direct_run_command))

# Set LD_LIBRARY_PATH for Libtorch .so files
libtorch_lib_dir = os.path.join(libtorch_path_release, "lib")
print(f"Setting LD_LIBRARY_PATH to include: {libtorch_lib_dir}")
if not os.path.isdir(libtorch_lib_dir):
    raise FileNotFoundError(f"Libtorch lib directory not found: {libtorch_lib_dir}")
# Confirm the specific libgomp file still exists
libgomp_file_path = os.path.join(libtorch_lib_dir, "libgomp-98b21ff3.so.1")
if not os.path.exists(libgomp_file_path):
    print(f"ERROR: Expected libgomp file not found: {libgomp_file_path}")
    print(f"Contents of {libtorch_lib_dir}:")
    try:
        print(os.listdir(libtorch_lib_dir))
    except Exception as e:
        print(f"Could not list directory contents: {e}")
    raise FileNotFoundError(f"Specific libgomp file not found in Libtorch: {libgomp_file_path}")
else:
    print(f"Confirmed {libgomp_file_path} exists.")

# Prepare the environment for the direct execution
run_env = os.environ.copy()
existing_ld_path = run_env.get('LD_LIBRARY_PATH', '')
run_env['LD_LIBRARY_PATH'] = f"{libtorch_lib_dir}:{existing_ld_path}".strip(':')
print(f"Updated LD_LIBRARY_PATH for execution: {run_env['LD_LIBRARY_PATH']}")

# Execute the command directly using subprocess.Popen
print(f"\nExecuting C++ training process directly...")
try:
    # Use Popen to stream output
    process = subprocess.Popen(
        direct_run_command,
        cwd=local_repo_path,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT, # Redirect stderr to stdout
        text=True,
        bufsize=1, # Line buffered
        env=run_env
    )

    # Print output line by line
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            sys.stdout.write(output) # Use sys.stdout.write for direct printing
            sys.stdout.flush()       # Ensure output appears immediately

    rc = process.poll()
    if rc == 0:
        print("\n--- Training process completed successfully ---")
    else:
        # Print final stderr if process failed
        print(f"\n--- Training process failed with exit code {rc} ---")

except Exception as e:
    print(f"An error occurred during training execution: {e}")

print("\n--- Colab Notebook Finished ---")

--- Step 1: Cloning/Updating Repository ---
Repository '4pc-ffa-chaturaji-mcts-cpp' already exists at /content/4pc-ffa-chaturaji-mcts-cpp. Removing for a fresh clone...
Removed /content/4pc-ffa-chaturaji-mcts-cpp
Cloning https://github.com/Anurag-Baundwal/4pc-ffa-chaturaji-mcts-cpp into /content/4pc-ffa-chaturaji-mcts-cpp...

Clone successful.

--- Step 2: Installing Bazel ---
Bazel already installed at: /usr/bin/bazel
Bazel version:
bazel 8.2.1


--- Step 3: Downloading and Extracting Libtorch ---
libtorch_release Libtorch already extracted at /content/libtorch/libtorch_release/libtorch

--- Step 4: Mounting Google Drive ---
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Creating model directory (if needed): /content/drive/MyDrive/cpp_engine_models

--- Step 5: Modifying WORKSPACE file for Colab (Linux Libtorch) ---
Generating new WORKSPACE content for Colab paths and Linux Libtorch...
Generated WORKSPAC

KeyboardInterrupt: 