# Run `run.py` in Google Colab
Install dependencies, clone the repository if needed, and launch the training or evaluation script with parameterized arguments.

In [None]:
import os
import subprocess
import sys
from pathlib import Path

repo_url = "https://github.com/Chris0lsen/fp-dataset-artifacts.git"
repo_dir = Path("fp-dataset-artifacts")

if Path.cwd().name != repo_dir.name:
    if not repo_dir.exists():
        subprocess.run(["git", "clone", repo_url, repo_dir.name], check=True)
    os.chdir(repo_dir)
    print(f"Working directory: {Path.cwd()}")
else:
    print(f"Working directory: {Path.cwd()}")

subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], check=True)

In [None]:
# Adjust these values as needed.
task = "nli"
dataset = "snli"
do_train = True
do_eval = False
output_dir = "./trained_model"
model_id = "google/electra-small-discriminator"
max_length = 128
max_train_samples = None
max_eval_samples = None
training_arg_overrides = {
    "per_device_train_batch_size": 8,
    "num_train_epochs": 3.0
}
extra_args = []  # e.g., ["--resume_from_checkpoint", "./trained_model"]

In [None]:
import shlex
import subprocess
import sys
from pathlib import Path

cli_args = [
    sys.executable,
    "run.py",
    "--task",
    task,
    "--output_dir",
    output_dir
]

if do_train:
    cli_args.append("--do_train")
if do_eval:
    cli_args.append("--do_eval")
if dataset:
    cli_args.extend(["--dataset", dataset])
if model_id:
    cli_args.extend(["--model", model_id])
if max_length is not None:
    cli_args.extend(["--max_length", str(max_length)])
if max_train_samples is not None:
    cli_args.extend(["--max_train_samples", str(max_train_samples)])
if max_eval_samples is not None:
    cli_args.extend(["--max_eval_samples", str(max_eval_samples)])

for key, value in training_arg_overrides.items():
    if value is None:
        continue
    cli_args.extend([f"--{key}", str(value)])

cli_args.extend(extra_args)
Path(output_dir).mkdir(parents=True, exist_ok=True)

print("Running:", " ".join(shlex.quote(str(arg)) for arg in cli_args))
subprocess.run(cli_args, check=True)
