1. Install Dependencies

In [1]:
# HPS dependencies
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
! pip install hpsv2

# Stable Diffusion dependencies
! pip install diffusers

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-w_h8sd0t
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-w_h8sd0t
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-



In [2]:
!mkdir -p clip && wget https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz -P /usr/local/lib/python3.11/dist-packages/hpsv2/src/open_clip

--2025-03-04 02:24:12--  https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz
Resolving github.com (github.com)... 20.27.177.113
Connecting to github.com (github.com)|20.27.177.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/openai/CLIP/main/clip/bpe_simple_vocab_16e6.txt.gz [following]
--2025-03-04 02:24:12--  https://raw.githubusercontent.com/openai/CLIP/main/clip/bpe_simple_vocab_16e6.txt.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1356917 (1.3M) [application/octet-stream]
Saving to: ‘/usr/local/lib/python3.11/dist-packages/hpsv2/src/open_clip/bpe_simple_vocab_16e6.txt.gz’


2025-03-04 02:24:13 (5.71 MB/s) - ‘/usr/local/lib/python3.11/dist-packages/hpsv2/src/o

2. Imports

In [43]:
import os
import re
import random
from google.colab import drive

from abc import ABC, abstractmethod
from typing import Union, List, Dict, Optional

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from diffusers import DiffusionPipeline, StableDiffusionPipeline, StableDiffusion3Pipeline

import clip
import hpsv2
import PIL
from PIL import Image

3. Connect to Google Drive

In [4]:
drive.mount("/content/drive",force_remount=True)
os.chdir("/content/drive/My Drive")

Mounted at /content/drive


4. Model Code

In [5]:
class ModelLoadingError(Exception):
    """Exception raised when there is an error loading the model."""
    pass

class InferenceError(Exception):
    """Exception raised when an error occurs during inference."""
    pass


In [6]:
class BaseModel(ABC):
    @abstractmethod
    def load_model(self):
        """
        Load the open-weights model or make an API connection to the closed-source model.
        """
        pass

    @abstractmethod
    def inference(
        self, inputs: Union[List[str], torch.Tensor], captions: Optional[List[str]] = None
    ) -> Union[torch.Tensor, List[float]]:
        """
        Run inference on a batch of inputs with optional captions.

        Args:
            inputs (Union[List[str], torch.Tensor]): A batch of text prompts or a batch of images.
            captions (Optional[List[str]]): Optional text captions associated with the inputs for reward models.

        Returns:
            Union[torch.Tensor, List[float]]: A batch of model outputs or a list of reward scores.
        """
        pass

