1. Install Dependencies

In [1]:
# HPS dependencies
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
! pip install hpsv2

# Stable Diffusion dependencies
! pip install diffusers

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-4x58psef
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-4x58psef
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
!mkdir -p clip && wget https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz -P /usr/local/lib/python3.11/dist-packages/hpsv2/src/open_clip

--2025-03-08 02:56:36--  https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/openai/CLIP/main/clip/bpe_simple_vocab_16e6.txt.gz [following]
--2025-03-08 02:56:36--  https://raw.githubusercontent.com/openai/CLIP/main/clip/bpe_simple_vocab_16e6.txt.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1356917 (1.3M) [application/octet-stream]
Saving to: ‘/usr/local/lib/python3.11/dist-packages/hpsv2/src/open_clip/bpe_simple_vocab_16e6.txt.gz.2’


2025-03-08 02:56:36 (174 MB/s) - ‘/usr/local/lib/python3.11/dist-packages/hpsv2/sr

2. Imports

In [34]:
import os
import re
import gc
from datetime import datetime
import random
import argparse
from tqdm import tqdm
from google.colab import drive

from abc import ABC, abstractmethod
from typing import Union, List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from diffusers import DiffusionPipeline, StableDiffusionPipeline, StableDiffusion3Pipeline

import clip
import hpsv2
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
import PIL
from PIL import Image

3. Connect to Google Drive

In [4]:
drive.mount("/content/drive",force_remount=True)
os.chdir("/content/drive/My Drive")

Mounted at /content/drive


4. Model Code

In [5]:
class ModelLoadingError(Exception):
    """Exception raised when there is an error loading the model."""
    pass

class InferenceError(Exception):
    """Exception raised when an error occurs during inference."""
    pass

In [6]:
class BaseModel(ABC):
    @abstractmethod
    def load_model(self):
        """
        Load the open-weights model or make an API connection to the closed-source model.
        """
        pass

    @abstractmethod
    def inference(
        self, inputs: Union[List[str], torch.Tensor], captions: Optional[List[str]] = None
    ) -> Union[torch.Tensor, List[float]]:
        """
        Run inference on a batch of inputs with optional captions.

        Args:
            inputs (Union[List[str], torch.Tensor]): A batch of text prompts or a batch of images.
            captions (Optional[List[str]]): Optional text captions associated with the inputs for reward models.

        Returns:
            Union[torch.Tensor, List[float]]: A batch of model outputs or a list of reward scores.
        """
        pass

