In [1]:
import os
import json
from PIL import Image
from torch.utils.data import Dataset
import re
from torchvision import transforms
import pandas as pd
import numpy as np

import torch
from typing import Tuple

from transformers import AutoTokenizer, CLIPVisionModel
from torchvision import transforms
from torch import nn
from collections import OrderedDict
from transformers.activations import ACT2FN



# process question when input model
def pre_question(question, max_ques_words):
    question = re.sub(
        r"([,.'!?\"()*#:;~])",
        '',
        question.lower(),
    ).replace(' \t', ' ').replace('is/are', 'is').replace('near/in', 'in')
    question = question.replace('>', 'more than ').replace('-yes/no', '')
    question = question.replace('x ray', 'xray').replace('x-ray', 'xray')
    question = question.rstrip(' ')

    # truncate question
    question_words = question.split(' ')
    if len(question_words) > max_ques_words:
        question = ' '.join(question_words[:max_ques_words])

    return question

# process answer when input model
def pre_answer(answer):
    answer = str(answer)
    answer = re.sub(
        r"([,.'!?\"()*#:;~])",
        '',
        answer.lower(),
    ).replace(' \t', ' ')
    answer = answer.replace('x ray', 'xray').replace('x-ray', 'xray')
    answer = answer.replace(' - ', '-')
    return answer

def visual_feature(model, image, proj):
    inputs_img = {'pixel_values':0}
    inputs_img['pixel_values'] = image.unsqueeze(0)
    
    #image feature
    image_features = model(**inputs_img)
    selected_image_features = image_features.last_hidden_state[:, 1:]

    #feature embedding
    img_emd = projector(selected_image_features)
    return img_emd
    

class LlavaMultiModalProjector(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.linear_1 = nn.Linear(config["hidden_size"], config["hidden_size_text"], bias=True)
        self.act = ACT2FN[config["projector_hidden_act"]]
        self.linear_2 = nn.Linear(config["hidden_size_text"], config["hidden_size_text"], bias=True)

    def forward(self, image_features):
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


class vqa_dataset(Dataset):
    def __init__(self, ann_file, transform, vqa_root, eos='[SEP]', split="train", max_ques_words=30,
                 answer_list='', clip_model = None, feature_proj=None):
        self.split = split
        self.ann =  pd.read_pickle(ann_file)

        self.transform = transform
        self.vqa_root = vqa_root
        self.max_ques_words = max_ques_words
        self.eos = eos

        self.clip_model = clip_model
        self.feature_proj = feature_proj

        if split == 'test':
            self.max_ques_words = 50  # do not limit question length during test
            self.answer_list = json.load(open(answer_list, 'r'))

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):

        ann = self.ann[index]

        image_path = os.path.join(self.vqa_root, ann['image'])
        image = Image.open(image_path + '.jpg').convert('RGB')
        image = self.transform(image).cuda()
        
        image_proj = visual_feature(self.clip_model, image, self.feature_proj)

        if self.split == 'test':
            question = pre_question(ann['question'], self.max_ques_words)
            question_id = ann['qid']
            return image, question, question_id

        elif self.split == 'train':

            question = pre_question(ann['question'], self.max_ques_words)
            answers = ann['answer']
            answers = pre_answer(answers)

            return {'context':image_proj, 'question':question, 'answer':answers}

normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
train_transform = transforms.Compose([
        # transforms.RandomResizedCrop(384, scale=(0.5, 1.0), interpolation=Image.BICUBIC),
        transforms.Resize((336,336), interpolation=Image.BICUBIC),
        transforms.RandomHorizontalFlip(),
        # RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
        #                                       'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
        transforms.ToTensor(),
        normalize,
    ])


train_path = '/data/user-data/sa25729/MICCAI_2024/datasets/pvqa/qas/train/train_qa.pkl'
image_path = '/data/user-data/sa25729/MICCAI_2024/datasets/pvqa/images/train'

config = {}
config["hidden_size"] = 1024
config["hidden_size_text"] = 4096
config["projector_hidden_act"] = 'gelu'

model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336").cuda()
projector = LlavaMultiModalProjector(config).cuda()

train_dataset = vqa_dataset(train_path, train_transform, image_path, split='train', clip_model = model, feature_proj = projector)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  return torch.load(checkpoint_file, map_location=map_location)


In [2]:
train_dataset[0]['context'].size(), train_dataset[0].keys(), train_dataset[0]['question'], train_dataset[0]['answer'], len(train_dataset)


(torch.Size([1, 576, 4096]),
 dict_keys(['context', 'question', 'answer']),
 'where are liver stem cells oval cells located',
 'in the canals of hering',
 19755)

In [3]:
from lit_llama.tokenizer import Tokenizer
from pathlib import Path
from typing import Optional
import torch.nn as nn
from transformers import AutoModelForCausalLM


config["padded_vocab_size"]= 32000
config["n_embd"] = 4096


