# Daft Demo

## Here we show a demo of Daft of the following: 
- Initializing our cluster
- Daft Data Repos
- Use a Python Dataclass to define a Schema
- Load existing data from our Data Repos
- Write a function to download data from the web
- Write a function to decode and resize the image
- Write our own Schema for image storage
- Save downloaded images to the cloud
- Write our own embedding extractor for batch inference
- Save our embeddings to a data repo
- Preview our Schema

## Initializing our cluster

In [None]:
import daft

daft.init()

## Daft Data Repos

In [None]:
from daft import Datarepo

Datarepo.list_ids()

## Defining our Own Schema
- ORM for binary data
- Translates to parquet under the hood
- Support for logical types like images, numpy arrays and any other types that you can define yourself

In [None]:
import dataclasses
from daft import dataclass

@dataclass
class OpenImagesMetadata:
    url: str
    size: int
    id: str

## Reading Data from our Data Repo

In [None]:
readback = Datarepo.from_id('openimages-dc-8000-v2', data_type=OpenImagesMetadata)


## Previewing our Data

In [None]:
print(readback)
readback.show(4)

In [None]:
@dataclass
class ImageBinaryPayload:
    url: str
    data: bytes = dataclasses.field(repr=False)

In [None]:
import concurrent.futures
import requests
from typing import List

def download_single(image_metadata: OpenImagesMetadata) -> ImageBinaryPayload:
    r = requests.get(image_metadata.url)
    if r.status_code == 200:
        return ImageBinaryPayload(image_metadata.url, r.content)
    else:
        return ImageBinaryPayload(image_metadata.url, b'')


def download_batch(batch: List[OpenImagesMetadata]) -> List[ImageBinaryPayload]:
    with concurrent.futures.ThreadPoolExecutor() as exector : 
        futures = exector.map(download_single, batch)
        return list(futures)


### Download via a non-batched map

In [None]:
%%time

image_payload_single = readback.map(download_single)

### Download via a batched map

In [None]:
%%time

image_payload_batch = readback.map_batches(download_batch, batch_size=64)

## Decode and Resize our downloaded images

In [None]:
import io
import PIL.Image
from daft.fields import DaftImageField
from daft.types import DaftImageType

@dataclass
class ProcessedImageData:
    url: str
    img: PIL.Image.Image = DaftImageField(encoding=DaftImageType.Encoding.JPEG)
    
    @classmethod
    def from_image_payload(cls, payload: ImageBinaryPayload, size:int=256) -> 'ProcessedImageData':
        with io.BytesIO(payload.data) as f:
            try:
                img = PIL.Image.open(f)
                img = img.resize((size,size))
                img = img.convert("RGB")
            except Exception as e:
                img = PIL.Image.new("RGB", (size, size))
            return cls(payload.url, img)

In [None]:
resized_decoded_images = image_payload_batch.map(ProcessedImageData.from_image_payload)

### Lets look at our images

In [None]:
resized_decoded_images.show(4)

### Defining our Embedding data model

In [None]:
import numpy as np

@dataclass
class ProcessedEmbedding:
    url: str
    model: str
    dim: int
    mean: float
    std: float
    embedding: np.ndarray


### Defining our function for Batch Inference

In [None]:
from typing import Tuple

import torch
import torchvision

        
class BatchInferModel:
    def __init__(self):
        """
        Here we init our model as well as needed data transforms
        """
        self.model_name = "resnet18"
        model = torchvision.models.resnet18(pretrained=True).eval()
        self.feature_extractor = torchvision.models.feature_extraction.create_feature_extractor(
            model=model, 
            return_nodes={'avgpool': 'embedding'}
        )
        self.to_tensor = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            )]
        )
    
    
    def prepare_batch(self, image_data: List[ProcessedImageData]) -> Tuple[torch.Tensor, List[str]]:
        """
        Here we convert our PIL image to a normalized tensor
        """
        pil_images = [item.img for item in image_data]
        urls = [item.url for item in image_data]
        return torch.stack([self.to_tensor(img) for img in pil_images]), urls
    
    def __call__(self, image_data: List[ProcessedImageData]) -> List[ProcessedEmbedding]:
        """
        Here we extract our embedding with resnet 18
        """
        with torch.no_grad():
            tensor, urls = self.prepare_batch(image_data)
            embedding =  self.feature_extractor(tensor.float())['embedding'].view(len(image_data), -1)
            np_embedding = embedding.cpu().numpy()
            dim = np_embedding.shape[1]
            per_image_embedding = np.vsplit(np_embedding, np.arange(1, len(image_data)))
                        
            return [ProcessedEmbedding(
                url=url,
                embedding=e,
                mean=e.mean(),
                std=e.std(),
                model=self.model_name,
                dim=dim)
                   for url, e in zip(urls, per_image_embedding)]

## Running large scale batch inference

In [None]:
%%time

embeddings = resized_decoded_images.map_batches(BatchInferModel, batch_size=8)

In [None]:
embeddings.show(3)

## Save our extracted embeddings to the cloud in Parquet

In [None]:
embeddings.save('open-images-8k-processed-embeddings')

## Peeking under the Hood of Serialization

In [None]:
print(ProcessedEmbedding._daft_schema.arrow_schema())