In [7]:
class HPSv1Model(BaseModel):
    def __init__(self, model_path: str):
        """
        Args:
            model_path (str): Path to the HPSv1 model checkpoint.
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_path = model_path
        self.load_model()

    def load_model(self):
        try:
            self.model, self.preprocess_function = clip.load("ViT-L/14", device=self.device)
            checkpoint = torch.load(self.model_path)

            if "state_dict" not in checkpoint:
                raise ModelLoadingError("Checkpoint does not contain 'state_dict'.")

            self.model.load_state_dict(checkpoint["state_dict"])
            self.tokenizer = clip.tokenize
            self.model.eval()

        except FileNotFoundError as e:
            raise ModelLoadingError(f"Model checkpoint not found at '{self.model_path}'.") from e
        except Exception as e:
            raise ModelLoadingError(f"Error loading model: {e}") from e

    def inference(self, inputs: torch.Tensor, captions: Union[List[str], torch.Tensor]) -> List[float]:
        """
        Runs inference on a batch of images and corresponding captions.
        Returns a batch of reward scores.
        """
        if not isinstance(inputs, torch.Tensor):
            raise TypeError("Expected 'inputs' to be of type torch.Tensor (i.e. images).")
        if not isinstance(captions, list) or not all(isinstance(c, str) for c in captions):
            raise TypeError("Expected 'captions' to be a list of strings.")
        if inputs.shape[0] != len(captions):
            raise ValueError("Number of 'inputs' and 'captions' must match.")

        try:
            with torch.no_grad():
                image_features = self.model.encode_image(inputs.to(self.device))

                if not isinstance(captions, torch.Tensor):
                    text_tokens = self.tokenizer(captions).to(self.device)
                else:
                    text_tokens = captions.to(self.device)
                text_features = self.model.encode_text(text_tokens)

                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                # Convert cosine similarity scores to percentages as in the original paper
                similarity_scores = (image_features @ text_features.T).diag() * 100
            return similarity_scores.tolist()
        except Exception as e:
            raise InferenceError(f"Inference failed: {e}") from e

    def inference_with_grad(self, inputs: torch.Tensor, captions: List[str]) -> List[float]:
        """
        Runs inference on a batch of images and corresponding captions.
        Returns a batch of reward scores.
        """
        if not isinstance(inputs, torch.Tensor):
            raise TypeError("Expected 'inputs' to be of type torch.Tensor (i.e. images).")
        if not isinstance(captions, list) or not all(isinstance(c, str) for c in captions):
            raise TypeError("Expected 'captions' to be a list of strings.")
        if inputs.shape[0] != len(captions):
            raise ValueError("Number of 'inputs' and 'captions' must match.")

        try:
            text_tokens = clip.tokenize(captions).to(self.device)
            image_features, text_features = self.model(inputs, text_tokens)
            return (image_features @ text_features.T).diag() * 100
        except Exception as e:
            raise InferenceError(f"Inference failed: {e}") from e

In [8]:
class HPSv2Model(BaseModel):
    def __init__(self, model_path: str):
        """
        Args:
            model_path (str): Path to the HPSv2 model checkpoint.
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_path = model_path
        self.load_model()

    def load_model(self):
        try:
            self.model, _, self.preprocess_function = create_model_and_transforms(
                "ViT-H-14",
                "laion2B-s32B-b79K",
                precision="amp",
                device=self.device,
                jit=False,
                force_quick_gelu=False,
                force_custom_text=False,
                force_patch_dropout=False,
                force_image_size=None,
                pretrained_image=False,
                image_mean=None,
                image_std=None,
                light_augmentation=True,
                aug_cfg={},
                output_dict=True,
                with_score_predictor=False,
                with_region_predictor=False
            )

            checkpoint = torch.load(self.model_path)
            if "state_dict" not in checkpoint:
                raise ModelLoadingError("Checkpoint does not contain 'state_dict'.")

            self.model.load_state_dict(checkpoint["state_dict"])
            self.tokenizer = get_tokenizer("ViT-H-14")
            self.model.eval()

        except FileNotFoundError as e:
            raise ModelLoadingError(f"Model checkpoint not found at '{self.model_path}'.") from e
        except Exception as e:
            raise ModelLoadingError(f"Error loading model: {e}") from e

    def inference(self, inputs: torch.Tensor, captions: Union[List[str], torch.Tensor]) -> List[float]:
        """
        Runs inference on a batch of images and corresponding captions.
        Returns a batch of reward scores.
        """
        if not isinstance(inputs, list) or not all(isinstance(i, PIL.Image.Image) for i in inputs):
            raise TypeError("Expected 'inputs' to be a list of PIL.Image objects.")
        if not isinstance(captions, list) or not all(isinstance(c, str) for c in captions):
            raise TypeError("Expected 'captions' to be a list of strings.")
        if len(inputs) != len(captions):
            raise ValueError("Number of 'inputs' and 'captions' must match.")

        try:
            with torch.no_grad():
                if not isinstance(captions, torch.Tensor):
                    text_tokens = self.tokenizer(captions).to(self.device)
                else:
                    text_tokens = captions.to(self.device)

                with torch.cuda.amp.autocast():
                    outputs = self.model(inputs, text_tokens)
                    image_features, text_features = outputs["image_features"], outputs["text_features"]
                    similarity_scores = (image_features @ text_features.T).diag() * 100
                return similarity_scores.tolist()

        except Exception as e:
            raise InferenceError(f"Inference failed: {e}") from e


    def inference_with_grad(self, inputs: torch.Tensor, captions: List[str]) -> List[float]:
        """
        Runs inference on a batch of images and corresponding captions.
        Returns a batch of reward scores.
        """
        if not isinstance(inputs, list) or not all(isinstance(i, PIL.Image.Image) for i in inputs):
            raise TypeError("Expected 'inputs' to be a list of PIL.Image objects.")
        if not isinstance(captions, list) or not all(isinstance(c, str) for c in captions):
            raise TypeError("Expected 'captions' to be a list of strings.")
        if len(inputs) != len(captions):
            raise ValueError("Number of 'inputs' and 'captions' must match.")

        try:
            text_tokens = self.tokenizer(captions).to(self.device)

            with torch.cuda.amp.autocast():
                outputs = self.model(inputs, text_tokens)
                image_features, text_features = outputs["image_features"], outputs["text_features"]
                return (image_features @ text_features.T).diag() * 100

        except Exception as e:
            raise InferenceError(f"Inference failed: {e}") from e

