# Spectra: Object Recognition Model Training

This notebook implements the training pipeline for 'spectra', trained on LAION-5B, WIT, PMD, and expanded detection datasets (COCO, OpenImages, etc.) using CLIP-style contrastive learning.

## Features
- **Backbone**: `laion/CLIP-ViT-B-16-laion2B-s34B-b88K` (OpenCLIP ViT-B/16)
- **Streaming Data**: Zero-disk usage for massive datasets.
- **Robust Checkpointing**: Resumable training from HF Hub (Landmarks).
- **Phase 2 Training**: Includes COCO, OpenImages, Objects365, VOC, LVIS, Visual Genome.
- **Optimizations**: Mixed Precision (FP16), Gradient Accumulation, Torch Compile.

In [1]:
# Imports
import os
import json
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset, interleave_datasets
from transformers import CLIPModel, CLIPProcessor
from accelerate import Accelerator
from huggingface_hub import login, HfApi, hf_hub_download
from dotenv import load_dotenv
from tqdm.auto import tqdm
from pathlib import Path

# Load environment variables
# add env path here
env_path = Path(".env")

if env_path.exists():
    load_dotenv(dotenv_path=env_path, override=True)
    print("Loaded .env file")
else:
    print("Warning: .env file not found in current directory")

HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)
    print(f"Logged in to HF with token ending in ...{HF_TOKEN[-4:]}")
else:
    print("WARNING: HF_TOKEN not found in .env. Please set it for Hub uploads.")

Loaded .env file


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Logged in to HF with token ending in ...nrRd


In [2]:
# Configuration
CONFIG = {
    "model_name": "laion/CLIP-ViT-B-16-laion2B-s34B-b88K", # ViT-B/16 Backbone
    "fallback_model": "openai/clip-vit-base-patch32",
    "output_dir": "./spectra_checkpoints",
    "hub_model_id": "990aa/spectra",
    "batch_size": 64, # Per device
    "grad_accum_steps": 4,
    "learning_rate": 1e-4,
    "weight_decay": 0.1,
    "max_steps": 100000, # Increased for multi-phase
    "warmup_steps": 2000,
    "mixed_precision": "fp16",
    "image_size": 224,
    "push_to_hub": True,
    "checkpoint_interval": 1000,
    "state_file": "training_state.json",
    "model_card_file": "README.md"
}

accelerator = Accelerator(
    mixed_precision=CONFIG["mixed_precision"],
    gradient_accumulation_steps=CONFIG["grad_accum_steps"],
    log_with="all",
    project_dir="logs"
)

print(f"Accelerator setup: {accelerator.device}, Mixed Precision: {accelerator.mixed_precision}")

Accelerator setup: cpu, Mixed Precision: fp16




In [3]:
# Checkpointing & Resumption Logic
def save_landmark(step: int, consumed_samples: int) -> None:
    """
    Save a training checkpoint (landmark) and upload it to the Hugging Face Hub.
    
    Args:
        step (int): The current global training step.
        consumed_samples (int): The total number of samples consumed so far.
    """
    if not accelerator.is_main_process:
        return
    
    state = {
        "global_step": step,
        "consumed_samples": consumed_samples
    }
    with open(CONFIG["state_file"], "w") as f:
        json.dump(state, f)
    
    # Upload to HF
    api = HfApi()
    try:
        api.upload_file(
            path_or_fileobj=CONFIG["state_file"],
            path_in_repo=CONFIG["state_file"],
            repo_id=CONFIG["hub_model_id"],
            repo_type="model"
        )
        print(f"Landmark saved at step {step}")
        
        # Upload model card
        if os.path.exists(CONFIG["model_card_file"]):
            api.upload_file(
                path_or_fileobj=CONFIG["model_card_file"],
                path_in_repo="README.md",
                repo_id=CONFIG["hub_model_id"],
                repo_type="model"
            )
            print("Model card uploaded")
    except Exception as e:
        print(f"Failed to upload to Hub: {e}")

