In [2]:
import os
import json
from PIL import Image
from torch.utils.data import Dataset
import re
from torchvision import transforms
import pandas as pd
import numpy as np

import torch
from typing import Tuple
from transformers import AutoTokenizer, CLIPVisionModel
from torch import nn
from collections import OrderedDict
from transformers.activations import ACT2FN

from lit_llama.tokenizer import Tokenizer
from pathlib import Path
from typing import Optional
import torch.nn as nn
from transformers import AutoModelForCausalLM


# process question when input model
def pre_question(question, max_ques_words):
    question = re.sub(
        r"([,.'!?\"()*#:;~])",
        '',
        question.lower(),
    ).replace(' \t', ' ').replace('is/are', 'is').replace('near/in', 'in')
    question = question.replace('>', 'more than ').replace('-yes/no', '')
    question = question.replace('x ray', 'xray').replace('x-ray', 'xray')
    question = question.rstrip(' ')

    # truncate question
    question_words = question.split(' ')
    if len(question_words) > max_ques_words:
        question = ' '.join(question_words[:max_ques_words])

    return question

# process answer when input model
def pre_answer(answer):
    answer = str(answer)
    answer = re.sub(
        r"([,.'!?\"()*#:;~])",
        '',
        answer.lower(),
    ).replace(' \t', ' ')
    answer = answer.replace('x ray', 'xray').replace('x-ray', 'xray')
    answer = answer.replace(' - ', '-')
    return answer

def visual_feature(model, image, proj):
    inputs_img = {'pixel_values':0}
    inputs_img['pixel_values'] = image.unsqueeze(0)
    
    #image feature
    image_features = model(**inputs_img)
    selected_image_features = image_features.last_hidden_state[:, 1:]

    #feature embedding
    img_emd = projector(selected_image_features)
    return img_emd
    

class LlavaMultiModalProjector(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.linear_1 = nn.Linear(config["hidden_size"], config["hidden_size_text"], bias=True)
        self.act = ACT2FN[config["projector_hidden_act"]]
        self.linear_2 = nn.Linear(config["hidden_size_text"], config["hidden_size_text"], bias=True)

    def forward(self, image_features):
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states

def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
    return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)


def generate_prompt_qa(item, content):
    """Generates a standardized message to prompt the model with an instruction, optional input and a
    'response' field."""
    if item == "question": 
        return f"Context:\nImage\nQuestion:\n{content}\nAnswer:\n"
    else:
        return f"Answer:\n{content}"



class vqa_dataset(Dataset):
    def __init__(self, ann_file, transform, vqa_root, eos='[SEP]', max_ques_words=100, max_length = 512,
                 answer_list='', clip_model = None, feature_proj=None, tokenizer = None):
        
        self.ann =  pd.read_pickle(ann_file)

        self.transform = transform
        self.vqa_root = vqa_root
        self.max_ques_words = max_ques_words
        self.eos = eos

        self.clip_model = clip_model
        self.feature_proj = feature_proj
        self.tokenizer = tokenizer
        self.max_length = max_length


    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):

        ann = self.ann[index]

        image_path = os.path.join(self.vqa_root, ann['image'])
        image = Image.open(image_path + '.jpg').convert('RGB')
        image = self.transform(image).cuda()
        
        image_proj = visual_feature(self.clip_model, image, self.feature_proj)
        question = pre_question(ann['question'], self.max_ques_words)
        answers = ann['answer']
        answers = pre_answer(answers)

        full_prompt = generate_prompt_qa('question',question)
        full_prompt_and_response = full_prompt + answers
        encoded_full_prompt = tokenize(self.tokenizer, full_prompt, max_length=self.max_length, eos=False)
        encoded_full_prompt_and_response = tokenize(self.tokenizer, full_prompt_and_response, eos=True, max_length=self.max_length)
        label = encoded_full_prompt_and_response.clone()
        label[:len(encoded_full_prompt)] = -1
        
        
        if len(encoded_full_prompt_and_response)==self.max_length:
            return None

        return {'context':image_proj, 'question':question, 'answer':answers, 'qa_token': encoded_full_prompt_and_response, 'tokens': full_prompt_and_response, 'q_token': encoded_full_prompt, 'label': label}

normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
train_transform = transforms.Compose([
        # transforms.RandomResizedCrop(384, scale=(0.5, 1.0), interpolation=Image.BICUBIC),
        transforms.Resize((224,224), interpolation=Image.BICUBIC),
        transforms.RandomHorizontalFlip(),
        # RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
        #                                       'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
        transforms.ToTensor(),
        normalize,
    ])


train_path = '/data/user-data/sa25729/MICCAI_2024/datasets/pvqa/qas/train/train_qa.pkl'
val_path = '/data/user-data/sa25729/MICCAI_2024/datasets/pvqa/qas/val/val_qa.pkl'
test_path = '/data/user-data/sa25729/MICCAI_2024/datasets/pvqa/qas/test/test_qa.pkl'

train_image_path = '/data/user-data/sa25729/MICCAI_2024/datasets/pvqa/images/train'
val_image_path = '/data/user-data/sa25729/MICCAI_2024/datasets/pvqa/images/val'
test_image_path = '/data/user-data/sa25729/MICCAI_2024/datasets/pvqa/images/test'

config = {}
config["hidden_size"] = 768
config["hidden_size_text"] = 4096
config["projector_hidden_act"] = 'gelu'

model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16").cuda()
model.eval()
projector = LlavaMultiModalProjector(config).cuda()

tokenizer_path = Path("checkpoints/lit-llama/tokenizer.model")
tokenizer = Tokenizer(tokenizer_path)

train_dataset = vqa_dataset(train_path, train_transform, train_image_path, clip_model = model, feature_proj = projector, tokenizer = tokenizer)
val_dataset = vqa_dataset(val_path, train_transform, val_image_path, clip_model = model, feature_proj = projector, tokenizer = tokenizer)
test_dataset = vqa_dataset(test_path, train_transform, test_image_path, clip_model = model, feature_proj = projector, tokenizer = tokenizer)


In [3]:
train_dataset[0]['context'].size(), train_dataset[0].keys(), train_dataset[0]['question'], train_dataset[0]['answer'], len(train_dataset), train_dataset[0]['qa_token']


(torch.Size([1, 196, 4096]),
 dict_keys(['context', 'question', 'answer', 'qa_token', 'tokens', 'q_token', 'label']),
 'where are liver stem cells oval cells located',
 'in the canals of hering',
 19755,
 tensor([    1, 22430, 31871,    13,  9895,    13,  6347, 28312, 31871,    13,
          3272,   397, 13357, 10700,  3984,   269,  1735,  3984,  3521,    13,
          5092,  2055,   265, 31871,    13,   261,   266,   473,   811,   287,
           600,   281,     2], dtype=torch.int32))

In [4]:
val_dataset[0]['context'].size(), val_dataset[0].keys(), val_dataset[0]['question'], val_dataset[0]['answer'], len(val_dataset), val_dataset[0]['qa_token']



(torch.Size([1, 196, 4096]),
 dict_keys(['context', 'question', 'answer', 'qa_token', 'tokens', 'q_token', 'label']),
 'what have lost their nuclei',
 'neutrophils',
 6279,
 tensor([    1, 22430, 31871,    13,  9895,    13,  6347, 28312, 31871,    13,
         11906,   435,  2953,   518, 13323, 31827,    13,  5092,  2055,   265,
         31871,    13,   652,   319,  8982,  3640,     2], dtype=torch.int32))

In [5]:
test_dataset[0]['context'].size(), test_dataset[0].keys(),test_dataset[0]['question'], test_dataset[0]['answer'], len(test_dataset), test_dataset[0]['qa_token']



(torch.Size([1, 196, 4096]),
 dict_keys(['context', 'question', 'answer', 'qa_token', 'tokens', 'q_token', 'label']),
 'what are positively charged  thus allowing the compaction of the negatively charged dna',
 'the histone subunits',
 6761,
 tensor([    1, 22430, 31871,    13,  9895,    13,  6347, 28312, 31871,    13,
         11906,   397, 18807,  6448,  4487,  5905,   266,   520,  2794,   287,
           266, 23162,  6448,   294,  2244,    13,  5092,  2055,   265, 31871,
            13,  1134,  1285,   534,   845,   356,   916,     2],
        dtype=torch.int32))

