# Vision-Caption Projector Training

This notebook trains the projector component of the vision-caption model on COCO Captions dataset.

**What this trains:** Only the MLP projector that maps vision features to language embeddings.

**What stays frozen:** SigLIP vision encoder and Qwen2-1.5B language decoder.

**Dataset:** COCO train2017 (~118k images, ~590k captions)

**Training time:** ~2-3 hours on T4 GPU

## Setup

In [None]:
# Clone repository
!git clone https://github.com/asynced24/vision-caption.git
%cd vision-caption

In [None]:
# Install dependencies
%pip install -e . -q
%pip install pycocotools -q

## Download COCO Dataset

Full COCO train2017 dataset: 118k images (~18GB) + annotations (~1GB).

**Direct links:**
- Images: http://images.cocodataset.org/zips/train2017.zip
- Annotations: http://images.cocodataset.org/annotations/annotations_trainval2017.zip

In [None]:
import os
from pathlib import Path

# Create data directory
data_dir = Path("coco_data")
data_dir.mkdir(exist_ok=True)

# Download train2017 images (~18GB)
if not (data_dir / "train2017").exists():
    print("Downloading COCO train2017 images...")
    !wget http://images.cocodataset.org/zips/train2017.zip -P coco_data
    !unzip -q coco_data/train2017.zip -d coco_data
    !rm coco_data/train2017.zip
    print("✓ Images downloaded")
else:
    print("✓ Images already downloaded")

# Download annotations (~1GB)
if not (data_dir / "annotations").exists():
    print("Downloading COCO annotations...")
    !wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P coco_data
    !unzip -q coco_data/annotations_trainval2017.zip -d coco_data
    !rm coco_data/annotations_trainval2017.zip
    print("✓ Annotations downloaded")
else:
    print("✓ Annotations already downloaded")

print(f"\nDataset ready!")
print(f"Images: {data_dir / 'train2017'}")
print(f"Captions: {data_dir / 'annotations' / 'captions_train2017.json'}")

## Training Configuration

Adjust parameters below. For quick testing, set `MAX_SAMPLES = 10000`. For full training, set `MAX_SAMPLES = None`.

In [None]:
IMAGES_DIR = "coco_data/train2017"
ANNOTATIONS_FILE = "coco_data/annotations/captions_train2017.json"
OUTPUT_DIR = "checkpoints"
EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
MAX_SAMPLES = None  # None = full dataset (~118k images)

## Start Training

This will take ~2-3 hours on T4 GPU for the full dataset. Loss should decrease from ~0.5 to ~0.05.

In [None]:
cmd = f"python train.py --images-dir {IMAGES_DIR} --annotations-file {ANNOTATIONS_FILE} --output-dir {OUTPUT_DIR} --epochs {EPOCHS} --batch-size {BATCH_SIZE} --lr {LEARNING_RATE}"
if MAX_SAMPLES:
    cmd += f" --max-samples {MAX_SAMPLES}"
    
!{cmd}

## Test Trained Model

Generate a caption with the newly trained projector.

In [None]:
from vision_caption import ModelConfig, load_model
from PIL import Image
import requests
from io import BytesIO

# Load model with trained projector
config = ModelConfig()
config.projector_path = "checkpoints/projector_final.pt"
model = load_model(config)

# Test with an example image
url = "https://images.unsplash.com/photo-1518791841217-8f162f1e1131"
response = requests.get(url)
image = Image.open(BytesIO(response.content))

# Generate caption
caption = model.generate(image)
print(f"\nGenerated caption: {caption}")

display(image)

## Download Trained Weights

Save the trained projector to use locally or share.

In [None]:
from google.colab import files

# Download the final trained projector
files.download('checkpoints/projector_final.pt')

print("\n✓ Downloaded projector_final.pt")
print("\nTo use it:")
print("1. Place it in your project's 'checkpoints/' directory")
print("2. Set config.projector_path = 'checkpoints/projector_final.pt'")
print("3. Run: model = load_model(config)")

## Launch Gradio Demo (Optional)

Test the trained model interactively.

In [None]:
# Update colab_app to use trained weights
import fileinput

for line in fileinput.input('colab_app.py', inplace=True):
    if 'config = ModelConfig()' in line:
        print(line, end='')
        print('    config.projector_path = "checkpoints/projector_final.pt"')
    else:
        print(line, end='')

# Launch demo with trained model
!python colab_app.py