In [9]:
class BaseDiffusionModel(BaseModel):
    def __init__(self, model_path: str, offload_to_cpu: bool = False, resolution: int = None, **kwargs):
        """
        Args:
            model_path (str): Path or repository ID of the diffusion model checkpoint.
        """
        self.seed = 42

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_path = model_path
        self.offload_to_cpu = offload_to_cpu
        self.resolution = resolution
        self.kwargs = kwargs

        self.diffusion_pipeline = self._get_diffusion_pipeline()
        self.load_model()

    def _get_diffusion_pipeline(self):
        """ Subclasses should override this to return the correct pipeline. """
        return DiffusionPipeline

    def load_model(self):
        try:
            self.model = self.diffusion_pipeline.from_pretrained(
                self.model_path,
                **self.kwargs
            ).to(self.device)
            if self.offload_to_cpu:
                self.model.enable_model_cpu_offload()

        except MemoryError as e:
            if hasattr(self, "model"):
                del self.model
                torch.cuda.empty_cache()
            raise ModelLoadingError(f"Memory error occurred while loading the model. Consider using a smaller model: {e}")
        except FileNotFoundError as e:
            raise ModelLoadingError(f"Model checkpoint not found at '{self.model_path}'.") from e
        except Exception as e:
            raise ModelLoadingError(f"Failed to load diffusion model: {e}") from e

    def inference(
        self, inputs: List[str], captions: Optional[List[str]] = None
    ):
        """
        Runs inference on a batch of prompts.
        Returns a batch of images corresponding to the prompts.
        """
        if not isinstance(inputs, list) or not all(isinstance(c, str) for c in inputs):
            raise TypeError("Expected 'inputs' to be a list of strings.")

        try:
            # Create one generator per prompt to ensure reproducibility
            generators = [
                torch.Generator(self.device).manual_seed(self.seed) for _ in range(len(inputs))
            ]
            if self.resolution:
                images = self.model(
                    prompt=inputs, generator=generators,
                    height=self.resolution, width=self.resolution # use 1:1 aspect ratio
                ).images
                return images
            else:
                images = self.model(
                    prompt=inputs, generator=generators,
                ).images
                return images

        except Exception as e:
            raise InferenceError(f"Inference failed: {e}")

In [47]:
class StableDiffusionModel(BaseDiffusionModel):
    def __init__(self, model_path: str, offload_to_cpu: bool = False, resolution: int = None, **kwargs):
        """
        Note:
            model_path (str): Path to the Stable Diffusion model.
                              Must include 'stable-diffusion-1', 'stable-diffusion-2', or 'stable-diffusion-3' after '<repo-owner>/'
                              for simplicity.
        """

        # Load the model with float16 precision.
        # If your GPU supports torch.bfloat16 for lower memory usage with similar precision to FP32,
        # consider switching the torch_dtype accordingly.
        if "torch_dtype" not in kwargs:
            kwargs["torch_dtype"] = torch.float16
        super().__init__(model_path, offload_to_cpu, resolution, **kwargs)

    def _get_diffusion_pipeline(self):
        version_tag = self.model_path.split("/")[-1].lower()

        if re.search(r'(stable-diffusion-?(v-?|v)?1(?:-\d+)?)(.*)?$', version_tag):
            return StableDiffusionPipeline
        elif re.search(r'(stable-diffusion-?(v-?|v)?2(?:-\d+)?)(.*)?$', version_tag):
            return DiffusionPipeline
        elif re.search(r'(stable-diffusion-?(v-?|v)?3(?:-\d+)?)(.*)?$', version_tag):
            return StableDiffusion3Pipeline
        else:
            raise ValueError(
                "Model path must match 'stable-diffusion-1', 'stable-diffusion-v1', 'stable-diffusion-v-1', "
                "'stable-diffusion-2', 'stable-diffusion-v2', etc."
            )

