# Spectra: Object Recognition Model Training

This notebook implements the training pipeline for 'spectra', trained on LAION-5B, WIT, PMD, and expanded detection datasets (COCO, OpenImages, etc.) using CLIP-style contrastive learning.

## Features
- **Backbone**: `laion/CLIP-ViT-B-16-laion2B-s34B-b88K` (OpenCLIP ViT-B/16)
- **Streaming Data**: Zero-disk usage for massive datasets.
- **Robust Checkpointing**: Resumable training from HF Hub (Landmarks).
- **Phase 2 Training**: Includes COCO, OpenImages, Objects365, VOC, LVIS, Visual Genome.
- **Optimizations**: Mixed Precision (FP16), Gradient Accumulation, Torch Compile.

In [None]:
# Imports
import os
import json
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset, interleave_datasets
from transformers import CLIPModel, CLIPProcessor
from accelerate import Accelerator
from huggingface_hub import login, HfApi, hf_hub_download
from dotenv import load_dotenv
from tqdm.auto import tqdm
from pathlib import Path

# Load environment variables
# add env path here
env_path = Path(".env")

if env_path.exists():
    load_dotenv(dotenv_path=env_path, override=True)
    print("Loaded .env file")
else:
    print("Warning: .env file not found in current directory")

HF_TOKEN = os.getenv("HF_TOKEN")

if HF_TOKEN:
    login(token=HF_TOKEN)
    print(f"Logged in to HF with token ending in ...{HF_TOKEN[-4:]}")
else:
    print("WARNING: HF_TOKEN not found in .env. Please set it for Hub uploads.")

Logged in to HF with token ending in ...nrRd


In [3]:
# Configuration
CONFIG = {
    "model_name": "laion/CLIP-ViT-B-16-laion2B-s34B-b88K",  # ViT-B/16 Backbone
    "fallback_model": "openai/clip-vit-base-patch32",
    "output_dir": "./spectra_checkpoints",
    "hub_model_id": "990aa/spectra",
    "batch_size": 64,  # Per device
    "grad_accum_steps": 4,
    "learning_rate": 1e-4,
    "weight_decay": 0.1,
    "max_steps": 100000,  # Increased for multi-phase
    "warmup_steps": 2000,
    "mixed_precision": "fp16",
    "image_size": 224,
    "push_to_hub": True,
    "checkpoint_interval": 1000,
    "state_file": "training_state.json",
    "model_card_file": "README.md",
}

accelerator = Accelerator(
    mixed_precision=CONFIG["mixed_precision"],
    gradient_accumulation_steps=CONFIG["grad_accum_steps"],
    log_with="all",
    project_dir="logs",
)

print(
    f"Accelerator setup: {accelerator.device}, Mixed Precision: {accelerator.mixed_precision}"
)

Accelerator setup: cuda, Mixed Precision: fp16


In [4]:
# Checkpointing & Resumption Logic
def save_landmark(step: int, consumed_samples: int) -> None:
    """
    Save a training checkpoint (landmark) and upload it to the Hugging Face Hub.

    Args:
        step (int): The current global training step.
        consumed_samples (int): The total number of samples consumed so far.
    """
    if not accelerator.is_main_process:
        return

    state = {"global_step": step, "consumed_samples": consumed_samples}
    with open(CONFIG["state_file"], "w") as f:
        json.dump(state, f)

    # Upload to HF
    api = HfApi()
    try:
        api.upload_file(
            path_or_fileobj=CONFIG["state_file"],
            path_in_repo=CONFIG["state_file"],
            repo_id=CONFIG["hub_model_id"],
            repo_type="model",
        )
        print(f"Landmark saved at step {step}")

        # Upload model card
        if os.path.exists(CONFIG["model_card_file"]):
            api.upload_file(
                path_or_fileobj=CONFIG["model_card_file"],
                path_in_repo="README.md",
                repo_id=CONFIG["hub_model_id"],
                repo_type="model",
            )
            print("Model card uploaded")
    except Exception as e:
        print(f"Failed to upload to Hub: {e}")


def load_landmark() -> dict | None:
    """
    Load the latest checkpoint state from the Hugging Face Hub.

    Returns:
        dict | None: The loaded state dictionary if found, else None.
    """
    try:
        path = hf_hub_download(
            repo_id=CONFIG["hub_model_id"], filename=CONFIG["state_file"]
        )
        with open(path, "r") as f:
            state = json.load(f)
        print(f"Found previous state: Step {state['global_step']}")
        return state
    except Exception:
        print("No previous state found. Starting from scratch.")
        return None

In [5]:
# Data Pipeline Construction
from PIL import Image
import requests
from io import BytesIO

