In [None]:
import daft
from daft import DataFrame
from PIL import Image
import io
daft.context.set_runner_ray()

In [None]:
coco_df = DataFrame.read_parquet("s3://daft-public-data/coco-2017/mscoco.parquet")
coco_df = coco_df.repartition(8).limit(200)
coco_df.show()

In [None]:
images_df = coco_df.select('URL').distinct()
images_df.show()

In [None]:

images_df = images_df.with_column(
    "image",
    images_df["URL"].url.download().apply(
        lambda data: Image.open(io.BytesIO(data)).resize((512, 512)), 
        return_type=Image.Image,
    )
)
images_df.show()

In [None]:
import numpy as np
from daft import udf, col
from daft.resource_request import ResourceRequest
from typing import List
import clip
import torch

@udf(return_type=np.ndarray)
class ImageClipExtractor:
    def __init__(self, model_name: str="ViT-B/32") -> None:
        self.device = "cuda"
        self.model, self.preprocess = clip.load(model_name, device=self.device)
        
    
    def __call__(self, images: List[Image.Image]) -> np.ndarray:   
        if len(images) == 0:
            return []        
        image_input = torch.stack([self.preprocess(img) for img in images]).to(self.device)

        with torch.no_grad():
            image_features = self.model.encode_image(image_input)
            image_features = image_features.detach().cpu().float()
            norm = image_features.norm(p=2, dim=1, keepdim=True)
        return (image_features / norm).numpy()
    
    
images_df = images_df.with_column('image_clip_embedding', 
                                  ImageClipExtractor(col('image')),
                                  resource_request=ResourceRequest(num_gpus=1))

In [None]:
@udf(return_type=np.ndarray)
class TextClipExtractor:
    def __init__(self, model_name: str="ViT-B/32") -> None:
        self.device = "cuda"
        self.model, self.preprocess = clip.load(model_name, device=self.device)
        
    
    def __call__(self, text: List[str]) -> np.ndarray:      
        tokens = clip.tokenize(text).to(self.device)

        with torch.no_grad():
            features = self.model.encode_text(tokens)
            features = features.detach().cpu().float()
            features /= features.norm(p=2, dim=-1, keepdim=True)
        return features.numpy()
text_df = coco_df.with_column('text_clip_embedding',
                              TextClipExtractor(col('TEXT')),
                              resource_request=ResourceRequest(num_gpus=1))

In [None]:
joined_df = images_df.join(text_df, on='URL')


In [None]:
import numpy as np
@udf(return_type=float)
def cosine_similarity(A: List[np.ndarray], B: List[np.ndarray]) -> List[float]:
    return [np.dot(a,b) for a,b in zip(A, B)]

joined_df = joined_df.with_column("cosine_similarity", cosine_similarity(col("image_clip_embedding"), col("text_clip_embedding")))

In [None]:
joined_df.explain()

In [None]:
result = joined_df.select('URL', 'TEXT', 'image', 'cosine_similarity').collect()

In [None]:
best_caption = (result.groupby('URL')
                .max('cosine_similarity') 
                .join(result, 
                      on=['URL', 'cosine_similarity'])
                .sort("cosine_similarity", desc=True)).collect()
best_caption.explain()

In [None]:
best_caption.show()