# ACIE Training on Google Colab

This notebook trains the ACIE model using the project folder from your Google Drive.

## Quick Start
1. **Drive Setup**: Upload your project folder to Google Drive.
2. **Runtime**: Ensure you are using a GPU runtime (`Runtime` > `Change runtime type` > `T4 GPU`).

In [None]:
# 1. Mount Google Drive
from google.colab import drive
import os
import sys
import glob

drive.mount('/content/drive')

# 2. Find Project Root (Robust Search)
SEARCH_ROOT = "/content/drive/My Drive/ACIE"
PROJECT_ROOT = None

print(f"Searching for setup.py in {SEARCH_ROOT}...")

candidates = glob.glob(f"{SEARCH_ROOT}/**/setup.py", recursive=True)

if candidates:
    candidates.sort(key=len)
    setup_path = candidates[0]
    PROJECT_ROOT = os.path.dirname(setup_path)
    
    print(f"‚úÖ Found setup.py at: {setup_path}")
    print(f"üìÇ Setting working directory to: {PROJECT_ROOT}")
    
    os.chdir(PROJECT_ROOT)
    sys.path.append(PROJECT_ROOT)
else:
    print(f"‚ùå NOT FOUND: Could not find setup.py in {SEARCH_ROOT}.")

In [None]:
# 3. Install Dependencies
import os
if os.path.exists("setup.py"):
    print(f"Installing dependencies from: {os.getcwd()}")
    !pip install -q pytorch-lightning torchmetrics python-dotenv
    !pip install -q "numpy<2.0" pandas scipy networkx
    !pip install -q "bcrypt<4.0.0" passlib python-jose[cryptography]
    !pip install -e .
else:
    print("‚ùå Setup.py still not found.")

In [None]:
# 4. Find Datasets (Recursive)
# We now search inside ANY subfolder (like lib/) for CSVs
import shutil

DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True)

# Find ALL CSVs in the entire ACIE folder structure
print(f"Searching for CSVs in {SEARCH_ROOT}...")
found_csvs = glob.glob(f"{SEARCH_ROOT}/**/*.csv", recursive=True)

if found_csvs:
    print(f"‚úÖ Found {len(found_csvs)} CSV files. Linking to local data/ folder...")
    for csv in found_csvs:
        filename = os.path.basename(csv)
        target = os.path.join(DATA_DIR, filename)
        if not os.path.exists(target):
            os.symlink(csv, target)
    print(f"Linked to: {os.path.abspath(DATA_DIR)}")
else:
    print("‚ö†Ô∏è Warning: No CSV files found ANYWHERE in ACIE folder!")
    # Fallback to current dir if user put them there manually
    if glob.glob("*.csv"):
        print("Found CSVs in current directory instead.")
        DATA_DIR = "."

In [None]:
# 5. Run Training (Combined 20k)
# Using '20k' automatically picks up:
# - Observational (20k)
# - Interventions/Shifts (20k)
# (10k files are skipped due to dimension mismatch)
DATASET_SIZE = "20k"
MAX_EPOCHS = 20
BATCH_SIZE = 64
OUTPUT_DIR = "outputs/colab_run_combined"

OBS_DIM = 11000 if DATASET_SIZE == "20k" else 6000
LATENT_DIM = 4000 if DATASET_SIZE == "20k" else 2000

cmd = f"python acie/training/train.py --data_dir {DATA_DIR} --output_dir {OUTPUT_DIR} --dataset_size {DATASET_SIZE} --obs_dim {OBS_DIM} --latent_dim {LATENT_DIM} --max_epochs {MAX_EPOCHS} --batch_size {BATCH_SIZE} --gpus 1"

print(f"Starting training command: {cmd}")
!{cmd}

In [None]:
# 6. Save Results
dest_output = os.path.join(SEARCH_ROOT, "outputs_combined")

if os.path.exists(OUTPUT_DIR):
    print(f"Copying results to {dest_output}...")
    if not os.path.exists(dest_output):
        os.makedirs(dest_output)
    !cp -r {OUTPUT_DIR}/* "{dest_output}/"