# Vision-Caption Projector Training

Train the projector on COCO Captions. Works on **Colab** and **Kaggle**.

**IMPORTANT:** Enable GPU + Internet first!
- **Colab:** Runtime → Change runtime type → T4 GPU
- **Kaggle:** Settings → Accelerator → GPU T4 x2 + Internet ON

**If Kaggle logs show `Accelerator: None`, GPU is not enabled.**
**If `git clone` fails with `Could not resolve host`, Internet is OFF.**

In [None]:
# Step 1: Clone and install (fails fast if Internet/GPU are off)
import os
import socket
import subprocess
import sys

SETUP_OK = True


def has_internet() -> bool:
    try:
        socket.gethostbyname("github.com")
        return True
    except OSError:
        return False


if not has_internet():
    print("Internet is OFF. In Kaggle: Settings → Internet → ON.")
    SETUP_OK = False

repo_dir = os.getcwd()
if os.path.basename(repo_dir) != "vision-caption":
    repo_dir = os.path.abspath("vision-caption")

if SETUP_OK and not os.path.isdir(repo_dir):
    subprocess.run(
        ["git", "clone", "https://github.com/asynced24/vision-caption.git"],
        check=True,
    )

if SETUP_OK:
    os.chdir(repo_dir)
    subprocess.run([sys.executable, "-m", "pip", "install", "-e", ".", "-q"], check=True)
    subprocess.run([sys.executable, "-m", "pip", "install", "pycocotools", "-q"], check=True)

    try:
        import torch

        if not torch.cuda.is_available():
            print("GPU not detected. Enable Accelerator (Kaggle) or T4 GPU (Colab).")
            SETUP_OK = False
    except Exception as exc:
        print(f"GPU check failed: {exc}")
        SETUP_OK = False

print("✓ Setup complete" if SETUP_OK else "Setup incomplete. Fix settings and re-run this cell.")

In [None]:
# Step 2: Download COCO dataset (~19GB)
from pathlib import Path

DATA_OK = False

data_dir = Path("coco_data")
data_dir.mkdir(exist_ok=True)

if not SETUP_OK:
    print("Skipping download because setup failed. Fix settings above first.")
else:
    if not (data_dir / "train2017").exists():
        print("Downloading images...")
        !wget -q http://images.cocodataset.org/zips/train2017.zip -P coco_data
        !unzip -q coco_data/train2017.zip -d coco_data
        !rm coco_data/train2017.zip

    if not (data_dir / "annotations").exists():
        print("Downloading annotations...")
        !wget -q http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P coco_data
        !unzip -q coco_data/annotations_trainval2017.zip -d coco_data
        !rm coco_data/annotations_trainval2017.zip

    if (data_dir / "train2017").exists() and (data_dir / "annotations" / "captions_train2017.json").exists():
        DATA_OK = True

print("✓ Dataset ready!" if DATA_OK else "Dataset not ready. Check Internet and rerun this cell.")

In [None]:
# Step 3: Train projector (full COCO)
import subprocess
import sys

if not SETUP_OK:
    print("Setup not complete. Fix settings in Step 1 and rerun.")
elif not DATA_OK:
    print("Dataset not ready. Fix Step 2 and rerun.")
else:
    subprocess.run(
        [
            sys.executable,
            "train.py",
            "--images-dir",
            "coco_data/train2017",
            "--annotations-file",
            "coco_data/annotations/captions_train2017.json",
            "--output-dir",
            "checkpoints",
            "--epochs",
            "3",
            "--batch-size",
            "32",
            "--lr",
            "1e-3",
        ],
        check=True,
    )

In [None]:
# Step 4: Test trained model
from vision_caption import ModelConfig, load_model
from PIL import Image
import requests
from io import BytesIO

config = ModelConfig()
config.projector_path = "checkpoints/projector_final.pt"
model = load_model(config)

# Test image
url = "https://images.unsplash.com/photo-1518791841217-8f162f1e1131"
image = Image.open(BytesIO(requests.get(url).content))

print(f"Caption: {model.generate(image)}")
display(image)

In [None]:
# Step 5: Download trained weights
try:
    from google.colab import files
    files.download('checkpoints/projector_final.pt')
    print("✓ Downloaded (Colab)")
except ImportError:
    print("✓ On Kaggle: Click Output tab → Download projector_final.pt")