def load_landmark() -> dict | None:
    """
    Load the latest checkpoint state from the Hugging Face Hub.
    
    Returns:
        dict | None: The loaded state dictionary if found, else None.
    """
    try:
        path = hf_hub_download(repo_id=CONFIG["hub_model_id"], filename=CONFIG["state_file"])
        with open(path, "r") as f:
            state = json.load(f)
        print(f"Found previous state: Step {state['global_step']}")
        return state
    except Exception:
        print("No previous state found. Starting from scratch.")
        return None

In [None]:
# Data Pipeline Construction
def format_detection_dataset(item: dict, text_key: str = None, label_key: str = None) -> dict:
    """
    Standardize dataset items to {image, text} format.
    
    Args:
        item (dict): The dataset item.
        text_key (str, optional): Key for text caption.
        label_key (str, optional): Key for object labels.
        
    Returns:
        dict: Formatted item with 'image' and 'text' keys.
    """
    # Helper to standardize (image, text) format
    img = item.get("image")
    txt = ""
    
    if text_key and text_key in item:
        txt = item[text_key]
        if isinstance(txt, list):
            txt = txt[0]  # Take first caption
    elif label_key and label_key in item:
        # Convert labels to text (simplified)
        # In real scenario, map IDs to class names
        labels = item[label_key]
        txt = f"A photo of object {labels}" # Placeholder logic
        
    return {"image": img, "text": txt}

