## Text Annotation to Synthetic Image Pipeline

This notebook demonstrates using advanced diffusion models to generate synthetic images locally using descriptive text annotations (captions) from BLIp-2 captioning (formatted as JSONs).

In [None]:
# Standard library imports
import os
import json
import gc
import logging
import time
from multiprocessing import Pool, cpu_count, current_process, Manager, get_context

# Third-party library imports
import torch
from PIL import Image
from torchvision.transforms import ToPILImage
from transformers import pipeline
from diffusers import DiffusionPipeline
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from IPython.display import display
import warnings

# Local/application-specific imports
import bittensor as bt
from bitmind.constants import PROMPT_GENERATOR_NAMES, PROMPT_GENERATOR_ARGS, DIFFUSER_NAMES, DIFFUSER_ARGS
from multiprocessing_tasks import worker_initializer, generate_images_for_chunk

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Suppress TensorFlow logging (1: filter out INFO, 2: additionally filter out WARNING, 3: additionally filter out ERROR)
import tensorflow as tf  # Import TensorFlow after setting the log level

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO)
# Suppress FutureWarnings from diffusers module
warnings.filterwarnings("ignore", category=FutureWarning, module='diffusers')
# Set device for model operations
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
if device == "cpu":
    raise RuntimeError("This script requires a GPU because it uses torch.float16.")  # Added check for GPU availability
# Ensure that this script uses 'spawn' method for starting multiprocessing tasks
ctx = get_context("spawn")

In [None]:
def list_datasets(base_dir):
    """List all subdirectories in the base directory."""
    return [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]

def load_annotations(base_dir, dataset):
    """Load annotations from JSON files within a specified directory."""
    annotations = []
    path = os.path.join(base_dir, dataset)
    for filename in os.listdir(path):
        if filename.endswith(".json"):
            with open(os.path.join(path, filename), 'r') as file:
                data = json.load(file)
                annotations.append(data)
    return annotations

def load_diffuser(model_name):
    """Load a diffusion model by name, configured to provided arguments."""
    bt.logging.info(f"Loading image generation model ({model_name})...")
    model = DiffusionPipeline.from_pretrained(
        model_name, torch_dtype=torch.float32 if device == "cpu" else torch.float16, **DIFFUSER_ARGS[model_name]
    )
    model.to(device)
    return model

In [None]:
## GPU
def generate_images(annotations, diffuser, save_dir, num_images, batch_size, diffuser_name):
    """Generate images from annotations using a diffuser and save to the specified directory."""
    # Ensure the directory exists
    os.makedirs(save_dir, exist_ok=True)
    
    generated_images = []
    start_time = time.time()

    with torch.no_grad():
        for i in range(min(num_images, len(annotations))):
            start_loop = time.time()
            annotation = annotations[i]
            prompt = annotation['description']
            index = annotation.get('index', f"missing_index")

            logging.info(f"Annotation {i}: {json.dumps(annotation, indent=2)}")

            generated_image = diffuser(prompt=prompt).images[0]
            logging.info(f"Type of generated image: {type(generated_image)}")

            if isinstance(generated_image, torch.Tensor):
                img = ToPILImage()(generated_image)
            else:
                img = generated_image

            safe_prompt = prompt[:50].replace(' ', '_').replace('/', '_').replace('\\', '_')
            img_filename = f"{save_dir}/{safe_prompt}-{index}.png"
            img.save(img_filename)
            generated_images.append(img_filename)
            loop_time = time.time() - start_loop
            logging.info(f"Image saved to {img_filename}")

    total_time = time.time() - start_time
    logging.info(f"Total processing time: {total_time:.2f} seconds")
    return generated_images


def load_and_initialize_diffuser(diffuser_name, previous_diffuser=None):
    """Load and initialize the diffuser, handling previous diffuser cleanup if needed."""
    if previous_diffuser is not None:
        logging.info("Deleting previous diffuser, freeing memory")
        # Move to float32 if it's float16, then move to CPU for deletion
        if previous_diffuser.dtype == torch.float16:
            previous_diffuser = previous_diffuser.to(dtype=torch.float32)
        previous_diffuser.to('cpu')
        del previous_diffuser
        gc.collect()
        torch.cuda.empty_cache()
        
    return load_diffuser(diffuser_name)

def test_diffuser_on_dataset(dataset, annotations, diffuser, output_dir, num_images, batch_size, diffuser_name):
    """Test a single diffuser on a given dataset."""
    dataset_name = dataset.rsplit('/', 1)[-1] if '/' in dataset else dataset
    diffuser_name = diffuser_name.rsplit('/', 1)[-1] if '/' in diffuser_name else diffuser_name
    save_dir = os.path.join(output_dir, dataset_name, diffuser_name)
    logging.info(f"Testing {diffuser_name} on annotation dataset {dataset} at {save_dir}...")
    os.makedirs(save_dir, exist_ok=True)
    
    try:
        generate_images(annotations, diffuser, save_dir, num_images, batch_size, diffuser_name)
        logging.info("Images generated and saved successfully.")
    except Exception as e:
        logging.error(f"Failed to generate images with {diffuser_name}: {str(e)}")

