# ACIE Training on Google Colab

This notebook provides a streamlined guide to training the Astronomical Counterfactual Inference Engine (ACIE).

## Quick Start
1. **Upload**: Upload `ACIE_Project.zip` (generated by `scripts/prepare_for_colab.sh`) to your Google Drive.
2. **Data**: Create a `data/` folder next to the zip file and upload your CSV datasets there.
3. **Runtime**: Ensure you are using a GPU runtime (`Runtime` > `Change runtime type` > `T4 GPU`).

In [None]:
# 1. Mount Google Drive
from google.colab import drive
import os

drive.mount('/content/drive')

# CONFIG: Path to where you uploaded files
DRIVE_PATH = "/content/drive/My Drive/ACIE" 
WORK_DIR = "/content/ACIE_Train"

In [None]:
# 2. Setup Env (Unzip & Install)
import shutil

if not os.path.exists(WORK_DIR):
    print(f"Creating work directory: {WORK_DIR}")
    os.makedirs(WORK_DIR, exist_ok=True)
    
    # Unzip project
    zip_path = os.path.join(DRIVE_PATH, "ACIE_Project.zip")
    if os.path.exists(zip_path):
        print(f"Unzipping {zip_path}...")
        !unzip -q "{zip_path}" -d "{WORK_DIR}"
    else:
        print(f"Error: {zip_path} not found!")
else:
    print(f"Work directory {WORK_DIR} already exists.")

# Link Data Folder (read directly from Drive to save space/time)
if not os.path.exists(f"{WORK_DIR}/data"):
    data_drive_path = os.path.join(DRIVE_PATH, "data")
    if os.path.exists(data_drive_path):
        print(f"Symlinking data from {data_drive_path}...")
        os.symlink(data_drive_path, f"{WORK_DIR}/data")
    else:
        print(f"Warning: Data folder not found at {data_drive_path}")

In [None]:
# 3. Install Dependencies
os.chdir(WORK_DIR)
print(f"Current directory: {os.getcwd()}")

print("Installing dependencies...")
!pip install -q pytorch-lightning torchmetrics python-dotenv
!pip install -q "numpy<2.0" pandas scipy networkx
!pip install -q "bcrypt<4.0.0" passlib python-jose[cryptography]
!pip install -e .

In [None]:
# 4. Run Training
# Configuration
DATASET_SIZE = "10k"
MAX_EPOCHS = 20
BATCH_SIZE = 64
OUTPUT_DIR = "outputs/colab_run1"

cmd = f"python acie/training/train.py --data_dir data/ --output_dir {OUTPUT_DIR} --dataset_size {DATASET_SIZE} --max_epochs {MAX_EPOCHS} --batch_size {BATCH_SIZE} --gpus 1"

print(f"Starting training command: {cmd}")
!{cmd}

In [None]:
# 5. Save Results
# Copy outputs back to Drive
dest_output = os.path.join(DRIVE_PATH, "outputs")
if os.path.exists(OUTPUT_DIR):
    print(f"Copying results to {dest_output}...")
    !cp -r {OUTPUT_DIR} "{dest_output}"