# Finetuning Nomic Vision Embeddings

This notebook demonstrates how to finetune nomic-embed-image-v1 model using custom image data.

In [None]:
!pip install nomic torch datasets transformers pillow torchvision

In [None]:
from nomic import atlas
import torch
from torch import nn
from transformers import AutoImageProcessor, AutoModel
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import os

In [None]:
# Initialize the model and image processor
model_name = "nomic-ai/nomic-embed-image-v1"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

In [None]:
# Custom dataset class for images
class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return {
            'pixel_values': image,
            'label': torch.tensor(self.labels[idx], dtype=torch.float32)
        }

In [None]:
# Example training data - replace with your own data
# You should have a list of image paths and corresponding labels
train_image_paths = ["path/to/image1.jpg", "path/to/image2.jpg"]  # Replace with actual paths
train_labels = [0, 1]  # Example labels

# Create dataset and dataloader
train_dataset = ImageDataset(train_image_paths, train_labels)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [None]:
# Training configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.MSELoss()

In [None]:
# Training loop
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    for batch in train_loader:
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        outputs = model(pixel_values=pixel_values)
        embeddings = outputs.last_hidden_state[:, 0, :]  # Use [CLS] token embedding
        
        # Add a simple projection layer for the task
        projection = nn.Linear(embeddings.shape[1], 1).to(device)
        predictions = projection(embeddings).squeeze()
        
        loss = criterion(predictions, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}')

In [None]:
# Save the finetuned model
model.save_pretrained('finetuned_nomic_vision_model')
processor.save_pretrained('finetuned_nomic_vision_model')

## Using the Finetuned Model

Here's how to use the finetuned model to generate embeddings for images:

In [None]:
def get_image_embedding(image_path):
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = transform(image).unsqueeze(0).to(device)
    
    # Generate embedding
    with torch.no_grad():
        outputs = model(pixel_values=image)
        embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
    
    return embedding

# Example usage
test_image_path = "path/to/test_image.jpg"  # Replace with actual path
if os.path.exists(test_image_path):
    embedding = get_image_embedding(test_image_path)
    print(f"Embedding shape: {embedding.shape}")