In [7]:
class HPSv1Model(BaseModel):
    def __init__(self, model_path: str):
        """
        Args:
            model_path (str): Path to the HPSv1 model checkpoint.
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_path = model_path
        self.load_model()

    def load_model(self):
        try:
            self.model, self.preprocess = clip.load("ViT-L/14", device=self.device)
            checkpoint = torch.load(self.model_path)

            if "state_dict" not in checkpoint:
                raise ModelLoadingError("Checkpoint does not contain 'state_dict'.")

            self.model.load_state_dict(checkpoint["state_dict"])

        except FileNotFoundError as e:
            raise ModelLoadingError(f"Model checkpoint not found at '{self.model_path}'.") from e
        except Exception as e:
            raise ModelLoadingError(f"Error loading model: {e}") from e

    def inference(self, inputs: torch.Tensor, captions: List[str]) -> List[float]:
        """
        Runs inference on a batch of images and corresponding captions.
        Returns a batch of reward scores.
        """
        if not isinstance(inputs, torch.Tensor):
            raise TypeError("Expected 'inputs' to be of type torch.Tensor (i.e. images).")
        if not isinstance(captions, list) or not all(isinstance(c, str) for c in captions):
            raise TypeError("Expected 'captions' to be a list of strings.")
        if inputs.shape[0] != len(captions):
            raise ValueError("Number of 'inputs' and 'captions' must match.")

        try:
            with torch.no_grad():
                image_features = self.model.encode_image(inputs.to(self.device))
                text_tokens = clip.tokenize(captions).to(self.device)
                text_features = self.model.encode_text(text_tokens)

                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                # Convert cosine similarity scores to percentages as in the original paper
                similarity_scores = (image_features @ text_features.T).diag() * 100

            return similarity_scores.tolist()
        except Exception as e:
            raise InferenceError(f"Inference failed: {e}") from e

In [8]:
class HPSv2Model(BaseModel):
    def __init__(self, model_path: str):
        """
        Args:
            model_path (str): Path to the HPSv2 model checkpoint. Must be
            either 'v2.0' or 'v2.1'.
        """
        if model_path not in ["v2.0", "v2.1"]:
            raise ValueError("Expected 'model_path' to be either 'v2.0' or 'v2.1'.")
        self.model_path = model_path
        self.load_model()

    def load_model(self):
        try:
            temp_image = PIL.Image.new("RGB", (256, 256), color=(255, 255, 255))
            _ = hpsv2.score(temp_image, '<prompt>', hps_version="v2.0") # Also caches the model

        except FileNotFoundError as e:
            raise ModelLoadingError(f"Model checkpoint not found at '{self.model_path}'.") from e
        except Exception as e:
            raise ModelLoadingError(f"Error loading model: {e}") from e

    def inference(self, inputs: List[PIL.Image], captions: List[str]) -> List[float]:
        """
        Runs inference on a batch of images and corresponding captions.
        Returns a batch of reward scores.
        """
        if not isinstance(inputs, list) or not all(isinstance(i, PIL.Image.Image) for i in inputs):
            raise TypeError("Expected 'inputs' to be a list of PIL.Image objects.")
        if not isinstance(captions, list) or not all(isinstance(c, str) for c in captions):
            raise TypeError("Expected 'captions' to be a list of strings.")
        if len(inputs) != len(captions):
            raise ValueError("Number of 'inputs' and 'captions' must match.")

        try:
            similarity_scores = []
            for i in range(len(inputs)):
                reward = hpsv2.score(inputs[i], captions[i], hps_version=self.model_path)
                similarity_scores.append(reward[0] * 100)
            return similarity_scores

        except Exception as e:
            raise InferenceError(f"Inference failed: {e}") from e

In [9]:
class BaseDiffusionModel(BaseModel):
    def __init__(self, model_path: str, offload_to_cpu: bool = False, resolution: int = None, **kwargs):
        """
        Args:
            model_path (str): Path or repository ID of the diffusion model checkpoint.
        """
        self.seed = 42

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_path = model_path
        self.offload_to_cpu = offload_to_cpu
        self.resolution = resolution
        self.kwargs = kwargs

        self.diffusion_pipeline = self._get_diffusion_pipeline()
        self.load_model()

    def _get_diffusion_pipeline(self):
        """ Subclasses should override this to return the correct pipeline. """
        return DiffusionPipeline

    def load_model(self):
        try:
            self.model = self.diffusion_pipeline.from_pretrained(
                self.model_path,
                **self.kwargs
            ).to(self.device)
            if self.offload_to_cpu:
                self.model.enable_model_cpu_offload()

        except MemoryError as e:
            if hasattr(self, "model"):
                del self.model
                torch.cuda.empty_cache()
            raise ModelLoadingError(f"Memory error occurred while loading the model. Consider using a smaller model: {e}")
        except FileNotFoundError as e:
            raise ModelLoadingError(f"Model checkpoint not found at '{self.model_path}'.") from e
        except Exception as e:
            raise ModelLoadingError(f"Failed to load diffusion model: {e}") from e

    def inference(
        self, inputs: List[str], captions: Optional[List[str]] = None
    ):
        """
        Runs inference on a batch of prompts.
        Returns a batch of images corresponding to the prompts.
        """
        if not isinstance(inputs, list) or not all(isinstance(c, str) for c in inputs):
            raise TypeError("Expected 'inputs' to be a list of strings.")

        try:
            # Create one generator per prompt to ensure reproducibility
            generators = [
                torch.Generator(self.device).manual_seed(self.seed) for _ in range(len(inputs))
            ]
            if self.resolution:
                images = self.model(
                    prompt=inputs, generator=generators,
                    height=self.resolution, width=self.resolution # use 1:1 aspect ratio
                ).images
                return images
            else:
                images = self.model(
                    prompt=inputs, generator=generators,
                ).images
                return images

        except Exception as e:
            raise InferenceError(f"Inference failed: {e}")

In [10]:
class StableDiffusionModel(BaseDiffusionModel):
    def __init__(self, model_path: str, offload_to_cpu: bool = False, resolution: int = None, **kwargs):
        """
        Note:
            model_path (str): Path to the Stable Diffusion model.
                              Must include 'stable-diffusion-1', 'stable-diffusion-2', or 'stable-diffusion-3' after '<repo-owner>/'
                              for simplicity.
        """

        # Load the model with float16 precision.
        # If your GPU supports torch.bfloat16 for lower memory usage with similar precision to FP32,
        # consider switching the torch_dtype accordingly.
        if "torch_dtype" not in kwargs:
            kwargs["torch_dtype"] = torch.float16
        super().__init__(model_path, offload_to_cpu, resolution, **kwargs)

    def _get_diffusion_pipeline(self):
        version_tag = self.model_path.split("/")[-1].lower()
        if re.search(r'(stable-diffusion-1|v-?1)', version_tag):
            return StableDiffusionPipeline
        elif re.search(r'(stable-diffusion-2|v-?2)', version_tag):
            return DiffusionPipeline
        elif re.search(r'(stable-diffusion-3|v-?3)', version_tag):
            return StableDiffusion3Pipeline
        else:
            raise ModelLoadingError(
                "Model path must contain one of: 'stable-diffusion-1', 'stable-diffusion-2', or 'stable-diffusion-3'."
            )

5. Test HPSv2 Benchmark

In [26]:
class DatasetFormatError(Exception):
    """Raised when the dataset format is incorrect."""
    pass

class DatasetLoadingError(Exception):
    """Raised when the dataset fails to load properly."""
    pass

In [40]:
class BasePromptDataset(Dataset, ABC):
    def __init__(self):
        try:
            self.data = self.load_dataset()
        except Exception as e:
            raise DatasetLoadingError(f"Failed to load dataset: {e}")

        if not isinstance(self.data, dict):
            raise DatasetFormatError(f"Expected 'load_dataset()' to return a dictionary, got '{type(self.data)}'.")

        for key, prompts in self.data.items():
            if not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
                raise DatasetFormatError(f"Expected a list of strings for category '{key}', but got '{type(prompts)}'")

        # Precompute samples with round-robin ordering
        self.samples = self._create_round_robin_samples()

    @abstractmethod
    def load_dataset(self) -> Dict[str, List[str]]:
        """To be implemented by subclasses."""
        pass

    def _create_round_robin_samples(self) -> List[Dict[str, str]]:
        """Ensure fair round-robin interleaving of prompts from all categories."""
        samples = []
        categories = list(self.data.keys())
        category_prompts = [self.data[cat] for cat in categories]

        if not categories or all(len(prompts) == 0 for prompts in category_prompts):
            raise DatasetFormatError("Dataset is empty or contains only empty categories.")

        max_length = max(len(prompts) for prompts in category_prompts)

        # Round-robin interleaving
        for i in range(max_length):
            for cat_idx, category in enumerate(categories):
                prompts = category_prompts[cat_idx]
                if len(prompts) > 0:
                    prompt = prompts[i % len(prompts)]  # Cycle back for shorter lists
                    samples.append({"category": category, "prompt": prompt})

        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


In [44]:
class RoundRobinSampler(torch.utils.data.Sampler):
    def __init__(self, dataset: BasePromptDataset):
        self.dataset = dataset
        self.indices = self._generate_indices()

    def _generate_indices(self):
        """
        Assume dataset.data has equal length lists per category.

        For each category, create a shuffled list of indices corresponding to that category's samples.
        Since BasePromptDataset precomputes samples in round-robin order, we need to map from category + position
        to the flat sample index.

        In our round-robin samples, the ordering is:
        index 0: category1, index 1: category2, ..., index N: category1

        Let K = number of categories,
        Then the sample index for category j at position i is: i*K + j.
        """
        categories = list(self.dataset.data.keys())
        num_per_category = len(next(iter(self.dataset.data.values())))
        K = len(categories)

        category_indices = {}
        for j, cat in enumerate(categories):
            indices = [i * K + j for i in range(num_per_category)]
            random.shuffle(indices)
            category_indices[cat] = indices

        ordered_indices = []
        for i in range(num_per_category):
            for cat in categories:
                ordered_indices.append(category_indices[cat][i])
        return ordered_indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


In [38]:
class HPSV2PromptDataset(BasePromptDataset):
    def load_dataset(self) -> Dict[str, List[str]]:
        all_prompts = hpsv2.benchmark_prompts("all")
        return dict(all_prompts.items())

In [50]:
dataset = HPSV2PromptDataset()
sampler = RoundRobinSampler(dataset)
dataloader = DataLoader(dataset, batch_size=5, sampler=sampler)

count = 0
for batch in dataloader:
    print(batch)
    count += 1
    if count == 10:
      break

{'category': ['anime', 'concept-art', 'paintings', 'photo', 'anime'], 'prompt': ['Anime oil painting of Rem from Re Zero.', 'The image is titled "Burning Memory" and features dark, dramatic, and highly detailed artwork by multiple artists, depicting a scene from the video game Bloodborne.', 'Oil painting of a man under a tree in the rain, by Greg Rutkowski.', 'A desk sitting next to a showroom of cars in it.', 'A digital painting of a cyberpunk anime woman with intricate and highly detailed features.']}
{'category': ['concept-art', 'paintings', 'photo', 'anime', 'concept-art'], 'prompt': ['A surreal image featuring a rainbow and neon glow with a biohazard scientist in a laboratory evacuation scene, showcasing a mix of gothic and neo-gothic styles with rich colors.', 'A creepy man dressed as a chicken is frightening children in a painting by Bussiere, Mullins, and Leyendecker.', 'A group of waiters standing in a line. ', 'A young man in a small Tokyo room with an open window sits at his