def cleanup_diffuser(diffuser):
    """Clean up resources associated with a diffuser."""
    logging.info("Deleting diffuser, freeing memory")
    # Move to float32 if it's float16, then move to CPU for deletion
    if diffuser.dtype == torch.float16:
        diffuser = diffuser.to(dtype=torch.float32)
    diffuser.to('cpu')
    del diffuser
    gc.collect()
    torch.cuda.empty_cache()

def test_diffusers_on_datasets(annotations_dir, output_dir, num_images=1, batch_size=2):
    """Test various diffusers on datasets."""
    datasets = list_datasets(annotations_dir)
    for diffuser_name in DIFFUSER_NAMES:
        logging.info(f"Loading and initializing diffuser: {diffuser_name}")
        diffuser = load_and_initialize_diffuser(diffuser_name)
        for dataset in datasets:
            annotations = load_annotations(annotations_dir, dataset)
            test_diffuser_on_dataset(dataset, annotations, diffuser, output_dir, num_images, batch_size, diffuser_name)
        cleanup_diffuser(diffuser)

In [None]:
## Multiprocessing loop
def multiprocess_generate_images(annotations_dir, output_dir, num_processes=None):
    if num_processes is None:
        num_processes = max(1, cpu_count() - 1)  # Leaves one CPU core free

    datasets = list_datasets(annotations_dir)
    for model_name in DIFFUSER_NAMES:
        logging.info(f"Processing with model: {model_name}")
        with ctx.Pool(processes=num_processes, initializer=worker_initializer, initargs=(model_name, device, DIFFUSER_ARGS)) as pool:
            for dataset in datasets:
                annotations = load_annotations(annotations_dir, dataset)
                save_dir = os.path.join(output_dir, model_name, dataset)

                # Split annotations into chunks for each worker
                chunk_size = (len(annotations) + num_processes - 1) // num_processes
                chunks = [annotations[i:i + chunk_size] for i in range(0, len(annotations), chunk_size)]

                results = pool.starmap(generate_images_for_chunk, [(chunk, save_dir) for chunk in chunks])
                logging.info(f"Completed processing for dataset {dataset} with model {model_name}")

In [None]:
ANNOTATIONS_DIR = "test_data/annotations/"
OUTPUT_DIR = "test_data/synthetics_from_annotations/"

In [None]:
# GPU
test_diffusers_on_datasets(ANNOTATIONS_DIR, OUTPUT_DIR, num_images=1, batch_size=8)

In [None]:
# CPU Multiprocessing
# multiprocess_generate_images(ANNOTATIONS_DIR, OUTPUT_DIR)

In [None]:
def test_specific_diffuser_on_specific_dataset(annotations_dir, output_dir, diffuser_name="stabilityai/stable-diffusion-xl-base-1.0", dataset_name="dalle-mini_open-images", num_images=1, batch_size=8):
    """Test a specific diffuser on a specific dataset"""
    logging.info(f"Loading and initializing diffuser: {diffuser_name}")
    diffuser = load_and_initialize_diffuser(diffuser_name)
    annotations = load_annotations(annotations_dir, dataset_name)
    
    logging.info(f"Testing diffuser: {diffuser_name} on dataset: {dataset_name}")
    test_diffuser_on_dataset(dataset_name, annotations, diffuser, output_dir, num_images, batch_size, diffuser_name)
    
    cleanup_diffuser(diffuser)

In [None]:
DIFFUSER_NAMES

In [None]:
stable_diff = "stable_test/"
ANNOTATIONS_DIR = "test_data/dataset/annotations/"
test_specific_diffuser_on_specific_dataset(ANNOTATIONS_DIR, stable_diff, num_images=10)

Image Comparison Test

In [None]:
# # Standard libraries
# import random
# import logging

# # Third-party libraries
# import numpy as np
# import torch
# from transformers import AutoProcessor, Blip2ForConditionalGeneration

# # Bitmind-specific libraries
# from bitmind.image_dataset import ImageDataset
# from bitmind.constants import DATASET_META

# # Initialize seeds for reproducibility
# torch.manual_seed(0)
# random.seed(0)
# np.random.seed(0)

# device = "cuda" if torch.cuda.is_available() else "cpu"

# # Suppress logs
# transformers_level = logging.getLogger("transformers").getEffectiveLevel()
# huggingface_hub_level = logging.getLogger("huggingface_hub").getEffectiveLevel()
# logging.getLogger("transformers").setLevel(logging.ERROR)
# logging.getLogger("huggingface_hub").setLevel(logging.ERROR)

# # Load the processor and model
# processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b-coco")
# model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b-coco", torch_dtype=torch.float16) 
# model.to(device)

# # Restore log settings
# logging.getLogger("transformers").setLevel(transformers_level)
# logging.getLogger("huggingface_hub").setLevel(huggingface_hub_level)

# # Load real datasets
# print("Loading real datasets")
# real_image_datasets = [
#     ImageDataset(ds['path'], 'test', ds.get('name', None), ds['create_splits'])
#     for ds in DATASET_META['real']
# ]

In [None]:
# # Target dataset name and index of image to display
# target_dataset_name = 'dalle-mini/open-images'
# index_of_image = 10

# for dataset in real_image_datasets:
#     if dataset.huggingface_dataset_path == target_dataset_name:
#         for index, image_info in enumerate(dataset):
#             if index == index_of_image:
#                 display(image_info['image'])
#                 break
#         break