# PictSure Simple Example

This notebook demonstrates the basic usage of PictSure for few-shot image classification.

In [1]:
import os
import random
from PIL import Image
import torch
from PictSure import PictSure

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Configuration
from torchvision import datasets, transforms
from collections import defaultdict

# Download CIFAR-10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Load CIFAR-10 dataset
cifar_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
print(f"CIFAR-10 test set loaded: {len(cifar_dataset)} images")

# Organize by class
class_to_indices = defaultdict(list)
for idx, (_, label) in enumerate(cifar_dataset):
    class_to_indices[label].append(idx)

CIFAR_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
print(f"Classes: {CIFAR_CLASSES}")

Using device: cuda
CIFAR-10 test set loaded: 10000 images
Classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [None]:
# Load pre-trained model from HuggingFace
model = PictSure.from_pretrained("pictsure/pictsure-dinov2").to(DEVICE)
print("Model loaded successfully")



Model loaded successfully


In [4]:
# Prepare context images - 5 images per class (10 classes = 50 total context images)
NUM_CONTEXT_PER_CLASS = 5

context_images = []
context_labels = []

print("Loading context images from CIFAR-10...\n")

for class_id in range(10):
    # Select NUM_CONTEXT_PER_CLASS random images from this class
    indices = random.sample(class_to_indices[class_id], NUM_CONTEXT_PER_CLASS)
    
    for idx in indices:
        img, _ = cifar_dataset[idx]
        # Convert tensor to PIL Image
        img_pil = transforms.ToPILImage()(img)
        context_images.append(img_pil)
        context_labels.append(class_id)
    
    print(f"Class {class_id} ({CIFAR_CLASSES[class_id]}): loaded {NUM_CONTEXT_PER_CLASS} images")

print(f"\nTotal context images: {len(context_images)}")
print(f"Labels distribution: {context_labels}")

Loading context images from CIFAR-10...

Class 0 (airplane): loaded 5 images
Class 1 (automobile): loaded 5 images
Class 2 (bird): loaded 5 images
Class 3 (cat): loaded 5 images
Class 4 (deer): loaded 5 images
Class 5 (dog): loaded 5 images
Class 6 (frog): loaded 5 images
Class 7 (horse): loaded 5 images
Class 8 (ship): loaded 5 images
Class 9 (truck): loaded 5 images

Total context images: 50
Labels distribution: [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9]


In [5]:
# Set context - the model handles all normalization internally
model.set_context_images(context_images, context_labels)
print("Context set successfully")

Context set successfully


In [6]:
# Make a single test prediction
# Select a random test image
test_class = random.randint(0, 9)
test_idx = random.choice(class_to_indices[test_class])
test_image_tensor, test_label = cifar_dataset[test_idx]
test_image = transforms.ToPILImage()(test_image_tensor)

# Predict
prediction = model.predict(test_image)

print(f"Test image from class: {test_label} ({CIFAR_CLASSES[test_label]})")
print(f"Predicted class: {prediction} ({CIFAR_CLASSES[prediction] if 0 <= prediction < 10 else 'INVALID'})")
print(f"Result: {'✓ CORRECT' if prediction == test_label else '✗ INCORRECT'}")

Test image from class: 4 (deer)
Predicted class: 4 (deer)
Result: ✓ CORRECT


In [7]:
# Test accuracy over 50 different combinations (episodes)
# Each episode: new context set + new query image
num_episodes = 50
results = []

print(f"Testing {num_episodes} episodes with different context/query combinations...\n")

for episode in range(num_episodes):
    # Create new context set for this episode (5 images per class)
    episode_context_images = []
    episode_context_labels = []
    
    for class_id in range(10):
        indices = random.sample(class_to_indices[class_id], NUM_CONTEXT_PER_CLASS)
        for idx in indices:
            img, _ = cifar_dataset[idx]
            img_pil = transforms.ToPILImage()(img)
            episode_context_images.append(img_pil)
            episode_context_labels.append(class_id)
    
    # Set context for this episode
    model.set_context_images(episode_context_images, episode_context_labels)
    
    # Select random test image
    test_class = random.randint(0, 9)
    test_idx = random.choice(class_to_indices[test_class])
    test_image_tensor, test_label = cifar_dataset[test_idx]
    test_image = transforms.ToPILImage()(test_image_tensor)
    
    # Predict
    with torch.no_grad():
        prediction = model.predict(test_image)
    
    is_correct = (prediction == test_label)
    results.append(is_correct)
    
    # Print progress
    if (episode + 1) % 10 == 0:
        acc_so_far = sum(results) / len(results) * 100
        print(f"Episode {episode+1}/{num_episodes} - Accuracy: {acc_so_far:.1f}%")

# Final results
accuracy = sum(results) / len(results) * 100
print(f"\n{'='*60}")
print(f"Final Accuracy: {accuracy:.1f}% ({sum(results)}/{num_episodes} correct)")
print(f"{'='*60}")

Testing 50 episodes with different context/query combinations...

Episode 10/50 - Accuracy: 100.0%
Episode 20/50 - Accuracy: 95.0%
Episode 30/50 - Accuracy: 93.3%
Episode 40/50 - Accuracy: 95.0%
Episode 50/50 - Accuracy: 94.0%

Final Accuracy: 94.0% (47/50 correct)
