In [None]:
import os
import urllib.request

PARQUET_URL = "https://huggingface.co/datasets/ChristophSchuhmann/improved_aesthetics_6.5plus/resolve/main/data/train-00000-of-00001-6f24a7497df494ae.parquet"
PARQUET_PATH = "laion_improved_aesthetics_6_5.parquet"

if not os.path.exists(PARQUET_PATH):
    with open(PARQUET_PATH, "wb") as f:
        response = urllib.request.urlopen(PARQUET_URL)
        f.write(response.read())

In [None]:
# Replace with your auth token as a string
# See: https://huggingface.co/docs/hub/security-tokens
HUGGINGFACE_AUTH_TOKEN = os.getenv("HUGGINGFACE_AUTH_TOKEN", "")
HUGGINGFACE_AUTH_TOKEN = "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"

In [None]:
from daft import DataFrame, col, udf

images_df = DataFrame.from_parquet(PARQUET_PATH)

In [None]:
images_df

In [None]:
import urllib.request
from urllib.error import HTTPError, URLError
import concurrent.futures
import threading
import PIL.Image
import boto3
import io


@udf(return_type=PIL.Image.Image)
def download_batch(url_col):
    def download_single(obj: str):
        try:
            response = urllib.request.urlopen(obj)
        except (HTTPError, URLError):
            return None
        return response.read()
    
    with concurrent.futures.ThreadPoolExecutor() as executor:
        byte_contents = [res for res in executor.map(download_single, url_col)]
        images = []
        for payload in byte_contents:
            if payload is None:
                images.append(None)
            else:
                with io.BytesIO(payload) as f:
                    images.append(PIL.Image.open(f).convert("RGB"))
        return images

In [None]:
images_df = images_df.limit(20)

In [None]:
@udf(return_type=int)
def str_len(text_col, min_len=100):
    return [len(s) for s in text_col]

@udf(return_type=bool)
def is_not_null(c):
    return [x is not None for x in c]

images_downloaded_df = images_df.with_column("image", download_batch(col("URL"))).where((str_len(col("TEXT")) > 50) & is_not_null(col("image")))

In [None]:
images_downloaded_df

In [None]:
import torch
from diffusers import DiffusionPipeline

@udf(return_type=PIL.Image.Image)
class GenerateImageFromText:
    
    def __init__(self):
        self.pipeline = DiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4",
            use_auth_token=HUGGINGFACE_AUTH_TOKEN,
        )

    def __call__(self, text_col, num_steps=5):
        return [self.pipeline(t, num_inference_steps=num_steps)["sample"][0] for t in text_col]

images_downloaded_df.with_column("generated_image", GenerateImageFromText(col("TEXT"))).show(1)

In [None]:
import torch
from diffusers import DiffusionPipeline

@udf(return_type=PIL.Image.Image, request_gpu=1)
class GenerateImageFromTextGPU:

    def __init__(self):
        self.pipeline = DiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4",
            use_auth_token=HUGGINGFACE_AUTH_TOKEN,
        )
        self.pipeline = self.pipeline.to("cuda:0")

    def __call__(self, text_col, num_steps=5):
        return [self.pipeline(t, num_inference_steps=num_steps)["sample"][0] for t in text_col]

images_downloaded_df.with_column("generated_image", GenerateImageFromTextGPU(col("TEXT"))).show(1)