In [11]:
class ModelFactory:
    @staticmethod
    def create_model(
        model_type: str, model_path: str,
        **kwargs,
    ) -> BaseModel:
        """
        Creates and returns an instance of a model subclass based on the model_type.

        Args:
            model_type (str): The type of model to create. Supported values are:
                - "hpsv1": For HPSv1 reward models.
                - "hpsv2": For HPSv2 reward models.
                - "sd": For stable diffusion text-to-image models.
            model_path (str): The path or repository ID of the model checkpoint.

        Returns:
            BaseModel: An instance of the requested model.

        Raises:
            ValueError: If an unsupported model_type is provided.
        """
        if model_type == "hpsv1":
            return HPSv1Model(model_path)
        elif model_type == "hpsv2":
            return HPSv2Model(model_path)
        elif model_type == "sd":
            return StableDiffusionModel(model_path, **kwargs)
        else:
            raise ValueError("Unsupported model type. Use 'sd' for stable diffusion models or 'hps' for HPS models.")

5. Dataset Code

In [12]:
class DatasetFormatError(Exception):
    """Raised when the dataset format is incorrect."""
    pass

class DatasetLoadingError(Exception):
    """Raised when the dataset fails to load properly."""
    pass

In [13]:
class BasePromptDataset(Dataset, ABC):
    def __init__(self):
        try:
            self.data = self.load_dataset()
        except Exception as e:
            raise DatasetLoadingError(f"Failed to load dataset: {e}")

        if not isinstance(self.data, dict):
            raise DatasetFormatError(f"Expected 'load_dataset()' to return a dictionary, got '{type(self.data)}'.")

        for key, prompts in self.data.items():
            if not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
                raise DatasetFormatError(f"Expected a list of strings for category '{key}', but got '{type(prompts)}'")

        # Precompute samples with round-robin ordering
        self.samples = self._create_round_robin_samples()

    @abstractmethod
    def load_dataset(self) -> Dict[str, List[str]]:
        """To be implemented by subclasses."""
        pass

    def _create_round_robin_samples(self) -> List[Dict[str, str]]:
        """Ensure fair round-robin interleaving of prompts from all categories."""
        samples = []
        categories = list(self.data.keys())
        category_prompts = [self.data[cat] for cat in categories]

        if not categories or all(len(prompts) == 0 for prompts in category_prompts):
            raise DatasetFormatError("Dataset is empty or contains only empty categories.")

        max_length = max(len(prompts) for prompts in category_prompts)

        # Round-robin interleaving
        for i in range(max_length):
            for cat_idx, category in enumerate(categories):
                prompts = category_prompts[cat_idx]
                if len(prompts) > 0:
                    prompt = prompts[i % len(prompts)]  # Cycle back for shorter lists
                    samples.append({"category": category, "prompt": prompt})

        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

    def num_categories(self) -> int:
        """Returns the number of unique categories in the dataset."""
        return len(self.data)

In [14]:
class HPSV2PromptDataset(BasePromptDataset):
    def load_dataset(self) -> Dict[str, List[str]]:
        all_prompts = hpsv2.benchmark_prompts("all")
        return dict(all_prompts.items())

In [15]:
class DrawBenchPromptDataset(BasePromptDataset):
    def load_dataset(self) -> Dict[str, List[str]]:
        df = pd.read_csv("drawbench_data.csv")
        return df.groupby("Category")["Prompts"].apply(list).to_dict()

In [16]:
class ImagePromptDataset(Dataset):
    def __init__(
            self,
            image_list: List[PIL.Image], prompt_list: List[Tuple[str, str]],
            image_transform_function: callable, text_tokenizer_function: callable = None
        ):
        """
        Args:
            image_list (List[PIL.Image]): List of PIL images.
            prompt_list (List[Tuple[str, str]]): List of (category, prompt) tuples.
            image_transform_function (callable): Function to transform PIL images.
            text_tokenizer_function (callable): Function to tokenize text prompts.
        """
        if len(image_list) == 0 or len(prompt_list) == 0:
            raise DatasetFormatError("Both image_list and prompt_list must be non-empty.")
        if len(image_list) != len(prompt_list):
            raise DatasetFormatError("Images and prompts must have the same length.")

        self.images = image_list
        self.prompts = prompt_list  # List of (category, prompt)
        self.image_transform_function = image_transform_function
        self.text_tokenizer_function = text_tokenizer_function

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.image_transform_function(self.images[idx])
        _, prompt = self.prompts[idx]
        if self.text_tokenizer_function is None:
            tokens = prompt
        else:
            tokens = self.text_tokenizer_function(prompt)

        return image, tokens

