In [None]:
from daft import DataFrame, col, udf

IMAGES_CSV_S3_PATH = "s3://amazon-berkeley-objects/images/metadata/images.csv.gz"
images_df = DataFrame.from_csv(IMAGES_CSV_S3_PATH)

In [None]:
images_df

In [None]:
LISTING_JSON_S3_PATH = "s3://amazon-berkeley-objects/listings/metadata/listings_0.json.gz"
listings_df = DataFrame.from_json(LISTING_JSON_S3_PATH)

In [None]:
%%time
listings_df

In [None]:
listings_df.select(col("bullet_point"), col("product_type"), col("item_name"), col("material"))

In [None]:
from daft.expressions import udf
import pandas as pd

@udf(return_type=dict)
def get_first(lists):
    return [l[0] if len(l) > 0 else None for l in lists]

@udf(return_type=str)
def get_first_en_or_none(lists):
    lists = [[item["value"] for item in l if item["language_tag"].startswith("en")] if l else None for l in lists]
    lists = [en_items[0] if en_items else None for en_items in lists]
    return pd.Series(lists)

@udf(return_type=str)
def get_en_concat(lists):
    lists = [[item["value"] for item in l if item["language_tag"].startswith("en")] if l else None for l in lists]
    lists = [" ".join(en_items) if en_items else None for en_items in lists]
    return pd.Series(lists)

@udf(return_type=bool)
def is_not_null(bullet_points: pd.Series):
    return bullet_points.notnull()

@udf(return_type=str)
def cast_str(l):
    return pd.Series(l)

processed_listings = listings_df.with_column(
    "product_type_parsed", cast_str(get_first(col("product_type")).as_py(dict)["value"])
).where(
    col("product_type_parsed") == "SHOES"
).select(
    get_en_concat(col("bullet_point")).alias("details"),
    col("product_type_parsed"),
    get_first_en_or_none(col("item_name")),
    col("main_image_id")
).where(
    is_not_null(col("details"))
)

In [None]:
processed_listings

In [None]:
joined_df = processed_listings.join(images_df, left_on=col("main_image_id"), right_on=col("image_id"))

In [None]:
joined_df

In [None]:
import concurrent.futures
import threading
import PIL.Image
import boto3
import io


@udf(return_type=PIL.Image.Image)
def download_batch(batch):
    def download_single(obj: str) -> bytes:
        local = threading.local()
        if "boto_session" not in local.__dict__:
            local.boto_session = boto3.session.Session()
        s3 = local.boto_session.client('s3')
        bucket, key = obj.replace("s3://", "").split("/", maxsplit=1)
        response = s3.get_object(Bucket=bucket, Key=key)
        body = response["Body"]
        contents = body.read()
        body.close()
        return contents
    
    with concurrent.futures.ThreadPoolExecutor() as executor : 
        byte_contents = [res for res in executor.map(download_single, batch)]
        images = []
        for payload in byte_contents:
            with io.BytesIO(payload) as f:
                images.append(PIL.Image.open(f).convert("RGB"))
        return images
    
@udf(return_type=str)
def full_url(paths):
    return pd.Series(["s3://amazon-berkeley-objects/images/small/" + path for path in paths])

In [None]:
joined_df = joined_df.with_column("s3_url", full_url(col("path")))
joined_df = joined_df.with_column("image", download_batch(col("s3_url")))

In [None]:
joined_df

In [None]:
import torch
from min_dalle import MinDalle

# TODO(jay): We should provide a UDF API to do expensive initializations once only
model = MinDalle(
    models_root='/tmp/pretrained',
    dtype=torch.float32,
    device='cpu',
    is_mega=False, 
    is_reusable=True
)

@udf(return_type=PIL.Image.Image)
def generate_image_from_text(details):
    return [
        model.generate_image(
            text=detail,
            seed=-1,
            grid_size=1,
            is_seamless=False,
            temperature=1,
            top_k=256,
            supercondition_factor=32,
            is_verbose=False
        ) for detail in details
    ]

joined_df = joined_df.with_column("generated_image", generate_image_from_text(col("details")))

In [None]:
joined_df