# Integrating Custom Models with LookBench

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SerendipityOneInc/look-bench/blob/main/notebooks/03_custom_model.ipynb)

This notebook demonstrates how to **integrate your own custom models** into the LookBench framework using the registry pattern.

## What You'll Learn

1. Create a custom model class inheriting from `BaseModel`
2. Register your model using the `@register_model` decorator
3. Implement required methods (`load_model`, `get_transform`)
4. Test your model with LookBench dataset
5. Integrate Hugging Face models

📄 **Paper**: [arxiv.org/abs/2601.14706](https://arxiv.org/abs/2601.14706)

## Setup

In [None]:
# Install required packages
!pip install -q torch torchvision transformers datasets pillow pandas pyarrow pyyaml tqdm matplotlib

# Clone LookBench repository
!git clone https://github.com/SerendipityOneInc/look-bench.git
%cd look-bench

import sys
sys.path.append('/content/look-bench')

print("✅ Setup complete!")

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, models
from datasets import load_dataset
import numpy as np

from models.base import BaseModel
from models.registry import register_model, list_available_models

print("✅ Imports successful!")

## Example 1: ResNet-50 Model

Let's create a simple ResNet-50 model for fashion image retrieval.

In [None]:
@register_model("resnet50", metadata={
    "description": "ResNet-50 pretrained on ImageNet",
    "framework": "PyTorch",
    "input_size": 224,
    "embedding_dim": 2048
})
class ResNet50Model(BaseModel):
    """ResNet-50 model for image embedding extraction"""
    
    @classmethod
    def load_model(cls, model_name: str = "resnet50", model_path: str = None):
        """
        Load ResNet-50 model
        
        Args:
            model_name: Name of the model
            model_path: Optional path to custom weights
            
        Returns:
            Tuple of (model, wrapper_instance)
        """
        # Load pretrained ResNet-50
        model = models.resnet50(pretrained=True)
        
        # Remove classification head, keep feature extractor
        model = nn.Sequential(*list(model.children())[:-1])
        
        # Load custom weights if provided
        if model_path:
            state_dict = torch.load(model_path, map_location='cpu')
            model.load_state_dict(state_dict)
        
        model.eval()
        
        # Create a wrapper that flattens the output
        class ModelWrapper(nn.Module):
            def __init__(self, backbone):
                super().__init__()
                self.backbone = backbone
            
            def forward(self, x):
                features = self.backbone(x)
                return features.squeeze(-1).squeeze(-1)  # Flatten spatial dimensions
        
        wrapped_model = ModelWrapper(model)
        
        return wrapped_model, cls()
    
    @classmethod
    def get_transform(cls, input_size: int = 224):
        """
        Get image preprocessing transform
        
        Args:
            input_size: Input image size
            
        Returns:
            torchvision.transforms composition
        """
        return transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

print("✅ ResNet-50 model registered!")
print(f"Available models: {list_available_models()}")

## Test the Custom Model

Let's load the model and test it with a sample image from LookBench.

In [None]:
# Load the model
print("Loading ResNet-50 model...")
model, wrapper = ResNet50Model.load_model()
transform = ResNet50Model.get_transform()

# Move to GPU if available
if torch.cuda.is_available():
    model = model.cuda()
    print("✅ Model on CUDA")

# Load sample image
dataset = load_dataset("srpone/look-bench")
sample_img = dataset['real_studio_flat']['query'][0]['image']

# Preprocess and extract features
img_tensor = transform(sample_img).unsqueeze(0)

if torch.cuda.is_available():
    img_tensor = img_tensor.cuda()

with torch.no_grad():
    features = model(img_tensor)

print(f"\n✅ Feature extraction successful!")
print(f"   Feature shape: {features.shape}")
print(f"   Feature dimension: {features.shape[1]}")
print(f"   Feature norm: {torch.norm(features).item():.4f}")

## Example 2: Custom Architecture

Here's how to integrate a model with custom architecture and checkpoint loading.

In [None]:
@register_model("custom_fashion_model", metadata={
    "description": "Custom fashion embedding model",
    "framework": "PyTorch",
    "input_size": 256,
    "embedding_dim": 512
})
class CustomFashionModel(BaseModel):
    """Custom fashion model trained on fashion data"""
    
    @classmethod
    def load_model(cls, model_name: str = "custom_fashion_model", model_path: str = None):
        """Load custom model"""
        # Define your model architecture
        class YourCustomArchitecture(nn.Module):
            def __init__(self, embedding_dim=512):
                super().__init__()
                # Your model architecture here
                self.backbone = models.resnet34(pretrained=True)
                self.backbone.fc = nn.Linear(self.backbone.fc.in_features, embedding_dim)
            
            def forward(self, x):
                return self.backbone(x)
        
        # Instantiate model
        model = YourCustomArchitecture(embedding_dim=512)
        
        # Load trained weights if provided
        if model_path:
            checkpoint = torch.load(model_path, map_location='cpu')
            
            # Handle different checkpoint formats
            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['model_state_dict'])
                elif 'state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['state_dict'])
                else:
                    model.load_state_dict(checkpoint)
            else:
                model.load_state_dict(checkpoint)
        
        model.eval()
        return model, cls()
    
    @classmethod
    def get_transform(cls, input_size: int = 256):
        """Custom preprocessing pipeline"""
        return transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

print("✅ Custom fashion model registered!")
print(f"Available models: {list_available_models()}")

## Example 3: Hugging Face Models