In [17]:
class RoundRobinSampler(torch.utils.data.Sampler):
    def __init__(self, dataset: BasePromptDataset):
        self.dataset = dataset
        self.indices = self._generate_indices()

    def _generate_indices(self):
        """
        Assume dataset.data has equal length lists per category.

        For each category, create a shuffled list of indices corresponding to that category's samples.
        Since BasePromptDataset precomputes samples in round-robin order, we need to map from category + position
        to the flat sample index.

        In our round-robin samples, the ordering is:
        index 0: category1, index 1: category2, ..., index N: category1

        Let K = number of categories,
        Then the sample index for category j at position i is: i*K + j.
        """
        categories = list(self.dataset.data.keys())
        num_per_category = len(next(iter(self.dataset.data.values())))
        K = len(categories)

        category_indices = {}
        for j, cat in enumerate(categories):
            indices = [i * K + j for i in range(num_per_category)]
            random.shuffle(indices)
            category_indices[cat] = indices

        ordered_indices = []
        for i in range(num_per_category):
            for cat in categories:
                ordered_indices.append(category_indices[cat][i])
        return ordered_indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)

In [18]:
class DatasetFactory:
    @staticmethod
    def create_dataset(
        dataset_type: str,
        **kwargs,
    ) -> Union[BasePromptDataset, ImagePromptDataset]:

        if dataset_type == "drawbench":
            return DrawBenchPromptDataset()
        elif dataset_type == "hps":
            return HPSV2PromptDataset()
        elif dataset_type == "imageandprompt":
            return ImagePromptDataset(**kwargs)
        else:
            raise ValueError(f"Unknown dataset type: '{dataset_type}'.")

6. Define Arguments and Utils

In [48]:
def check_target_model(value):
    pattern = r'^[^/]+/(stable-diffusion-?(v-?|v)?[123](?:-\d+)?)(.*)?$'

    if not re.match(pattern, value):
        raise argparse.ArgumentTypeError(
            "target_model_name must be in the format '<repo-owner>/stable-diffusion-[1|2|3]', "
            "'<repo-owner>/stable-diffusion-v[1|2|3]', or '<repo-owner>/stable-diffusion-v-[1|2|3]'."
        )
    return value

def check_dataset_name(value):
    if value not in ['hps', 'drawbench']:
        raise argparse.ArgumentTypeError(
            "dataset_name must be either 'hps' or 'drawbench'.")
    return value

def parse_model_args():
    parser = argparse.ArgumentParser(
        description="Argument partser for image generation process."
    )

    # Models group
    models = parser.add_argument_group("models")
    models.add_argument("--target_model_name", type=check_target_model, required=True,
                        help="HuggingFace model ID in format <repo-owner>/stable-diffusion-[1|2|3]")

    # Datasets group
    datasets = parser.add_argument_group("datasets")
    datasets.add_argument("--dataset_name", type=check_dataset_name, required=True,
                        help="Dataset for generating preliminary images: 'hps' or 'drawbench'")
    datasets.add_argument("--num_samples_per_category", type=int, default=None,
                        help="Number of text prompts per category (default: 5 for hps, 2 for drawbench)")
    datasets.add_argument("--shuffle", action="store_true",
                        help="Shuffle prompts prior to sampling (default: False)")

    # Misc group
    misc = parser.add_argument_group("misc")
    misc.add_argument("--inference_batch_size", type=int, default=4,
                        help="Batch size for target model inference (default: 4)")
    misc.add_argument("--no_save_image_results", dest="save_image_results", action="store_false",
                        help="Do not store images, prompts, and reward scores that pass threshold")
    misc.set_defaults(save_image_results=True)

    args = parser.parse_args()
    if args.num_samples_per_category is None:
        if args.dataset_name == "hps":
            args.num_samples_per_category = 5
        else:  # drawbench
            args.num_samples_per_category = 2
    return args

In [20]:
def clear_cuda_memory_and_force_gc(force: bool = False):
    """
    Clears the CUDA memory cache and forces garbage collection if the allocated memory
    exceeds a certain threshold or if explicitly forced.

    Args:
        force (bool): If True, CUDA cache will be cleared and garbage collection
                      will be forced regardless of the memory threshold.
    """

    memory_allocated = torch.cuda.max_memory_reserved()
    memory_total = torch.cuda.get_device_properties("cuda").total_memory

    memory_threshold = memory_total * 0.7
    if memory_allocated > memory_threshold or force:
        torch.cuda.empty_cache()
        gc.collect()

In [21]:
class SampledDataset(Dataset):
    def __init__(self, prompts):
        self.data = [{"category": c, "prompt": p} for c, p in zip(prompts["category"], prompts["prompt"])]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