def get_dataset_stream(phase: str = "phase1") -> object:
    """
    Initialize and return a streaming dataset for the specified phase.
    
    Args:
        phase (str): The training phase ('phase1' or 'phase2').
        
    Returns:
        object: An interleaved streaming dataset.
    """
    print(f"Initializing Data Stream for {phase}...")
    
    datasets_list = []
    probs = []
    
    if phase == "phase1":
        # 1. LAION-5B (using laion2B-en subset for demo)
        try:
            ds_laion = load_dataset("laion/laion2B-en", split="train", streaming=True)
            # Safely extract fields with error handling
            def safe_laion_map(x):
                try:
                    return {
                        "image": x.get("URL") or x.get("url") or x.get("image"),
                        "text": x.get("TEXT") or x.get("caption") or ""
                    }
                except Exception:
                    return None
            
            ds_laion = ds_laion.map(safe_laion_map)
            ds_laion = ds_laion.filter(lambda x: x is not None and x.get("image") is not None)
            datasets_list.append(ds_laion)
            probs.append(0.5)
            print("✓ LAION-2B loaded")
        except Exception as e:
            print(f"Warning: Could not load LAION-2B: {e}")
        
        # 2. WIT (Wikipedia Image Text)
        try:
            ds_wit = load_dataset("wikimedia/wit_base", split="train", streaming=True)
            def safe_wit_map(x):
                try:
                    return {
                        "image": x.get("image") or x.get("image_url"),
                        "text": x.get("caption_reference_description") or x.get("caption") or ""
                    }
                except Exception:
                    return None
            
            ds_wit = ds_wit.map(safe_wit_map)
            ds_wit = ds_wit.filter(lambda x: x is not None and x.get("image") is not None)
            datasets_list.append(ds_wit)
            probs.append(0.3)
            print("✓ WIT loaded")
        except Exception as e:
            print(f"Warning: Could not load WIT: {e}")
        
        # 3. Conceptual Captions (CC3M/CC12M alternative)
        try:
            ds_cc = load_dataset("google-research-datasets/conceptual_captions", split="train", streaming=True)
            def safe_cc_map(x):
                try:
                    return {
                        "image": x.get("image_url") or x.get("image"),
                        "text": x.get("caption") or ""
                    }
                except Exception:
                    return None
            
            ds_cc = ds_cc.map(safe_cc_map)
            ds_cc = ds_cc.filter(lambda x: x is not None and x.get("image") is not None)
            datasets_list.append(ds_cc)
            probs.append(0.2)
            print("✓ Conceptual Captions loaded")
        except Exception as e:
            print(f"Warning: Could not load Conceptual Captions: {e}")
        
    elif phase == "phase2":
        # Detection Datasets with robust loading
        
        # 1. COCO (Common Objects in Context)
        try:
            ds_coco = load_dataset("detection-datasets/coco", split="train", streaming=True)
            def safe_coco_map(x):
                try:
                    objects = x.get("objects", {})
                    labels = objects.get("category", []) if isinstance(objects, dict) else []
                    return {
                        "image": x.get("image"),
                        "text": f"Objects: {', '.join(map(str, labels[:5]))}" if labels else "An image"
                    }
                except Exception:
                    return None
            
            ds_coco = ds_coco.map(safe_coco_map)
            ds_coco = ds_coco.filter(lambda x: x is not None and x.get("image") is not None)
            datasets_list.append(ds_coco)
            probs.append(0.4)
            print("✓ COCO loaded")
        except Exception as e:
            print(f"Warning: Could not load COCO: {e}")
        
        # 2. Visual Genome
        try:
            ds_vg = load_dataset("visual_genome", "region_descriptions_v1.2.0", split="train", streaming=True)
            def safe_vg_map(x):
                try:
                    return {
                        "image": x.get("image") or x.get("url"),
                        "text": x.get("phrase") or x.get("regions", [{}])[0].get("phrase", "")
                    }
                except Exception:
                    return None
            
            ds_vg = ds_vg.map(safe_vg_map)
            ds_vg = ds_vg.filter(lambda x: x is not None and x.get("image") is not None)
            datasets_list.append(ds_vg)
            probs.append(0.3)
            print("✓ Visual Genome loaded")
        except Exception as e:
            print(f"Warning: Could not load Visual Genome: {e}")
        
        # 3. Objects365 alternative or ImageNet-like dataset
        try:
            ds_objects = load_dataset("objects365", split="train", streaming=True)
            def safe_objects_map(x):
                try:
                    return {
                        "image": x.get("image"),
                        "text": x.get("objects") or "Various objects"
                    }
                except Exception:
                    return None
            
            ds_objects = ds_objects.map(safe_objects_map)
            ds_objects = ds_objects.filter(lambda x: x is not None and x.get("image") is not None)
            datasets_list.append(ds_objects)
            probs.append(0.3)
            print("✓ Objects365 loaded")
        except Exception as e:
            print(f"Warning: Could not load Objects365: {e}")
        
        # Fallback to Phase 1 if no Phase 2 datasets loaded
        if not datasets_list:
            print("Warning: Phase 2 datasets not accessible. Falling back to Phase 1 streams.")
            return get_dataset_stream("phase1")

    # If no datasets loaded at all, create a dummy dataset to prevent errors
    if not datasets_list:
        print("ERROR: No datasets could be loaded. Creating minimal fallback dataset.")
        # You should replace this with actual accessible datasets
        raise RuntimeError("No datasets available. Please check your internet connection and dataset access permissions.")

    # Normalize probabilities
    probs = [p/sum(probs) for p in probs]
    
    print(f"Interleaving {len(datasets_list)} datasets with probabilities: {probs}")
    
    combined_ds = interleave_datasets(
        datasets_list, 
        probabilities=probs, 
        seed=42,
        stopping_strategy="first_exhausted"
    )
    
    return combined_ds

def collate_fn(batch: list) -> dict | None:
    """
    Collate function to process a batch of items.
    
    Args:
        batch (list): List of dataset items.
        
    Returns:
        dict | None: Processed batch tensors or None if empty.
    """
    processor = CLIPProcessor.from_pretrained(CONFIG["model_name"])
    texts = []
    images = []
    
    for item in batch:
        if item is None:
            continue
            
        img = item.get("image")
        txt = item.get("text")
        
        if img and txt:
            try:
                # Handle URL images (download if needed)
                if isinstance(img, str):
                    from PIL import Image
                    import requests
                    from io import BytesIO
                    
                    response = requests.get(img, timeout=5)
                    img = Image.open(BytesIO(response.content))
                
                # Convert to RGB
                if img.mode != "RGB":
                    img = img.convert("RGB")
                    
                images.append(img)
                texts.append(str(txt)[:77]) # Truncate for CLIP
            except Exception:
                continue
            
    if not images:
        return None
    return processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)

