# Vision-Caption Projector Training

This notebook trains the projector on COCO captions.

Setup checklist:
- Enable GPU
- Enable Internet (Kaggle)
- If you're on Kaggle, add the dataset `awsaf49/coco-2017-dataset`

Run cells in order.

In [None]:
# Step 1: Clone and install
import os
import subprocess
import sys

if not os.path.isdir("vision-caption"):
    subprocess.run(
        ["git", "clone", "https://github.com/asynced24/vision-caption.git"],
        check=True,
    )

os.chdir("vision-caption")
subprocess.run([sys.executable, "-m", "pip", "install", "-e", ".", "-q"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "pycocotools", "-q"], check=True)
print("Setup complete")

In [None]:
# Step 2: Get COCO dataset
# - Kaggle: uses /kaggle/input/coco-2017-dataset/coco2017
# - Colab: uses Kaggle API if kaggle.json exists, otherwise direct download
from pathlib import Path
import os
import subprocess
import sys

IMAGES_DIR = None
ANNOTATIONS_FILE = None


def _find_coco_paths(root: Path) -> tuple[str | None, str | None]:
    images = None
    annotations = None

    img_candidates = list(root.glob("**/train2017"))
    if img_candidates:
        images = str(img_candidates[0])

    ann_candidates = list(root.glob("**/captions_train2017.json"))
    if ann_candidates:
        annotations = str(ann_candidates[0])

    return images, annotations


def _validate_paths(images: str | None, annotations: str | None) -> tuple[str | None, str | None]:
    if images and not os.path.isdir(images):
        images = None
    if annotations and not os.path.isfile(annotations):
        annotations = None
    return images, annotations


is_kaggle = os.path.isdir("/kaggle/input") or bool(os.environ.get("KAGGLE_URL_BASE"))
if is_kaggle:
    images = "/kaggle/input/coco-2017-dataset/coco2017/train2017"
    annotations = "/kaggle/input/coco-2017-dataset/coco2017/annotations/captions_train2017.json"
    IMAGES_DIR, ANNOTATIONS_FILE = _validate_paths(images, annotations)

    if IMAGES_DIR is None or ANNOTATIONS_FILE is None:
        raise RuntimeError(
            "COCO not found under /kaggle/input. Add awsaf49/coco-2017-dataset and re-run Step 2."
        )

if IMAGES_DIR is None or ANNOTATIONS_FILE is None:
    kaggle_dir = Path.home() / ".kaggle"
    kaggle_json = kaggle_dir / "kaggle.json"

    if kaggle_json.exists():
        subprocess.run([sys.executable, "-m", "pip", "install", "kaggle", "-q"], check=True)
        from kaggle.api.kaggle_api_extended import KaggleApi

        api = KaggleApi()
        api.authenticate()

        dataset_slug = "awsaf49/coco-2017-dataset"
        out_dir = Path("coco_data")
        out_dir.mkdir(exist_ok=True)
        api.dataset_download_files(dataset_slug, path=str(out_dir), unzip=True)

        IMAGES_DIR, ANNOTATIONS_FILE = _validate_paths(*_find_coco_paths(out_dir))

if IMAGES_DIR is None or ANNOTATIONS_FILE is None:
    data_dir = Path("coco_data")
    data_dir.mkdir(exist_ok=True)

    train_zip = data_dir / "train2017.zip"
    ann_zip = data_dir / "annotations_trainval2017.zip"

    if not (data_dir / "train2017").exists():
        print("Downloading train2017 images...")
        subprocess.run(
            [
                "wget",
                "--progress=bar:force:noscroll",
                "-O",
                str(train_zip),
                "http://images.cocodataset.org/zips/train2017.zip",
            ],
            check=True,
        )
        subprocess.run(["unzip", "-q", str(train_zip), "-d", str(data_dir)], check=True)
        train_zip.unlink(missing_ok=True)

    if not (data_dir / "annotations").exists():
        print("Downloading annotations...")
        subprocess.run(
            [
                "wget",
                "--progress=bar:force:noscroll",
                "-O",
                str(ann_zip),
                "http://images.cocodataset.org/annotations/annotations_trainval2017.zip",
            ],
            check=True,
        )
        subprocess.run(["unzip", "-q", str(ann_zip), "-d", str(data_dir)], check=True)
        ann_zip.unlink(missing_ok=True)

    IMAGES_DIR, ANNOTATIONS_FILE = _validate_paths(
        str(data_dir / "train2017"),
        str(data_dir / "annotations" / "captions_train2017.json"),
    )

print("IMAGES_DIR:", IMAGES_DIR)
print("ANNOTATIONS_FILE:", ANNOTATIONS_FILE)
print("Dataset ready")

In [None]:
# Step 3: Train projector
import subprocess
import sys

if not IMAGES_DIR or not ANNOTATIONS_FILE:
    raise RuntimeError("IMAGES_DIR/ANNOTATIONS_FILE not set. Run Step 2 first.")

subprocess.run(
    [
        sys.executable,
        "train.py",
        "--images-dir",
        IMAGES_DIR,
        "--annotations-file",
        ANNOTATIONS_FILE,
        "--output-dir",
        "checkpoints",
        "--epochs",
        "3",
        "--batch-size",
        "32",
        "--lr",
        "1e-3",
    ],
    check=True,
)
print("Training complete")

In [None]:
# Step 4: Quick test
import os
from vision_caption import ModelConfig, load_model
from PIL import Image
import requests
from io import BytesIO

projector_path = "checkpoints/projector_final.pt"
if not os.path.isfile(projector_path):
    print("No trained projector found yet. Run Step 3 first.")
else:
    config = ModelConfig()
    config.projector_path = projector_path
    model = load_model(config)

    url = "https://images.unsplash.com/photo-1518791841217-8f162f1e1131"
    image = Image.open(BytesIO(requests.get(url).content))

    print("Caption:", model.generate(image))
    display(image)

In [None]:
# Step 5: Download weights
import os

projector_path = "checkpoints/projector_final.pt"
if not os.path.isfile(projector_path):
    print("No trained projector found yet. Run Step 3 first.")
else:
    try:
        from google.colab import files

        files.download(projector_path)
        print("Downloaded from Colab")
    except Exception:
        print("Kaggle: download from the Output tab (checkpoints/projector_final.pt)")