7. Generate Base Images

In [40]:
def generate_images(generate_image_args):
    print(generate_image_args)
    kwargs = {}

    if re.search(r'(stable-diffusion-2|v-?2)', generate_image_args.target_model_name):
        kwargs = {
            "resolution": 512,
        }
    if re.search(r'(stable-diffusion-3|v-?3)', generate_image_args.target_model_name):
        kwargs = {
            "resolution": 1024,
            "offload_to_cpu": True,
            "text_encoder_3": None,
            "tokenizer_3": None,
            "token": "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"

        }

    model = ModelFactory.create_model(
        model_type="sd",
        model_path=generate_image_args.target_model_name,
        **kwargs,
    )

    dataset = DatasetFactory.create_dataset(
        dataset_type=generate_image_args.dataset_name,
    )

    # Generate twice as many images as the number of samples per category for safety
    num_images_to_gen = 2 * generate_image_args.num_samples_per_category * dataset.num_categories()
    dataset_loader = DataLoader(
        dataset,
        batch_size=num_images_to_gen,
        sampler=RoundRobinSampler(dataset) if generate_image_args.shuffle else None,
    )

    prompts = next(iter(dataset_loader))
    sampled_dataset = SampledDataset(prompts)
    sampled_dataset_loader = DataLoader(sampled_dataset, batch_size=generate_image_args.inference_batch_size, shuffle=False)

    final_images = []
    final_prompts = []
    total_batches = len(sampled_dataset_loader)
    pbar = tqdm(total=total_batches, desc="Generating images from prompts")

    for batch in sampled_dataset_loader:
        prompts = batch["prompt"]
        categories = batch["category"]
        images = model.inference(inputs=prompts)
        final_images.extend(images)
        final_prompts.extend([(category, prompt) for category, prompt in zip(categories, prompts)])
        pbar.update(1)
    pbar.close()

    if generate_image_args.save_image_results:
        timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        output_dir = os.path.join(f"outputs/{generate_image_args.target_model_name.split('/')[1]}/{generate_image_args.dataset_name}/{timestamp}/")
        os.makedirs(output_dir)

        prompts_file = os.path.join(output_dir, "prompts.txt")
        with open(prompts_file, "w") as pf:
            for idx, (img, prompt) in enumerate(zip(final_images, final_prompts)):
                image_filename = os.path.join(output_dir, f"image_{idx}.png")
                img.save(image_filename)
                pf.write(f"Image {idx}: {prompt}\n")


    clear_cuda_memory_and_force_gc(force=True)

In [58]:
import argparse
import sys

sys.argv = [
    "script_name",  # Placeholder for script name (ignored by argparse)
    "--target_model_name", "stabilityai/stable-diffusion-3-medium-diffusers",
    "--dataset_name", "drawbench",
]

args = parse_model_args()
generate_images(args)

Namespace(target_model_name='stabilityai/stable-diffusion-3-medium-diffusers', dataset_name='drawbench', num_samples_per_category=2, shuffle=False, inference_batch_size=4, save_image_results=True)


model_index.json:   0%|          | 0.00/706 [00:00<?, ?B/s]

Fetching 18 files:   0%|          | 0/18 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/247M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/574 [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/705 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/588 [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/576 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/856 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/372 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/4.17G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/168M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/739 [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Generating images from prompts:   0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts:   9%|▉         | 1/11 [02:28<24:49, 148.94s/it]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts:  18%|█▊        | 2/11 [04:56<22:12, 148.09s/it]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts:  27%|██▋       | 3/11 [07:23<19:39, 147.43s/it]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts:  36%|███▋      | 4/11 [09:49<17:09, 147.09s/it]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts:  45%|████▌     | 5/11 [12:16<14:41, 146.99s/it]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts:  55%|█████▍    | 6/11 [14:43<12:14, 146.91s/it]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts:  64%|██████▎   | 7/11 [17:09<09:47, 146.80s/it]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts:  73%|███████▎  | 8/11 [19:36<07:20, 146.81s/it]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts:  82%|████████▏ | 9/11 [22:03<04:53, 146.82s/it]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts:  91%|█████████ | 10/11 [24:30<02:26, 146.76s/it]

  0%|          | 0/28 [00:00<?, ?it/s]

Generating images from prompts: 100%|██████████| 11/11 [26:56<00:00, 146.96s/it]


In [57]:
clear_cuda_memory_and_force_gc(force=True)