# üåå ACIE: Astronomical Counterfactual Inference Engine
## Google Colab Training Guide

This notebook provides a complete, step-by-step walkthrough to train the ACIE model using your datasets on Google Drive.

---

### üìã Prerequisites
1.  **Google Drive Folder**: You should have a folder named `ACIE_Training` in your "My Drive".
2.  **Project Zip**: `ACIE_Project.zip` must be inside that folder.
3.  **Datasets**: Your CSV files must be in that SAME folder.
    *   `acie_observational_20k_x_20k.csv`
    *   `acie_hard_intervention_20k_x_20k.csv`
    *   `acie_environment_shift_20k_x_20k.csv`
    *   `acie_instrument_shift_20k_x_20k.csv`
    *   *(Note: 10k datasets are skipped for training to ensure dimension consistency)*

### üöÄ Runtime Setup
Ensure you are using a GPU:
*   Go to **Runtime** > **Change runtime type**
*   Select **T4 GPU** (or A100 if available)

In [None]:
# @title 1. Initialize Environment
# Mount Google Drive to access your files
from google.colab import drive
import os
import glob
import shutil
import sys

print("Mounting Google Drive...")
drive.mount('/content/drive')

# Configuration
DRIVE_FOLDER = "/content/drive/My Drive/ACIE_Training"  # Default location

# Check if folder exists
if os.path.exists(DRIVE_FOLDER):
    print(f"‚úÖ Found Training Folder: {DRIVE_FOLDER}")
else:
    print(f"‚ùå Error: Could not find '{DRIVE_FOLDER}'")
    print("   Please verify you created the folder 'ACIE_Training' in your Drive root.")

In [None]:
# @title 2. Setup Workspace
# We copy the project to a local workspace for faster execution
WORK_DIR = "/content/ACIE_Work"
ZIP_PATH = os.path.join(DRIVE_FOLDER, "ACIE_Project.zip")

if os.path.exists(ZIP_PATH):
    # Clean up previous runs
    if os.path.exists(WORK_DIR):
        shutil.rmtree(WORK_DIR)
    
    print(f"Unzipping project to {WORK_DIR}...")
    !unzip -q "{ZIP_PATH}" -d "{WORK_DIR}"
    print("‚úÖ Project unpacked successfully.")
else:
    print(f"‚ùå Critical Error: 'ACIE_Project.zip' not found in {DRIVE_FOLDER}.")
    print("   Please upload the zip file generated by 'setup_assist/prepare_for_colab.sh'.")

In [None]:
# @title 3. Link Datasets
# We create symbolic links to your Drive CSVs to save space
DATA_DIR_LOCAL = os.path.join(WORK_DIR, "data")
os.makedirs(DATA_DIR_LOCAL, exist_ok=True)

# Specific files to look for (20k size for training)
REQUIRED_FILES = [
    "acie_observational_20k_x_20k.csv",
    "acie_hard_intervention_20k_x_20k.csv",
    "acie_environment_shift_20k_x_20k.csv",
    "acie_instrument_shift_20k_x_20k.csv"
]

print(f"Scanning {DRIVE_FOLDER} for datasets...")
found_count = 0

for expected_file in REQUIRED_FILES:
    drive_path = os.path.join(DRIVE_FOLDER, expected_file)
    if os.path.exists(drive_path):
        # Link it
        target = os.path.join(DATA_DIR_LOCAL, expected_file)
        if not os.path.exists(target):
            os.symlink(drive_path, target)
        print(f"  ‚úÖ Linked: {expected_file}")
        found_count += 1
    else:
        print(f"  ‚ö†Ô∏è Missing: {expected_file} (Will adhere to training without it if others exist)")

if found_count > 0:
    print(f"\nReady to train with {found_count} datasets.")
else:
    print("\n‚ùå No valid 20k datasets found! Please check your file names.")

In [None]:
# @title 4. Install Dependencies
os.chdir(WORK_DIR)
print(f"Working Directory: {os.getcwd()}")

print("Installing ACIE dependencies...")
!pip install -q pytorch-lightning torchmetrics python-dotenv
!pip install -q "numpy<2.0" pandas scipy networkx
!pip install -q "bcrypt<4.0.0" passlib python-jose[cryptography]
!pip install -e .
print("‚úÖ Dependencies installed.")

In [None]:
# @title 5. Start Training
# Training Configuration
OUTPUT_DIR = "outputs/colab_run_main"
DATASET_SIZE = "20k"      # Must match your CSV filenames
MAX_EPOCHS = 20           # Adjust complexity
BATCH_SIZE = 64
GPUS = 1

# Auto-configure dimensions based on dataset size
OBS_DIM = 11000 if DATASET_SIZE == "20k" else 6000
LATENT_DIM = 4000 if DATASET_SIZE == "20k" else 2000

print(f"Training Protocol: Combined Dataset ({DATASET_SIZE})")
print(f"Dimensions: Obs={OBS_DIM}, Latent={LATENT_DIM}")

cmd = (f"python acie/training/train.py "
       f"--data_dir data/ "
       f"--output_dir {OUTPUT_DIR} "
       f"--dataset_size {DATASET_SIZE} "
       f"--obs_dim {OBS_DIM} "
       f"--latent_dim {LATENT_DIM} "
       f"--max_epochs {MAX_EPOCHS} "
       f"--batch_size {BATCH_SIZE} "
       f"--gpus {GPUS}")

print(f"Executing: {cmd}\n")
!{cmd}

In [None]:
# @title 6. Save Results
# Sync the outputs back to your Google Drive
DEST_DIR = os.path.join(DRIVE_FOLDER, "outputs/final_run")

if os.path.exists(OUTPUT_DIR):
    print(f"Saving results to: {DEST_DIR}...")
    if not os.path.exists(DEST_DIR):
        os.makedirs(DEST_DIR)
    
    !cp -r {OUTPUT_DIR}/* "{DEST_DIR}/"
    print("‚úÖ Training artifacts saved successfully!")
else:
    print("‚ö†Ô∏è No output directory found. Did training complete?")