In [5]:
# Model Setup
try:
    model = CLIPModel.from_pretrained(CONFIG["model_name"])
    print(f"Loaded {CONFIG['model_name']}")
except Exception:
    print(f"Failed to load {CONFIG['model_name']}, falling back to {CONFIG['fallback_model']}")
    model = CLIPModel.from_pretrained(CONFIG["fallback_model"])
    CONFIG["model_name"] = CONFIG["fallback_model"]

# Freeze early layers
for name, param in model.vision_model.encoder.layers[:6].named_parameters():
    param.requires_grad = False

# Torch Compile
try:
    model = torch.compile(model)
    print("Model compiled.")
except Exception:
    print("torch.compile skipped.")

Failed to load laion/CLIP-ViT-B-16-laion2B-s34B-b88K, falling back to openai/clip-vit-base-patch32


config.json: 0.00B [00:00, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

Model compiled.


In [6]:
# Training Loop with Phases
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"], weight_decay=CONFIG["weight_decay"])
model, optimizer = accelerator.prepare(model, optimizer)

# Check for resumption
start_step = 0
consumed_samples = 0
saved_state = load_landmark()
if saved_state:
    start_step = saved_state["global_step"]
    consumed_samples = saved_state["consumed_samples"]
    # Load weights
    try:
        accelerator.load_state(CONFIG["output_dir"])
        print("Weights loaded.")
    except Exception:
        print("Warning: State file found but weights could not be loaded from local dir. Pulling from Hub if possible.")

# Define Phases
phases = ["phase1", "phase2"]
current_phase_idx = 0
if start_step > CONFIG["max_steps"] // 2:
    current_phase_idx = 1 # Simple logic to switch phases halfway

dataset = get_dataset_stream(phases[current_phase_idx])
dataset = dataset.shuffle(buffer_size=1000, seed=42)
dataset = dataset.skip(consumed_samples) # Skip processed samples

train_dataloader = DataLoader(dataset, batch_size=CONFIG["batch_size"], collate_fn=collate_fn, num_workers=4, pin_memory=True)
train_dataloader = accelerator.prepare(train_dataloader)

model.train()
progress_bar = tqdm(range(start_step, CONFIG["max_steps"]))
global_step = start_step

for batch in train_dataloader:
    if batch is None:
        continue
    
    with accelerator.accumulate(model):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
    
    if accelerator.sync_gradients:
        progress_bar.update(1)
        global_step += 1
        consumed_samples += CONFIG["batch_size"] * CONFIG["grad_accum_steps"] * accelerator.num_processes
        
        if global_step % 100 == 0:
            accelerator.print(f"Step {global_step}: Loss {loss.item()}")
            
        if global_step % CONFIG["checkpoint_interval"] == 0:
            accelerator.wait_for_everyone()
            if accelerator.is_main_process:
                accelerator.save_state(CONFIG["output_dir"])
                save_landmark(global_step, consumed_samples)
                
                if CONFIG["push_to_hub"]:
                     model.push_to_hub(CONFIG["hub_model_id"])
                     CLIPProcessor.from_pretrained(CONFIG["model_name"]).push_to_hub(CONFIG["hub_model_id"])

    # Phase Switch Logic
    if global_step == CONFIG["max_steps"] // 2 and current_phase_idx == 0:
        print("Switching to Phase 2...")
        current_phase_idx = 1
        dataset = get_dataset_stream("phase2")
        # Re-wrap dataloader
        train_dataloader = DataLoader(dataset, batch_size=CONFIG["batch_size"], collate_fn=collate_fn, num_workers=4)
        train_dataloader = accelerator.prepare(train_dataloader)

    if global_step >= CONFIG["max_steps"]:
        break

print("Training Complete")

No previous state found. Starting from scratch.
Initializing Data Stream for phase1...


Resolving data files:   0%|          | 0/128 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/330 [00:00<?, ?it/s]

dataset_infos.json: 0.00B [00:00, ?B/s]

KeyError: 'image'