In [4]:
train_dataset[0]['tokens'] , train_dataset[0]['q_token'], train_dataset[0]['qa_token'], train_dataset[0]['label'], train_dataset[0]['label'].size()
#  1, 22430, 31871,    13, 31903,  8326, 31901,    13,

('Context:\nImage\nQuestion:\nwhere are liver stem cells oval cells located\nAnswer:\nin the canals of hering',
 tensor([    1, 22430, 31871,    13,  9895,    13,  6347, 28312, 31871,    13,
          3272,   397, 13357, 10700,  3984,   269,  1735,  3984,  3521,    13,
          5092,  2055,   265, 31871,    13], dtype=torch.int32),
 tensor([    1, 22430, 31871,    13,  9895,    13,  6347, 28312, 31871,    13,
          3272,   397, 13357, 10700,  3984,   269,  1735,  3984,  3521,    13,
          5092,  2055,   265, 31871,    13,   261,   266,   473,   811,   287,
           600,   281,     2], dtype=torch.int32),
 tensor([ -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,
          -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1, 261, 266, 473,
         811, 287, 600, 281,   2], dtype=torch.int32),
 torch.Size([33]))

In [18]:
from lit_llama.tokenizer import Tokenizer
from pathlib import Path
from typing import Optional
import torch.nn as nn
from transformers import AutoModelForCausalLM


config["padded_vocab_size"]= 32000
config["n_embd"] = 4096

# the 4th position of qa_token, with the position of text "Image"
config["image_token_index"] = 9895 
micro_batch_size = 2
IGNORE_INDEX = -1

def prepare_sample(example: dict, max_length: int, mask_inputs: bool = True):

    ix = torch.randint(len(example), (micro_batch_size,))
    print("ix", ix)

    input_ids = [example[i]["qa_token"].type(torch.int64) for i in ix]
    # input_ids = torch.stack([example[i]["qa_token"].type(torch.int64) for i in ix])
    labels = [example[i]["label"] for i in ix]
    print('input_ids', len(input_ids), len(labels))

    max_len = max(len(s) for s in input_ids)
    print('max_len', max_len)

    def pad_right(x, pad_id):
        # pad right based on the longest sequence
        n = max_len - len(x)
        return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))

    x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
    # y = torch.stack([pad_right(y, pad_id=-1) for y in labels])
    # print("xy", x.size(), y.size())

    wte=nn.Embedding(config['padded_vocab_size'], config['n_embd'])
    text_ebd = wte(x)
    print("emb_x", text_ebd.size())

    # image_features = [example[i]["context"] for i in ix]
    image_features = torch.stack([example[i]["context"].squeeze(0) for i in ix])
    print("image_features", image_features.size(), type(image_features))

    num_images, num_image_patches, embed_dim = image_features.shape
    print(num_images, num_image_patches, embed_dim)

    batch_size, sequence_length = x.shape
    print(batch_size, sequence_length)

    special_image_token_mask = x == config["image_token_index"]
    # print("x", x)
    # print("special_image_token_mask", special_image_token_mask)

    num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
    print("num_special_image_tokens", num_special_image_tokens)

    max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + x.size(1)
    print("max_embed_dim", max_embed_dim)

    batch_indices, non_image_indices = torch.where(x != config["image_token_index"])
    print("batch_indices, non_image_indices", batch_indices, non_image_indices)

    new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
    print("new_token_positions", new_token_positions)

    nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
    nb_image_pad = nb_image_pad.cuda()
    print("nb_image_pad", nb_image_pad)
    
    text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
    print("text_to_overwrite", text_to_overwrite.size())

    final_embedding = torch.zeros(
        batch_size, max_embed_dim, embed_dim, dtype=image_features.dtype, device=image_features.device)

    final_embedding[batch_indices, text_to_overwrite] = text_ebd[batch_indices, non_image_indices].cuda()
    print('final_embedding[batch_indices, text_to_overwrite]', final_embedding[batch_indices, text_to_overwrite].size())

    image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
    print("image_to_overwrite", image_to_overwrite.size())
    
    image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None]
    # print("image_to_overwrite1", image_to_overwrite)

    if image_to_overwrite.sum() != image_features.shape[:-1].numel():
        raise ValueError(
            f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
            f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
        )

    final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim)
    print("final_embedding[image_to_overwrite]", final_embedding[image_to_overwrite].size())

    print("final_embedding", final_embedding.size())

    #getting labels
    temp_ebd_label = torch.zeros(batch_size, max_embed_dim, device=image_features.device).long() - 1
    print("temp_ebd_label", temp_ebd_label.size())

    labels = []
    for i in range(len(ix)):
        ids = ix[i]
        print("ids", ids)
        temp_len = len(example[ids]["label"])
        temp_ebd_label[i][-temp_len:] = example[ids]["label"]
        labels.append(temp_ebd_label[i])
    labels_ = torch.stack([y for y in labels])
    print("labels_", labels_.size(), labels_)
        
        
    
    

    return { "input_ids": final_embedding, "labels": labels_}