def get_dataset_stream(phase: str = "phase1") -> object:
    """
    Initialize and return a streaming dataset for the specified phase.

    Args:
        phase (str): The training phase ('phase1' or 'phase2').

    Returns:
        object: An interleaved streaming dataset.
    """
    print(f"Initializing Data Stream for {phase}...")

    datasets_list = []
    probs = []

    if phase == "phase1":
        # 1. LAION-5B (using laion2B-en subset for demo)
        try:
            ds_laion = load_dataset("laion/laion2B-en", split="train", streaming=True)
            
            # Select only the columns we need and rename them
            ds_laion = ds_laion.select_columns(['URL', 'TEXT'])
            ds_laion = ds_laion.rename_columns({'URL': 'image', 'TEXT': 'text'})
            
            # Filter out invalid entries
            ds_laion = ds_laion.filter(
                lambda x: x.get('image') and x.get('text') and 
                isinstance(x['image'], str) and isinstance(x['text'], str)
            )
            
            datasets_list.append(ds_laion)
            probs.append(0.5)
            print("✓ LAION-2B loaded")
        except Exception as e:
            print(f"Warning: Could not load LAION-2B: {e}")

        # 2. WIT (Wikipedia Image Text)
        try:
            ds_wit = load_dataset("wikimedia/wit_base", split="train", streaming=True)
            
            # Extract image_url and caption fields
            def wit_transform(x):
                return {
                    'image': x.get('image_url', ''),
                    'text': x.get('caption_reference_description', '') or x.get('caption', '')
                }
            
            ds_wit = ds_wit.map(wit_transform, remove_columns=ds_wit.column_names)
            ds_wit = ds_wit.filter(
                lambda x: x.get('image') and x.get('text') and 
                isinstance(x['image'], str) and isinstance(x['text'], str)
            )
            
            datasets_list.append(ds_wit)
            probs.append(0.3)
            print("✓ WIT loaded")
        except Exception as e:
            print(f"Warning: Could not load WIT: {e}")

        # 3. Conceptual Captions
        try:
            ds_cc = load_dataset(
                "google-research-datasets/conceptual_captions",
                split="train",
                streaming=True,
            )
            
            ds_cc = ds_cc.rename_columns({'image_url': 'image'})
            ds_cc = ds_cc.select_columns(['image', 'caption'])
            ds_cc = ds_cc.rename_columns({'caption': 'text'})
            
            ds_cc = ds_cc.filter(
                lambda x: x.get('image') and x.get('text') and 
                isinstance(x['image'], str) and isinstance(x['text'], str)
            )
            
            datasets_list.append(ds_cc)
            probs.append(0.2)
            print("✓ Conceptual Captions loaded")
        except Exception as e:
            print(f"Warning: Could not load Conceptual Captions: {e}")

    elif phase == "phase2":
        # For Phase 2, we'll use Phase 1 datasets as COCO has PIL images that cause issues
        # You can add proper Phase 2 datasets that provide URLs instead of PIL images
        print("Phase 2: Using URL-based datasets (COCO with PIL images skipped)")
        return get_dataset_stream("phase1")

    # If no datasets loaded at all
    if not datasets_list:
        print("ERROR: No datasets could be loaded.")
        raise RuntimeError(
            "No datasets available. Please check your internet connection and dataset access permissions."
        )

    # Normalize probabilities
    probs = [p / sum(probs) for p in probs]

    print(f"Interleaving {len(datasets_list)} datasets with probabilities: {probs}")

    combined_ds = interleave_datasets(
        datasets_list, probabilities=probs, seed=42, stopping_strategy="first_exhausted"
    )

    return combined_ds


def collate_fn(batch: list) -> dict | None:
    """
    Collate function to process a batch of items with robust image loading.

    Args:
        batch (list): List of dataset items.

    Returns:
        dict | None: Processed batch tensors or None if empty.
    """
    processor = CLIPProcessor.from_pretrained(CONFIG["model_name"])
    texts = []
    images = []

    for item in batch:
        if item is None or not isinstance(item, dict):
            continue

        img = item.get("image")
        txt = item.get("text")

        if not img or not txt:
            continue

        try:
            # Handle different image formats
            if isinstance(img, str):
                # URL or file path
                if img.startswith(('http://', 'https://')):
                    # Download from URL with retry
                    try:
                        response = requests.get(
                            img, 
                            timeout=5, 
                            headers={'User-Agent': 'Mozilla/5.0'}
                        )
                        response.raise_for_status()
                        img = Image.open(BytesIO(response.content))
                    except Exception:
                        continue
                else:
                    # Local file path
                    try:
                        img = Image.open(img)
                    except Exception:
                        continue
            elif isinstance(img, dict):
                # HuggingFace datasets format with bytes
                if 'bytes' in img:
                    try:
                        img = Image.open(BytesIO(img['bytes']))
                    except Exception:
                        continue
                elif 'path' in img:
                    try:
                        img = Image.open(img['path'])
                    except Exception:
                        continue
                else:
                    continue
            elif hasattr(img, 'convert'):
                # Already a PIL Image
                pass
            else:
                # Skip if we can't handle it
                continue

            # Convert to RGB
            if hasattr(img, 'mode'):
                if img.mode != "RGB":
                    img = img.convert("RGB")
            else:
                continue

            # Validate image
            if hasattr(img, 'size'):
                width, height = img.size
                if width < 10 or height < 10:  # Skip tiny images
                    continue

            images.append(img)
            texts.append(str(txt)[:77])  # Truncate for CLIP
            
        except Exception:
            # Skip problematic images
            continue

    if not images:
        return None
    
    try:
        return processor(
            text=texts, 
            images=images, 
            return_tensors="pt", 
            padding=True, 
            truncation=True
        )
    except Exception:
        return None


