In [68]:
%%capture
!pip install open_clip_torch

In [90]:
import torch
from torch.utils.data import Dataset
from torchvision.datasets import CocoDetection

import cv2
from PIL import Image
import pandas as pd
from matplotlib import pyplot as plt

from transformers import AutoTokenizer
# import open_clip

from tqdm.auto import tqdm
import json
import sys

In [2]:
path = "/content/drive/MyDrive/Colab Notebooks/VQA_hse/vqa_v2_rus_1000_examples.csv"

In [28]:
dataset_pd = pd.read_csv(path)

In [14]:
dataset = dataset_pd.drop('Unnamed: 0', axis=1)

In [15]:
dataset.to_json("ru_VQAv2_1000")

In [19]:
json_path = "ru_VQAv2_1000"
with open(json_path, 'r') as f:
    dataset = json.loads(list(f)[0])

In [None]:
tokenizer = AutoTokenizer.from_pretrained("ai-forever/rugpt3medium_based_on_gpt2")

In [40]:
answers = dataset["answer"]
questions = dataset["question"]
images = []
query_tokens = []
answer_tokens = []

In [64]:
for i in answers:
  answer = answers[i]
  question = questions[i]
  images += i
  query_tokens += [torch.tensor(tokenizer.encode(question), dtype=torch.int64)]
  answer_tokens += [torch.tensor(tokenizer.encode(answer), dtype=torch.int64)]

In [86]:
class VQAv2_Dataset(Dataset):
    def __init__(self, config, dataset_path, coef_size=0.1,
                 tokenizer_name="", prefix_length=20, normalize_prefix=False):
        if not tokenizer_name:
            tokenizer_name = config.decoder
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        clip_model, _, self.preprocess = open_clip.create_model_and_transforms(config.encoder, pretrained="laion400m_e32")
        self.prefix_length = prefix_length
        self.normalize_prefix = normalize_prefix

        with open(dataset_path, 'r') as f:
            dataset = json.loads(list(f)[0])
        answers = dataset["answer"]
        questions = dataset["question"]
        self.image_idx = []
        self.query_tokens = []
        self.answer_tokens = []

        max_img = len(answers)*coef_size
        for i in tqdm(answers, total=max_img):
            answer = answers[i]
            question = questions[i]
            self.query_tokens += [torch.tensor(tokenizer.encode(question), dtype=torch.int64)]
            self.answer_tokens += [torch.tensor(tokenizer.encode(answer), dtype=torch.int64)]
            self.image_idx += i
            if int(i) >= max_img:
              break
        del dataset
        sys.stdout.flush()

        #all_len
        self.max_seq_len = prefix_length
        # self.type = data_type

    """Почему не паддили captions?"""
    def pad_tokens(self, item: int):
        query_tokens = self.query_tokens[item]
        padding = self.max_seq_len - query_tokens.shape[0]
        if padding > 0:
            query_tokens = torch.cat((query_tokens, torch.zeros(padding, dtype=torch.int64) - 1))
            self.query_tokens[item] = query_tokens
        elif padding < 0:
            query_tokens = query_tokens[:self.max_seq_len]
            self.query_tokens[item] = query_tokens
        query_mask = query_tokens.ge(0)  # mask is zero where we out of sequence
        query_tokens[~query_mask] = 0
        query_mask = query_mask.float()


        answer_tokens = self.answer_tokens[item]
        padding = self.max_seq_len - answer_tokens.shape[0]
        if padding > 0:
            answer_tokens = torch.cat((answer_tokens, torch.zeros(padding, dtype=torch.int64) - 1))
            self.answer_tokens[item] = answer_tokens
        elif padding < 0:
            answer_tokens = answer_tokens[:self.max_seq_len]
            self.answer_tokens[item] = answer_tokens
        answer_mask = answer_tokens.ge(0)  # mask is zero where we out of sequence
        answer_tokens[~answer_mask] = 0
        answer_mask = answer_mask.float()

        return query_tokens, query_mask, answer_tokens, answer_mask

    def get_image(self, item):
#         name = str(self.img_paths[item])
#         name = f"{self.img_path}/{name}"
#         image_resized = cv2.resize(cv2.imread(naem), (256,256))
#         return Image.fromarray(cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB))
        image_resized = cv2.resize(self.image_idx[item], (256,256))
        return Image.fromarray(cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB))

    def __len__(self) -> int:
        return len(self.image_idx)

    def __getitem__(self, item):
        # image = self.get_image(item)
        # image = self.preprocess(image).unsqueeze(0)
        query_tokens, query_mask, answer_tokens, answer_mask = self.pad_tokens(item)
        # return query_tokens, query_mask, answer_tokens, answer_mask, image[0], item
        return query_tokens, query_mask, answer_tokens, answer_mask, item

    def show_image(self, item):
        image = self.get_image(item)
        text = self.tokenizer.decode(self.pad_tokens(item)[2])
        plt.imshow(img)
        print(text)

In [79]:
class Config:
  decoder: str = "ai-forever/rugpt3medium_based_on_gpt2"
  encoder: str = "ViT-B-16-plus-240"

In [80]:
config = Config()

In [87]:
dataset = VQAv2_Dataset(config, "ru_VQAv2_1000")

  0%|          | 0/100.0 [00:00<?, ?it/s]