In [None]:
from PIL import Image 
import requests
import torch 
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
import torchvision.datasets as dset
from torchvision.datasets import CocoCaptions
import torchvision.transforms as T
from typing import Any, Callable, List, Optional, Tuple, Union
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import json
from tqdm import tqdm

In [None]:
class CocoCaptions_custimized(CocoCaptions):
    def __init__(
        self,
        root: Union[str, Path],
        annFile: str,
        slice: int = 5000,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
        ) -> None:
        super().__init__(root, transforms, transform, target_transform)
        from pycocotools.coco import COCO

        self.coco = COCO(annFile)
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.ids = self.ids[:slice]
    def __getitem__(self, index: int) -> Tuple[Any, Any, Any]:

        if not isinstance(index, int):
            raise ValueError(f"Index must be of type integer, got {type(index)} instead.")

        id = self.ids[index]
        image = self._load_image(id)
        target = self._load_target(id)

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return id, image, target

In [None]:
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
transform = T.Compose(
            [
                T.Resize(
                    (448, 448), interpolation=T.InterpolationMode.BICUBIC
                ),
                T.ToTensor(),
                T.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD),
            ]
        )

In [None]:
cap = CocoCaptions_custimized(root = './val2014', annFile = './annotations/captions_val2014.json', transform=transform)
batch_size = 1
dataloader = DataLoader(cap, batch_size=batch_size, shuffle=False, num_workers=8)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("BAAI/Emu2")

with init_empty_weights():
     model = AutoModelForCausalLM.from_pretrained(
        "BAAI/Emu2",
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True)  

In [None]:
device_map = infer_auto_device_map(model, max_memory={0:'36GiB',1:'38GiB',}, no_split_module_classes=['Block','LlamaDecoderLayer'])  
device_map["model.decoder.lm.lm_head"] = 0

model = load_checkpoint_and_dispatch(
    model, 
    '/root/.cache/huggingface/hub/models--BAAI--Emu2/snapshots/fa835ec101e52da5e081695107e1ddd3c7c4d88a',
    device_map=device_map).eval()
query = '[<IMG_PLH>]Describe the image in details:' 

In [None]:
results = {}
for ids, image, target in tqdm(dataloader):
    inputs = model.build_input_ids(
        text=[query] * batch_size,
        tokenizer=tokenizer,
    )
    image = image.to("cuda")
    
    with torch.no_grad():
         outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            image=image.to(torch.bfloat16),
            max_new_tokens=64,
            length_penalty=-1)
    output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    for i in range(len(ids)):
        results[int(ids[i])] = (output_text[i], target[i])

In [None]:
# results in the format of {image_id: [generated_caption, [target_caption1, target_caption2, ...]], ...}
with open('captions.json', 'w') as f:
    json.dump(results, f)