In [1]:
import daft
from daft import DataFrame
from PIL import Image
import io
import ray
from pprint import pprint
USE_RAY_REMOTE = True

if USE_RAY_REMOTE:
    ray.init(
        address="ray://localhost:10001",
        runtime_env={"pip": [
            "https://anaconda.org/daft-nightly/getdaft/0.0.22%2Bdev0034.d218a0c/download/getdaft-0.0.22%2Bdev0034.d218a0c-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
            "git+https://github.com/openai/CLIP.git",
            "pillow",
            "s3fs",
        ]},
    )
    pprint(ray.available_resources())
daft.context.set_runner_ray()

{'CPU': 64.0,
 'GPU': 8.0,
 'accelerator_type:V100': 8.0,
 'memory': 556793856000.0,
 'node:10.0.0.132': 1.0,
 'node:10.0.0.157': 1.0,
 'node:10.0.0.25': 1.0,
 'node:10.0.1.107': 1.0,
 'node:10.0.1.154': 1.0,
 'node:10.0.1.175': 1.0,
 'node:10.0.1.198': 1.0,
 'node:10.0.2.127': 1.0,
 'node:10.0.2.203': 1.0,
 'object_store_memory': 166766611563.0}


DaftContext(runner_config=_RayRunnerConfig(address=None, max_tasks_per_core=None, max_refs_per_core=None, batch_dispatch_coeff=None), disallow_set_runner=True)

In [None]:
coco_df = DataFrame.read_parquet("s3://daft-public-data/coco-2017/mscoco.parquet")

if USE_RAY_REMOTE:
    coco_df = coco_df.repartition(64).limit(10000)
else:
    coco_df = coco_df.repartition(8).limit(200)

coco_df.show(5)

In [None]:
images_df = coco_df.select('URL').distinct()
images_df.show(5)

In [None]:

images_df = images_df.with_column(
    "image",
    images_df["URL"].url.download().apply(
        lambda data: Image.open(io.BytesIO(data)).resize((512, 512)) if data else None, 
        return_type=Image.Image,
    )
).where(~col('image').is_null())
images_df.show(5)

In [None]:
import numpy as np
from daft import udf, col
from daft.resource_request import ResourceRequest
from typing import List
import clip
import torch

USE_GPU = True
class ClipExtractor:
    def __init__(
        self,
        model_name: str = "ViT-B/32",
        batch_size: int = 128,
    ) -> None:
        self.device = "cuda" if USE_GPU else "cpu"
        self.batch_size = batch_size

        self.model, self.preprocess = clip.load(model_name, device=self.device, jit=True)
        self.num_dims = 512
    
    @staticmethod
    def batched(iterable, n=32):
        l = len(iterable)
        for ndx in range(0, l, n):
            yield iterable[ndx:min(ndx + n, l)]
    
        
@udf(return_type=np.ndarray)
class ImageClipExtractor(ClipExtractor):
    """Extracts CLIP embeddings from images"""
    def __call__(self, images: List[Image.Image | None]) -> np.ndarray:
        if not images:
            return []
        
        clip_embeddings = np.zeros((0, self.num_dims))
        for images_batch in ClipExtractor.batched(images, n=self.batch_size):
            image = torch.stack([self.preprocess(img) for img in images_batch]).to(self.device)
            with torch.no_grad():
                image_features = self.model.encode_image(image)
                image_features = image_features.detach().cpu().float()
                norm = image_features.norm(p=2, dim=1, keepdim=True)
            clip_embeddings = np.concatenate([clip_embeddings, (image_features / norm).numpy()])
            
        return clip_embeddings    
    


@udf(return_type=np.ndarray)
class TextClipExtractor(ClipExtractor):    
    def __call__(self, text: List[str]) -> np.ndarray:
        if not text:
            return []
        
        clip_embeddings = np.zeros((0, self.num_dims))
        for text_batch in ClipExtractor.batched(text, n=self.batch_size):
            tokens = clip.tokenize(text_batch).to(self.device)
            with torch.no_grad():
                features = self.model.encode_text(tokens)
                features = features.detach().cpu().float()
                features /= features.norm(p=2, dim=-1, keepdim=True)
            clip_embeddings = np.concatenate([clip_embeddings, features.numpy()], axis=0)
            
        return clip_embeddings    
    


In [None]:
images_df = images_df.with_column('image_clip_embedding', 
                                  ImageClipExtractor(col('image')),
                                  resource_request=ResourceRequest(num_gpus=0.25))


text_df = coco_df.with_column('text_clip_embedding',
                              TextClipExtractor(col('TEXT')),
                              resource_request=ResourceRequest(num_gpus=0.25))

In [None]:
joined_df = images_df.join(text_df, on='URL')


In [None]:
import numpy as np
@udf(return_type=float)
def cosine_similarity(A: List[np.ndarray], B: List[np.ndarray]) -> List[float]:
    return [np.dot(a,b) for a,b in zip(A, B)]

joined_df = joined_df.with_column("cosine_similarity", cosine_similarity(col("image_clip_embedding"), col("text_clip_embedding")))

In [None]:
joined_df.explain()

In [None]:
%%time

result = joined_df.select('URL', 'TEXT', 'image', 'cosine_similarity').collect()

In [None]:
%%time

best_caption = (result.groupby('URL')
                .max('cosine_similarity') 
                .join(result,
                      on=['URL', 'cosine_similarity'])
                .sort("cosine_similarity", desc=True)).collect()
best_caption.explain()

In [None]:
best_caption.show(10)