In [61]:
import warnings
warnings.filterwarnings('ignore')

import os 
import json 
import pandas as pd 
import torch
from PIL import Image
from transformers import ViltProcessor
import numpy as np


In [11]:
class Config:
    base_dir = os.path.dirname(os.getcwd())
    data_dir = os.path.join(base_dir, 'dataset')
    image_dir = os.path.join(data_dir, 'images')
    train_file = 'data_train.csv'
    eval_file = 'data_eval.csv'
    answer_space_file = 'answer_space.txt'

cfg = Config()

In [12]:
with open(os.path.join(cfg.data_dir, cfg.answer_space_file)) as f:
    answer_space = f.readlines()
answer_space = [ans.strip() for ans in answer_space]

train_df = pd.read_csv(os.path.join(cfg.data_dir, cfg.train_file))
eval_df = pd.read_csv(os.path.join(cfg.data_dir, cfg.eval_file))

print(f"Train Data Size : {len(train_df)}")
print(f"Eval Data Size : {len(eval_df)}")

Train Data Size : 9974
Eval Data Size : 2494


In [13]:
label2id = {label:idx for idx, label in enumerate(answer_space)}
id2label = {v:k for k,v in label2id.items()}

In [37]:
def prepare_annotations(data_df : pd.DataFrame, label2id : dict) -> dict: 
    annotations = []
    for idx,row in data_df.iterrows():
        question = row['question']
        image_id = row['image_id']
        answer = [ans.strip() for ans in row['answer'].split(',')]
        answer_count = {}
        for answer_ in answer:
            answer_count[answer_] = answer_count.get(answer_, 0) + 1
        
        labels = []
        scores = []
        for answer_ in answer_count:
            labels.append(label2id[str(answer_)])
            scores.append(1.0)
        
        annotations_dict = {
            'question' : question,
            'image_id' : image_id,
            'answer' : answer,
            'labels' : labels,
            'scores' : scores
        }
        annotations.append(annotations_dict)
    return annotations


In [38]:
train_annotations = prepare_annotations(data_df = train_df, label2id = label2id)
eval_annotations = prepare_annotations(data_df = eval_df, label2id = label2id)

{'question': 'what is the object on the shelves',
 'image_id': 'image100',
 'answer': ['cup'],
 'labels': [149],
 'scores': [1.0]}

In [42]:
class VQADataset(torch.utils.data.Dataset):
    """VQA (v2) dataset."""

    def __init__(self, annotations, processor, image_dir, id2label):
        self.annotations = annotations
        self.processor = processor
        self.image_dir = image_dir
        self.id2label = id2label

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        # get image + text
        annotation = self.annotations[idx]
        image_id = annotation['image_id']
        image = Image.open(os.path.join(self.image_dir, f'{image_id}.png'))
        text = annotation['question']

        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
        # remove batch dimension
        for k,v in encoding.items():
          encoding[k] = v.squeeze()
        # add labels
        labels = annotation['labels']
        scores = annotation['scores']
        # based on: https://github.com/dandelin/ViLT/blob/762fd3975c180db6fc88f577cf39549983fa373a/vilt/modules/objectives.py#L301
        targets = torch.zeros(len(self.id2label))
        for label, score in zip(labels, scores):
              targets[label] = score
        encoding["labels"] = targets
        return encoding

In [48]:
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")

In [49]:
dataset = VQADataset(annotations=train_annotations[:100],
                     processor=processor,
                     image_dir=cfg.image_dir,
                     id2label=id2label
                     )

In [50]:
dataset[0].keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'])

In [51]:
processor.decode(dataset[0]['input_ids'])

'[CLS] what is the object on the shelves [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [65]:
labels = np.array(torch.nonzero(dataset[0]['labels']))[0].tolist()
[id2label[label] for label in labels]

['cup']

In [57]:
labels

149