# Run `run.py` in Google Colab
Install dependencies, clone the repository if needed, and launch the training or evaluation script with parameterized arguments.

In [None]:
import os
import subprocess
import sys
from pathlib import Path

repo_url = "https://github.com/Chris0lsen/fp-dataset-artifacts.git"
repo_dir = Path("fp-dataset-artifacts")

if Path.cwd().name != repo_dir.name:
    if not repo_dir.exists():
        subprocess.run(["git", "clone", repo_url, repo_dir.name], check=True)
    os.chdir(repo_dir)
    print(f"Working directory: {Path.cwd()}")
else:
    print(f"Working directory: {Path.cwd()}")

subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], check=True)

In [None]:
# Adjust these values as needed.
task = "nli"
dataset = "snli"
do_train = True
do_eval = False
output_dir = "./trained_model"
model_id = "google/electra-small-discriminator"
max_length = 128
max_train_samples = None
max_eval_samples = None
training_arg_overrides = {
    "per_device_train_batch_size": 8,
    "num_train_epochs": 3.0
}
extra_args = []  # e.g., ["--resume_from_checkpoint", "./trained_model"]
disable_wandb = True  # Set to False if you have wandb configured
keep_intermediate_checkpoints = False  # When False, disable periodic checkpoint saves
save_to_drive = False  # Flip to True to copy artifacts into Google Drive after a run
drive_mount_point = "/content/drive"  # Leave as-is unless you mount elsewhere
drive_output_dir = "/content/drive/MyDrive/fp-trained-models"  # Destination folder in Drive


In [None]:
import os
import shlex
import shutil
import subprocess
import sys
from pathlib import Path

cli_args = [
    sys.executable,
    "run.py",
    "--task",
    task,
    "--output_dir",
    output_dir
]

if do_train:
    cli_args.append("--do_train")
if do_eval:
    cli_args.append("--do_eval")
if dataset:
    cli_args.extend(["--dataset", dataset])
if model_id:
    cli_args.extend(["--model", model_id])
if max_length is not None:
    cli_args.extend(["--max_length", str(max_length)])
if max_train_samples is not None:
    cli_args.extend(["--max_train_samples", str(max_train_samples)])
if max_eval_samples is not None:
    cli_args.extend(["--max_eval_samples", str(max_eval_samples)])

effective_overrides = dict(training_arg_overrides)
if not keep_intermediate_checkpoints:
    effective_overrides.setdefault("save_strategy", "no")

for key, value in effective_overrides.items():
    if value is None:
        continue
    cli_args.extend([f"--{key}", str(value)])

cli_args.extend(extra_args)
Path(output_dir).mkdir(parents=True, exist_ok=True)

env = os.environ.copy()
if disable_wandb:
    env["WANDB_DISABLED"] = "true"

if save_to_drive:
    try:
        from google.colab import drive as gdrive
    except ImportError as exc:
        raise RuntimeError("save_to_drive=True requires running inside Google Colab") from exc
    print(f"Mounting Google Drive at {drive_mount_point} (you may be prompted to authorize)...")
    gdrive.mount(drive_mount_point, force_remount=False)

print("Running:", " ".join(shlex.quote(str(arg)) for arg in cli_args))
result = subprocess.run(cli_args, check=False, capture_output=True, text=True, env=env)
if result.stdout:
    print("\nstdout:\n", result.stdout)
if result.stderr:
    print("\nstderr:\n", result.stderr, file=sys.stderr)
if result.returncode != 0:
    raise RuntimeError(f"run.py exited with status {result.returncode}")

if save_to_drive:
    source_dir = Path(output_dir)
    if not source_dir.is_dir():
        raise FileNotFoundError(f"Expected output directory '{source_dir}' not found")
    dest_root = Path(drive_output_dir)
    dest_root.mkdir(parents=True, exist_ok=True)
    dest_dir = dest_root / source_dir.name
    print(f"Copying artifacts to {dest_dir}...")
    shutil.copytree(source_dir, dest_dir, dirs_exist_ok=True)
    print("Artifacts copied to Google Drive.")


In [None]:
import os
import shlex
import subprocess
import sys
from pathlib import Path

# Use the fine-tuned checkpoint unless you override here
eval_model_path = Path(output_dir)
if not eval_model_path.exists():
    eval_model_path = Path(model_id)

cli_args = [
    sys.executable,
    "run.py",
    "--task",
    task,
    "--output_dir",
    output_dir,
    "--do_eval"
]

if dataset:
    cli_args.extend(["--dataset", dataset])
if eval_model_path:
    cli_args.extend(["--model", str(eval_model_path)])
if max_length is not None:
    cli_args.extend(["--max_length", str(max_length)])
if max_eval_samples is not None:
    cli_args.extend(["--max_eval_samples", str(max_eval_samples)])

per_device_eval_bs = training_arg_overrides.get("per_device_eval_batch_size")
if per_device_eval_bs is not None:
    cli_args.extend(["--per_device_eval_batch_size", str(per_device_eval_bs)])

cli_args.extend(extra_args)

env = os.environ.copy()
if disable_wandb:
    env["WANDB_DISABLED"] = "true"

print("Running eval:", " ".join(shlex.quote(str(arg)) for arg in cli_args))
result = subprocess.run(cli_args, check=False, capture_output=True, text=True, env=env)
if result.stdout:
    print("\nstdout:\n", result.stdout)
if result.stderr:
    print("\nstderr:\n", result.stderr, file=sys.stderr)
if result.returncode != 0:
    raise RuntimeError(f"run.py evaluation exited with status {result.returncode}")
