In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from diffusers import FluxPipeline
import torch

In [None]:
DEVICE = torch.device("cuda:0")
DTYPE = torch.bfloat16

pipe: FluxPipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev", torch_dtype=DTYPE, transformer=None
).to(DEVICE)

In [None]:
from torchvision import transforms
from PIL import Image

def process_sample(image_path, prompt_path):
    img = Image.open(image_path)
    img = pipe.image_processor.preprocess(img).to(device=DEVICE, dtype=DTYPE)
    latent = pipe.vae.encode(img).latent_dist.sample()
    latent = (latent - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
    with open(prompt_path, "r") as f:
        prompt = f.read().strip()
    prompt_embeds, pooled_prompt_embeds, _ = pipe.encode_prompt(
        prompt=prompt,
        prompt_2=prompt,
        device=DEVICE,
    )
    latent = latent.cpu().squeeze(0)
    prompt_embeds = prompt_embeds.cpu().squeeze(0)
    pooled_prompt_embeds = pooled_prompt_embeds.cpu().squeeze(0)
    return {
        "clean_latents": latent,
        "prompt_embeds": prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
    }

In [None]:
image_path = "/workspaces/torch-basics/data/285058/000.jpeg"
prompt_path = "/workspaces/torch-basics/data/285058/000.txt"
sample = process_sample(image_path, prompt_path)
sample

In [None]:
import os
import lmdb
import pickle

def process_directory(directory):
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".jpeg") or file.endswith(".jpg"):
                image_path = os.path.join(root, file)
                prompt_path = os.path.join(root, file.replace(".jpeg", ".txt").replace(".jpg", ".txt"))
                if os.path.exists(prompt_path):
                    yield process_sample(image_path, prompt_path)

def save_lmdb(samples, lmdb_path):
    env = lmdb.open(lmdb_path, map_size=int(1e12))
    with env.begin(write=True) as txn:
        for i, sample in enumerate(samples):
            pickle_data = pickle.dumps(sample)
            txn.put(f"{i:08}".encode(), pickle_data)
    env.close()

In [None]:
directory = "/workspaces/torch-basics/data/285058"
lmdb_path = "/workspaces/torch-basics/data/285058.lmdb"
samples = process_directory(directory)
save_lmdb(samples, lmdb_path)