def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True):


    full_prompt = generate_prompt_qa(example, item = 'question')
    full_prompt_and_response = full_prompt + example["answer"]
    
    # encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False)
    encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length)

    wte=nn.Embedding(config['padded_vocab_size'], config['n_embd'])
    text_ebd = wte(encoded_full_prompt_and_response).unsqueeze(0)

    image_features = example["context"]
    num_images, num_image_patches, embed_dim = image_features.shape 

    batch_size, sequence_length = encoded_full_prompt_and_response.unsqueeze(0).shape

    encoded_full_prompt_and_response_ = encoded_full_prompt_and_response.unsqueeze(0)
    
    image_token_index = encoded_full_prompt_and_response_[:,4]
    
    special_image_token_mask = encoded_full_prompt_and_response_ == encoded_full_prompt_and_response_[:,4]

    num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)

    max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + encoded_full_prompt_and_response_.size(1)

    batch_indices, non_image_indices = torch.where(encoded_full_prompt_and_response_ != image_token_index)

    new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1

    nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
    nb_image_pad = nb_image_pad.cuda()

    text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

    final_embedding = torch.zeros(
        batch_size, max_embed_dim, embed_dim, dtype=image_features.dtype, device=image_features.device)
    
    final_embedding[batch_indices, text_to_overwrite] = text_ebd[batch_indices, non_image_indices].cuda()

    image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
    
    image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None]

    if image_to_overwrite.sum() != image_features.shape[:-1].numel():
        raise ValueError(
            f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
            f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
        )

    final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim)
    
    #create labels
    temp_embedding = torch.zeros(batch_size, max_embed_dim, dtype=encoded_full_prompt_and_response_.dtype)

    ans = generate_prompt_qa(example, item = 'answer')
    encoded_ans = tokenize(tokenizer, ans, eos=True, max_length=max_length)
    encoded_ans = encoded_ans[1:].unsqueeze(0)
    temp_embedding[:,-encoded_ans.size(1):] = encoded_ans

    # The labels are the full prompt with response, but with the prompt masked out
    labels = temp_embedding.clone()
    if mask_inputs:
        labels[:, 0:-encoded_ans.size(1)] = IGNORE_INDEX

    return {**example, "input_ids": final_embedding, "labels": labels}

def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
    return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)


def generate_prompt_qa(example, item):
    """Generates a standardized message to prompt the model with an instruction, optional input and a
    'response' field."""
    if item == "question": 
        return f"Context:\nimage\nQuestion:\n{example[item]}\nAnswer:\n"
    else:
        return f"Answer:\n{example[item]}"


In [4]:
from tqdm import tqdm

IGNORE_INDEX = -1
tokenizer_path = Path("checkpoints/lit-llama/tokenizer.model")
tokenizer = Tokenizer(tokenizer_path)
training_file_path = Path("data/squad2")
train_sample_set = []

for i in range(10):
    sample_set = prepare_sample(train_dataset[i], tokenizer, max_length=512, mask_inputs= True)
    if sample_set == None:
        continue
    else:
        train_sample_set.append(sample_set)
        i += 1
    torch.save(train_sample_set, training_file_path.parent / "train_data_fuse.pt") #"/data/user-data/sa25729/lit_llama_qa/lit_llama_qa/data/train_temp.pt"

def load_datasets(data_dir):
    train_data = torch.load(os.path.join(data_dir, "train_data_fuse.pt"))
    # val_data = torch.load(os.path.join(data_dir, "test.pt"))
    # return train_data, val_data
    return train_data


data_dir = "/data/user-data/sa25729/lit_llama_qa/lit-llama-qa/data/"
train_data = load_datasets(data_dir)
len(train_data), train_data[0].keys(), train_data[0]['input_ids'].size(), train_data[0]['labels'].size()

  train_data = torch.load(os.path.join(data_dir, "train_data_fuse.pt"))


(10,
 dict_keys(['context', 'question', 'answer', 'input_ids', 'labels']),
 torch.Size([1, 608, 4096]),
 torch.Size([1, 608]))

In [5]:
train_data[0]['context'].size()

torch.Size([1, 576, 4096])

In [45]:
from lit_llama.tokenizer import Tokenizer

temp = "hello"

x = torch.tensor([  1, 22430, 31871,    13, 31903,  8326, 31901,    13,  6347, 28312,
        31871,    13,  3272,   397, 13357, 10700,  3984,   269,  1735,  3984,
         3521,    13,  5092,  2055,   265, 31871,    13,   261,   266,   473,
          811,   287,   600,   281,     2])

x_ = torch.tensor([  1, 22430, 31871,    13,   8326, 13])
def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
    return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)

def tokenize_dec(tokenizer: Tokenizer, token: str):
    return tokenizer.decode(token)
    
tokenize(tokenizer, temp, eos=True, max_length=50)
tokenize_dec(tokenizer, x_)


'Context:\nimage\n'

In [89]:
x = torch.tensor([  1, 22430, 31871,    13, 31903,  8326, 31901,    13,  6347, 28312,
        31871,    13,  3272,   397, 13357, 10700,  3984,   269,  1735,  3984,
         3521,    13,  5092,  2055,   265, 31871,    13,   261,   266,   473,
          811,   287,   600,   281,     2]).unsqueeze(0)
x[:,4]

tensor([31903])