In [1]:
import json
import os

from PIL import Image
from torch.utils.data import Dataset
from collections import Counter

In [2]:
class TextVQADataset(Dataset):
    def __init__(
        self, image_dir_path, annotations_path
    ):
        self.data = json.load(open(annotations_path, "r"))["data"]
        self.image_dir_path = image_dir_path
        
    def __len__(self):
        return len(self.data)

    def get_img_path(self, img_id):
        return os.path.join(self.image_dir_path, f"{img_id}.jpg")

    def most_frequent_string(self, strings):
        """
        Finds the most frequent string in an array of strings.
        
        :param strings: List of strings
        :return: The most frequent string and its count
        """
        if not strings:
            return None, 0  # Return None and 0 if the input list is empty
    
        # Count occurrences of each string
        counts = Counter(strings)
        
        # Find the string with the maximum count
        most_frequent = counts.most_common(1)[0]  # Returns a list of tuples [(string, count)]
        return most_frequent[0]
        
    def __getitem__(self, idx):
        question = self.data[idx]["question"]
        img_path = self.get_img_path(self.data[idx]["image_id"])
        image = Image.open(img_path)
        image.load()
        results = {
            "image": image,
            "image_id": self.data[idx]["image_id"],
            "question": question,
            "question_id": self.data[idx]["question_id"],
            "answers": self.data[idx]["answers"],
            
        }
        results["best_answer"] = self.most_frequent_string(results["answers"])
        return results

In [3]:
img_dir_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/textvqa/train_val_images/train_images"
tr_annotations_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/textvqa/TextVQA_0.5.1_train.json"
val_annotations_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/textvqa/TextVQA_0.5.1_val.json"

In [4]:
train_dataset = TextVQADataset(img_dir_path, tr_annotations_path)

In [4]:
val_dataset = TextVQADataset(img_dir_path, val_annotations_path)

In [6]:
len(train_dataset)

34602

In [7]:
train_dataset[0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x730>,
 'image_id': '0054c91397f2fe05',
 'question': 'what is the brand of phone?',
 'question_id': 0,
 'answers': ['nokia',
  'nokia',
  'nokia',
  'nokia',
  'toshiba',
  'nokia',
  'nokia',
  'nokia',
  'nokia',
  'nokia'],
 'best_answer': 'nokia'}

In [6]:
len(val_dataset)

5000

In [26]:
from tqdm import tqdm

In [27]:
uniq_images = set([train_dataset[i]['image_id'] for i in tqdm(range(len(train_dataset)))])


  0%|                                                 | 0/34602 [00:00<?, ?it/s][A
 43%|██████████████                   | 14757/34602 [00:00<00:00, 147562.05it/s][A
100%|█████████████████████████████████| 34602/34602 [00:00<00:00, 151554.65it/s][A


In [28]:
val_uniq_images = set([val_dataset[i]['image_id'] for i in tqdm(range(len(val_dataset)))])


100%|███████████████████████████████████| 5000/5000 [00:00<00:00, 134718.67it/s][A


In [32]:
# This means, train and val images are separate => we can use val as our test data
len(uniq_images), len(uniq_images-val_uniq_images), len(val_uniq_images)

(21953, 21953, 3166)

## RICE Features

In [15]:
import sys
import pickle
sys.path.append("..")

In [13]:
from rice import RICES

rices = RICES(train_dataset, 'cuda', 32)

Precomputing features for RICES: 100%|██████████████████████████████| 1082/1082 [10:41<00:00,  1.69it/s]


In [16]:
save_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/features-cache/textvqa.pkl"
with open(save_path, 'wb') as f:
    pickle.dump(rices.features.cpu(),f)

with open(save_path, 'rb') as f:
    rice_cached_features = pickle.load(f)

In [17]:
rices.features.cpu()

tensor([[ 0.0099,  0.0753, -0.0317,  ...,  0.0568, -0.0489, -0.0386],
        [ 0.0173,  0.0363, -0.0251,  ...,  0.0230,  0.0046,  0.0255],
        [ 0.0173,  0.0363, -0.0251,  ...,  0.0230,  0.0046,  0.0255],
        ...,
        [ 0.0008,  0.0309,  0.0169,  ...,  0.0940,  0.0352, -0.0186],
        [-0.0028, -0.0605, -0.0097,  ...,  0.0221,  0.0069, -0.0292],
        [-0.0028, -0.0605, -0.0097,  ...,  0.0221,  0.0069, -0.0292]])

In [18]:
rice_cached_features

tensor([[ 0.0099,  0.0753, -0.0317,  ...,  0.0568, -0.0489, -0.0386],
        [ 0.0173,  0.0363, -0.0251,  ...,  0.0230,  0.0046,  0.0255],
        [ 0.0173,  0.0363, -0.0251,  ...,  0.0230,  0.0046,  0.0255],
        ...,
        [ 0.0008,  0.0309,  0.0169,  ...,  0.0940,  0.0352, -0.0186],
        [-0.0028, -0.0605, -0.0097,  ...,  0.0221,  0.0069, -0.0292],
        [-0.0028, -0.0605, -0.0097,  ...,  0.0221,  0.0069, -0.0292]])

### Running results for textVQA

In [8]:
!python ../vila_e2e_vqa_textvqa.py --n_shots=2 

[2024-11-21 21:21:05,797] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Fetching 17 files: 100%|██████████████████████| 17/17 [00:00<00:00, 2810.64it/s]
Loading checkpoint shards: 100%|██████████████████| 2/2 [00:06<00:00,  3.15s/it]
new vqa_textvqa vila 3b
Namespace(model_name_or_path='Efficient-Large-Model/VILA1.5-3b', n_shots=2, use_random=False, n_random=0)
  0%|                                                    | 0/10 [00:00<?, ?it/s]{'query': ' <image> Question: what is proven about this brand? Short Answer: dependability <image> Question: what company is shown on the box? Short Answer: thomson<image> Question: what is the brand of this camera? Short Answer: ', 'output': 'thomson', 'references': ['nous les gosses', 'dakota', 'clos culombu', 'dakota digital', 'dakota', 'dakota', 'dakota digital', 'dakota digital', 'dakota', 'dakota']}
 10%|████▍                                       | 1/10 [00:00<00:08,  1.04it/s]{'query': ' <image>

In [10]:
!python ../vila_e2e_captioning_flickr8k.py --n_shots=2

[2024-11-21 22:04:54,919] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Fetching 17 files: 100%|██████████████████████| 17/17 [00:00<00:00, 7572.55it/s]
Loading checkpoint shards: 100%|██████████████████| 2/2 [00:06<00:00,  3.16s/it]
Namespace(model_name_or_path='Efficient-Large-Model/VILA1.5-3b', n_shots=2, use_random=False, n_random=0)
100%|███████████████████████████████████████████| 10/10 [00:07<00:00,  1.35it/s]