In [None]:
# Model Setup
try:
    model = CLIPModel.from_pretrained(CONFIG["model_name"])
    print(f"Loaded {CONFIG['model_name']}")
except Exception:
    print(
        f"Failed to load {CONFIG['model_name']}, falling back to {CONFIG['fallback_model']}"
    )
    model = CLIPModel.from_pretrained(CONFIG["fallback_model"])
    CONFIG["model_name"] = CONFIG["fallback_model"]

# Freeze early layers
for name, param in model.vision_model.encoder.layers[:6].named_parameters():
    param.requires_grad = False

# Torch Compile
try:
    model = torch.compile(model)
    print("Model compiled.")
except Exception:
    print("torch.compile skipped.")

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Failed to load laion/CLIP-ViT-B-16-laion2B-s34B-b88K, falling back to openai/clip-vit-base-patch32


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

Model compiled.


: 

In [7]:
# Training Loop with Phases
optimizer = torch.optim.AdamW(
    model.parameters(), lr=CONFIG["learning_rate"], weight_decay=CONFIG["weight_decay"]
)
model, optimizer = accelerator.prepare(model, optimizer)

# Check for resumption
start_step = 0
consumed_samples = 0
saved_state = load_landmark()
if saved_state:
    start_step = saved_state["global_step"]
    consumed_samples = saved_state["consumed_samples"]
    # Load weights
    try:
        accelerator.load_state(CONFIG["output_dir"])
        print("Weights loaded.")
    except Exception:
        print(
            "Warning: State file found but weights could not be loaded from local dir. Pulling from Hub if possible."
        )

# Define Phases
phases = ["phase1", "phase2"]
current_phase_idx = 0
if start_step > CONFIG["max_steps"] // 2:
    current_phase_idx = 1  # Simple logic to switch phases halfway

dataset = get_dataset_stream(phases[current_phase_idx])
dataset = dataset.shuffle(buffer_size=1000, seed=42)
dataset = dataset.skip(consumed_samples)  # Skip processed samples

train_dataloader = DataLoader(
    dataset,
    batch_size=CONFIG["batch_size"],
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)
train_dataloader = accelerator.prepare(train_dataloader)

model.train()
progress_bar = tqdm(range(start_step, CONFIG["max_steps"]))
global_step = start_step

for batch in train_dataloader:
    if batch is None:
        continue

    with accelerator.accumulate(model):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

    if accelerator.sync_gradients:
        progress_bar.update(1)
        global_step += 1
        consumed_samples += (
            CONFIG["batch_size"]
            * CONFIG["grad_accum_steps"]
            * accelerator.num_processes
        )

        if global_step % 100 == 0:
            accelerator.print(f"Step {global_step}: Loss {loss.item()}")

        if global_step % CONFIG["checkpoint_interval"] == 0:
            accelerator.wait_for_everyone()
            if accelerator.is_main_process:
                accelerator.save_state(CONFIG["output_dir"])
                save_landmark(global_step, consumed_samples)

                if CONFIG["push_to_hub"]:
                    model.push_to_hub(CONFIG["hub_model_id"])
                    CLIPProcessor.from_pretrained(CONFIG["model_name"]).push_to_hub(
                        CONFIG["hub_model_id"]
                    )

    # Phase Switch Logic
    if global_step == CONFIG["max_steps"] // 2 and current_phase_idx == 0:
        print("Switching to Phase 2...")
        current_phase_idx = 1
        dataset = get_dataset_stream("phase2")
        # Re-wrap dataloader
        train_dataloader = DataLoader(
            dataset,
            batch_size=CONFIG["batch_size"],
            collate_fn=collate_fn,
            num_workers=4,
        )
        train_dataloader = accelerator.prepare(train_dataloader)

    if global_step >= CONFIG["max_steps"]:
        break

print("Training Complete")

No previous state found. Starting from scratch.
Initializing Data Stream for phase1...


Resolving data files:   0%|          | 0/128 [00:00<?, ?it/s]



README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/330 [00:00<?, ?it/s]

dataset_infos.json: 0.00B [00:00, ?B/s]

✓ WIT loaded


README.md: 0.00B [00:00, ?B/s]

✓ Conceptual Captions loaded
Interleaving 2 datasets with probabilities: [0.6, 0.4]


: 

: 