Integrate models from Hugging Face Hub (like GR-Lite).

In [None]:
from transformers import AutoModel, AutoImageProcessor

@register_model("gr_lite", metadata={
    "description": "GR-Lite model from srpone",
    "framework": "PyTorch/Transformers",
    "input_size": 336,
    "embedding_dim": 1024
})
class GRLiteModel(BaseModel):
    """GR-Lite model for fashion retrieval"""
    
    @classmethod
    def load_model(cls, model_name: str = "srpone/gr-lite", model_path: str = None):
        """Load GR-Lite model from Hugging Face"""
        # Load model from Hugging Face Hub or local path
        if model_path:
            model = AutoModel.from_pretrained(model_path)
        else:
            model = AutoModel.from_pretrained(model_name)
        
        model.eval()
        return model, cls()
    
    @classmethod
    def get_transform(cls, input_size: int = 336):
        """Get transform for GR-Lite"""
        return transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

print("✅ GR-Lite model registered!")
print(f"\nAll available models: {list_available_models()}")

## Test Similarity Computation

Let's test the custom model with similarity computation.

In [None]:
import matplotlib.pyplot as plt

# Load ResNet-50 model
model, _ = ResNet50Model.load_model()
transform = ResNet50Model.get_transform()

if torch.cuda.is_available():
    model = model.cuda()

# Load dataset
query_img = dataset['real_studio_flat']['query'][0]['image']
gallery_imgs = [dataset['real_studio_flat']['gallery'][i]['image'] for i in range(3)]

def extract_features(images, model, transform):
    """Extract features from images"""
    features = []
    for img in images:
        img_tensor = transform(img).unsqueeze(0)
        if torch.cuda.is_available():
            img_tensor = img_tensor.cuda()
        with torch.no_grad():
            feat = model(img_tensor)
        features.append(feat.cpu().numpy())
    return np.vstack(features)

# Extract features
query_feat = extract_features([query_img], model, transform)
gallery_feats = extract_features(gallery_imgs, model, transform)

# Compute cosine similarity
query_norm = query_feat / np.linalg.norm(query_feat, axis=1, keepdims=True)
gallery_norm = gallery_feats / np.linalg.norm(gallery_feats, axis=1, keepdims=True)
similarities = np.dot(query_norm, gallery_norm.T)[0]

print("✅ Similarity computation:")
for i, sim in enumerate(similarities):
    print(f"   Gallery image {i}: {sim:.4f}")

# Visualize
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(query_img)
axes[0].set_title('Query', fontweight='bold')
axes[0].axis('off')

for i, (img, sim) in enumerate(zip(gallery_imgs, similarities)):
    axes[i+1].imshow(img)
    axes[i+1].set_title(f'Gallery {i}\nSim: {sim:.3f}')
    axes[i+1].axis('off')

plt.tight_layout()
plt.show()

print("\n✅ Custom model integration test passed!")

## Configure Your Model

To use your custom model with LookBench's main evaluation pipeline, add it to `configs/config.yaml`:

```yaml
# Custom ResNet-50 model
resnet50:
  enabled: true
  model_name: "resnet50"
  model_path: null
  input_size: 224
  embedding_dim: 2048
  device: "cuda"

# Custom trained model
custom_fashion_model:
  enabled: true
  model_name: "custom_fashion_model"
  model_path: "/path/to/your/model/weights.pth"
  input_size: 256
  embedding_dim: 512
  device: "cuda"

# GR-Lite from Hugging Face
gr_lite:
  enabled: true
  model_name: "srpone/gr-lite"
  model_path: null
  input_size: 336
  embedding_dim: 1024
  device: "cuda"
```

Then you can run evaluation:
```bash
python main.py --model resnet50
```

## Full Evaluation with Custom Model

You can now use your custom model with the full evaluation pipeline from notebook 02.

In [None]:
# Example: Quick evaluation with custom model
print("You can now use your custom model for full evaluation!")
print("\nNext steps:")
print("  1. Use your model with ConfigManager and ModelManager")
print("  2. Run full evaluation from notebook 02")
print("  3. Compare with baseline models")
print("  4. Submit results to LookBench leaderboard")

print("\n✅ All examples completed successfully!")

## Best Practices

### 1. Feature Normalization
- Always L2-normalize features for retrieval
- Improves cosine similarity computation

### 2. Input Preprocessing
- Use the same preprocessing as training
- Match normalization statistics to your model

### 3. Model Checkpoints
- Handle different checkpoint formats
- Support both full models and state_dict

### 4. Device Management
- Always check CUDA availability
- Move tensors to same device as model

### 5. Evaluation Mode
- Set model to `.eval()` mode
- Use `torch.no_grad()` for inference

## Next Steps

1. **Evaluate your model**: Run full evaluation on all LookBench subsets
2. **Compare with baselines**: See how your model compares to CLIP, SigLIP, GR-Lite
3. **Fine-tune**: Use LookBench for model training and improvement
4. **Submit results**: Share your results with the community

### Useful Links

- 📄 **Paper**: https://arxiv.org/abs/2601.14706
- 🏠 **Project**: https://serendipityoneinc.github.io/look-bench-page/
- 🤗 **Dataset**: https://huggingface.co/datasets/srpone/look-bench
- 🤗 **GR-Lite**: https://huggingface.co/srpone/gr-lite
- 💻 **GitHub**: https://github.com/SerendipityOneInc/look-bench

**Happy modeling! 🚀**