sample_set = prepare_sample(example = train_dataset, max_length=512, mask_inputs= True)

ix tensor([ 6406, 12253])
input_ids 2 2
max_len 24
emb_x torch.Size([2, 24, 4096])
image_features torch.Size([2, 196, 4096]) <class 'torch.Tensor'>
2 196 4096
2 24
num_special_image_tokens tensor([1, 1])
max_embed_dim tensor(219)
batch_indices, non_image_indices tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) tensor([ 0,  1,  2,  3,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23,  0,  1,  2,  3,  5,  6,  7,  8,  9, 10, 11, 12, 13,
        14, 15, 16, 17, 18, 19, 20, 21, 22, 23])
new_token_positions tensor([[  0,   1,   2,   3, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208,
         209, 210, 211, 212, 213, 214, 215, 216, 217, 218],
        [  0,   1,   2,   3, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208,
         209, 210, 211, 212, 213, 214, 215, 216, 217, 218]])
nb_image_pad tensor([0, 0], device='cuda:0')
text_to_overwrite torch.Size(

In [19]:
sample_set

{'input_ids': tensor([[[ 0.9946,  0.2502, -0.1851,  ...,  0.0617, -0.4781, -1.6970],
          [ 0.2183,  1.5907,  2.1551,  ...,  0.7390, -0.0546,  0.5157],
          [-1.0250,  0.0686, -1.0518,  ..., -1.0820,  0.8993,  1.0964],
          ...,
          [ 1.3632, -1.1058, -0.6484,  ...,  1.0083, -0.6516, -0.7756],
          [ 0.3430,  0.6047, -0.7093,  ...,  0.1265,  0.9999,  0.8103],
          [-1.2783,  0.8782,  0.4749,  ..., -0.3862,  0.1837, -0.0648]],
 
         [[ 0.9946,  0.2502, -0.1851,  ...,  0.0617, -0.4781, -1.6970],
          [ 0.2183,  1.5907,  2.1551,  ...,  0.7390, -0.0546,  0.5157],
          [-1.0250,  0.0686, -1.0518,  ..., -1.0820,  0.8993,  1.0964],
          ...,
          [-0.5672,  2.1402, -0.2583,  ..., -0.0036,  0.7955,  0.3287],
          [ 1.1418, -0.8197,  1.1176,  ...,  0.1624,  0.0037, -1.6648],
          [ 0.3430,  0.6047, -0.7093,  ...,  0.1265,  0.9999,  0.8103]]],
        device='cuda:0', grad_fn=<IndexPutBackward0>),
 'labels': tensor([[  -1,   -1,  

In [14]:
#getting labels and keep other positions masked
x = train_dataset[0]['q_token']
y = train_dataset[0]['qa_token']
y[:len(x)] = -1
y

tensor([ -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,
         -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1, 261, 266, 473,
        811, 287, 600, 281,   2], dtype=torch.int32)

In [17]:
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
model_path = Path("checkpoints/lit-llama/tokenizer.model")
processor = SentencePieceProcessor(model_file=str(model_path))
processor.bos_id(), processor.eos_id(), processor.pad_id()

(1, 2, -1)

In [4]:
from tqdm import tqdm

IGNORE_INDEX = -1
tokenizer_path = Path("checkpoints/lit-llama/tokenizer.model")
tokenizer = Tokenizer(tokenizer_path)
training_file_path = Path("data/squad2")
train_sample_set = []

for i in range(50):
    sample_set = prepare_sample(train_dataset[i], tokenizer, max_length=512, mask_inputs= True)
    if sample_set == None:
        continue
    else:
        train_sample_set.append(sample_set)
        i += 1
    torch.save(train_sample_set, training_file_path.parent / "train_data_fuse_v1.pt") #"/data/user-data/sa25729/lit_llama_qa/lit_llama_qa/data/train_temp.pt"

def load_datasets(data_dir):
    train_data = torch.load(os.path.join(data_dir, "train_data_fuse_v1.pt"))
    # val_data = torch.load(os.path.join(data_dir, "test.pt"))
    # return train_data, val_data
    return train_data


data_dir = "/data/user-data/sa25729/lit_llama_qa/lit-llama-qa/data/"
train_data = load_datasets(data_dir)
len(train_data), train_data[0].keys(), train_data[0]['input_ids'].size(), train_data[0]['labels'].size()

  train_data = torch.load(os.path.join(data_dir, "train_data_fuse_v1.pt"))


(50,
 dict_keys(['context', 'question', 'answer', 'input_ids', 'labels']),
 torch.Size([1, 228, 4096]),
 torch.Size([1, 228]))

In [5]:
train_data[0]['context'].size()

torch.Size([1, 196, 4096])

In [45]:
from lit_llama.tokenizer import Tokenizer

temp = "hello"

x = torch.tensor([  1, 22430, 31871,    13, 31903,  8326, 31901,    13,  6347, 28312,
        31871,    13,  3272,   397, 13357, 10700,  3984,   269,  1735,  3984,
         3521,    13,  5092,  2055,   265, 31871,    13,   261,   266,   473,
          811,   287,   600,   281,     2])

x_ = torch.tensor([  1, 22430, 31871,    13,   8326, 13])
def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
    return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)

def tokenize_dec(tokenizer: Tokenizer, token: str):
    return tokenizer.decode(token)
    
tokenize(tokenizer, temp, eos=True, max_length=50)
tokenize_dec(tokenizer, x_)


'Context:\nimage\n'

In [89]:
x = torch.tensor([  1, 22430, 31871,    13, 31903,  8326, 31901,    13,  6347, 28312,
        31871,    13,  3272,   397, 13357, 10700,  3984,   269,  1735,  3984,
         3521,    13,  5092,  2055,   265, 31871,    13,   261,   266,   473,
          811,   287,   600,   281,     2]).unsqueeze(0)
x[:,4]

tensor([31903])

In [1]:
from torchvision import transforms

In [None]:

    

    # wte=nn.Embedding(config['padded_vocab_size'], config['n_embd'])
    # text_ebd = wte(encoded_full_prompt_and_response).unsqueeze(0)

    # image_features = example["context"]
    # num_images, num_image_patches, embed_dim = image_features.shape 

    # batch_size, sequence_length = encoded_full_prompt_and_response.unsqueeze(0).shape

    # encoded_full_prompt_and_response_ = encoded_full_prompt_and_response.unsqueeze(0)
    
    # image_token_index = encoded_full_prompt_and_response_[:,4]
    
    # special_image_token_mask = encoded_full_prompt_and_response_ == encoded_full_prompt_and_response_[:,4]

    # num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)

    # max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + encoded_full_prompt_and_response_.size(1)

    # batch_indices, non_image_indices = torch.where(encoded_full_prompt_and_response_ != image_token_index)

    # new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1

    # nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
    # nb_image_pad = nb_image_pad.cuda()

    # text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

    # final_embedding = torch.zeros(
    #     batch_size, max_embed_dim, embed_dim, dtype=image_features.dtype, device=image_features.device)
    
    # final_embedding[batch_indices, text_to_overwrite] = text_ebd[batch_indices, non_image_indices].cuda()

    # image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
    
    # image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None]

    # if image_to_overwrite.sum() != image_features.shape[:-1].numel():
    #     raise ValueError(
    #         f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
    #         f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
    #     )

    # final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim)
    