In [1]:
import os
import json
from PIL import Image
from torch.utils.data import Dataset
import re
from torchvision import transforms
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from typing import Tuple
from transformers import AutoProcessor, CLIPModel


# process question when input model
def pre_question(question, max_ques_words):
    question = re.sub(
        r"([,.'!?\"()*#:;~])",
        '',
        question.lower(),
    ).replace(' \t', ' ').replace('is/are', 'is').replace('near/in', 'in')
    question = question.replace('>', 'more than ').replace('-yes/no', '')
    question = question.replace('x ray', 'xray').replace('x-ray', 'xray')
    question = question.rstrip(' ')

    # truncate question
    question_words = question.split(' ')
    if len(question_words) > max_ques_words:
        question = ' '.join(question_words[:max_ques_words])

    return question

# process answer when input model
def pre_answer(answer):
    answer = str(answer)
    answer = re.sub(
        r"([,.'!?\"()*#:;~])",
        '',
        answer.lower(),
    ).replace(' \t', ' ')
    answer = answer.replace('x ray', 'xray').replace('x-ray', 'xray')
    answer = answer.replace(' - ', '-')
    return answer

def visual_feature(model, image, proj):
    inputs_img = {'pixel_values':0}
    inputs_img['pixel_values'] = image.unsqueeze(0)
    image_features = model.get_image_features(**inputs_img)
    image_features_proj = proj(image_features)
    return image_features_proj
    
#reference of paper or github repo
class img_projection(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(img_projection, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(nn.Dropout(p=0.5))
                layers.append(act())
        self.model = nn.Sequential(*layers)


class vqa_dataset(Dataset):
    def __init__(self, ann_file, transform, vqa_root, eos='[SEP]', split="train", max_ques_words=30,
                 answer_list='', clip_model = None, feature_proj=None):
        self.split = split
        self.ann =  pd.read_pickle(ann_file)

        self.transform = transform
        self.vqa_root = vqa_root
        self.max_ques_words = max_ques_words
        self.eos = eos

        self.clip_model = clip_model
        self.feature_proj = feature_proj

        if split == 'test':
            self.max_ques_words = 50  # do not limit question length during test
            self.answer_list = json.load(open(answer_list, 'r'))

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):

        ann = self.ann[index]

        image_path = os.path.join(self.vqa_root, ann['image'])
        image = Image.open(image_path + '.jpg').convert('RGB')
        image = self.transform(image)
        image_proj = visual_feature(self.clip_model, image, self.feature_proj)
        # prompt_info = ', its organ is {}, the type of answer is {}, the type of question is {}'\
        #               .format(ann['image_organ'], ann['answer_type'], ann['question_type'])
        # ann['question'] = ann['question'] + prompt_info


        if self.split == 'test':
            question = pre_question(ann['question'], self.max_ques_words)
            question_id = ann['qid']
            return image, question, question_id

        elif self.split == 'train':

            question = pre_question(ann['question'], self.max_ques_words)
            answers = ann['answer']
            answers = pre_answer(answers)
            # answers = [pre_answer(answers)]
            # answers = [answer + self.eos for answer in answers]

            return {'context':image_proj, 'question':question, 'answer':answers}

normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
train_transform = transforms.Compose([
        # transforms.RandomResizedCrop(384, scale=(0.5, 1.0), interpolation=Image.BICUBIC),
        transforms.Resize((224,224), interpolation=Image.BICUBIC),
        transforms.RandomHorizontalFlip(),
        # RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
        #                                       'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
        transforms.ToTensor(),
        normalize,
    ])


train_path = '/data/user-data/sa25729/MICCAI_2024/datasets/pvqa/qas/train/train_qa.pkl'
image_path = '/data/user-data/sa25729/MICCAI_2024/datasets/pvqa/images/train'

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
proj = img_projection((512, 128))

train_dataset = vqa_dataset(train_path, train_transform, image_path, split='train', clip_model = model, feature_proj = proj)

2024-09-27 17:38:12.458547: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
train_dataset[0]['context'].size(), train_dataset[0], type(train_dataset[0]['question']), type(train_dataset[0]['answer'])

(torch.Size([1, 128]),
 {'context': tensor([[ 0.2929,  0.2182, -0.1688,  0.3002,  0.1204,  0.1557,  0.0088, -0.5786,
           -0.3105,  0.3801,  0.3762, -0.2273, -0.2410,  0.2423,  0.5150,  0.4706,
            0.3800, -0.0644,  0.0296, -0.3151, -0.0946, -0.0570,  0.0251, -0.0463,
            0.2365,  0.0978, -0.6446, -0.0106, -0.2707, -0.2651,  0.2845,  0.2669,
           -0.2838,  0.1725,  0.3792, -0.4023, -0.5583,  0.0692,  0.3028, -0.1613,
            0.0449,  0.1432,  0.5284,  0.4707, -0.1834, -0.0068, -0.2882,  0.4106,
            0.0736, -0.0572,  0.0292,  0.3399, -0.4908,  0.0587, -0.0130,  0.0515,
           -0.2022, -0.2150, -0.2021, -0.0681,  0.4435, -0.1914, -0.0609,  0.2424,
            0.1118, -0.1728, -0.3249, -0.1466, -0.2586,  0.1690,  0.1559, -0.3386,
           -0.3448,  0.0263, -0.1364, -0.0922,  0.1560,  0.1121,  0.1112, -0.3842,
            0.1458,  0.2866,  0.0873, -0.3112,  0.4269, -0.0166,  0.4459,  0.1424,
           -0.1220, -0.2386, -0.5953,  0.0027,  0.311

In [4]:
from lit_llama.tokenizer import Tokenizer
from pathlib import Path

def prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool = True):
    
    full_prompt = generate_prompt_qa(example, item = 'question')
    full_prompt_and_response = full_prompt + example["answer"]
    # print('full_prompt_and_response', full_prompt_and_response)
    
    # encoded_full_prompt = tokenize(tokenizer, full_prompt, max_length=max_length, eos=False)
    encoded_full_prompt_and_response = tokenize(tokenizer, full_prompt_and_response, eos=True, max_length=max_length)
    # print('encoded_full_prompt_and_response', encoded_full_prompt_and_response.size(), encoded_full_prompt_and_response)

    ans = generate_prompt_qa(example, item = 'answer')
    encoded_ans = tokenize(tokenizer, ans, eos=True, max_length=max_length)
    # print('encoded_ans', encoded_ans, len(encoded_ans))
    
    #concating the img feature with tokens
    x = example["context"].transpose(0,1)
    y = encoded_full_prompt_and_response.unsqueeze(0).transpose(0,1)
    cqa = torch.cat((x, y), dim=0).transpose(0,1)
    # print(x.size(), y.size(), cqa.size())
    
    if len(cqa)==max_length:
        return None
    # print(f"Length of tokens: {len(encoded_full_prompt_and_response)}")

    # The labels are the full prompt with response, but with the prompt masked out
    labels = cqa.clone()
    
    if mask_inputs:
        labels[:, 0:-len(encoded_ans)] = IGNORE_INDEX

    return {**example, "input_ids": cqa, "labels": labels}

def tokenize(tokenizer: Tokenizer, string: str, max_length: int, eos=True) -> torch.Tensor:
    return tokenizer.encode(string, bos=True, eos=eos, max_length=max_length)


def generate_prompt_qa(example, item):
    """Generates a standardized message to prompt the model with an instruction, optional input and a
    'response' field."""
    if item == "question": 
        return f"### Question:\n{example[item]}\n\n### Answer:\n"
    else:
        return f"### Answer:\n{example[item]}"
    # return f"### Context:\n{example['context']}\n\n### Question:\n{example['question']}\n\n### Answer:\n"


In [5]:
from tqdm import tqdm

IGNORE_INDEX = -1
tokenizer_path = Path("checkpoints/lit-llama/tokenizer.model")
tokenizer = Tokenizer(tokenizer_path)
training_file_path = Path("data/squad2")
train_sample_set = []

for i in range(50):
    sample_set = prepare_sample(train_dataset[i], tokenizer, max_length=512, mask_inputs= True)
    if sample_set == None:
        continue
    else:
        train_sample_set.append(sample_set)
        i += 1
    torch.save(train_sample_set, training_file_path.parent / "train_temp.pt") #"/data/user-data/sa25729/lit_llama_qa/lit_llama_qa/data/train_temp.pt"

# for sample in tqdm(train_dataset):
#     sample_set = prepare_sample(sample, tokenizer, max_length=512, mask_inputs= True)
#     if sample_set == None:
#         continue
#     else:
#         train_sample_set.append(sample_set)
#     torch.save(train_sample_set, training_file_path.parent / "train_temp.pt")

In [8]:
def load_datasets(data_dir):
    train_data = torch.load(os.path.join(data_dir, "train_temp.pt"))
    # val_data = torch.load(os.path.join(data_dir, "test.pt"))
    # return train_data, val_data
    return train_data


data_dir = "/data/user-data/sa25729/lit_llama_qa/lit-llama-qa/data/"
train_data = load_datasets(data_dir)
len(train_data), train_data[0].keys(), train_data[0]['input_ids'].size(), train_data[0]['labels'].size()

(50,
 dict_keys(['context', 'question', 'answer', 'input_ids', 'labels']),
 torch.Size([1, 157]),
 torch.Size([1, 157]))

In [9]:
train_data[0]['labels'].type(torch.int64), train_data[0]['input_ids'].type(torch.int64),

(tensor([[   -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
             -1,    -1,    -1,    -1,   

In [10]:
micro_batch_size =1

def get_batch(data: list):
    ix = torch.randint(len(data)-1, (micro_batch_size,))
    print('ix', ix) 
    
    input_ids = [data[i]["input_ids"].type(torch.float32) for i in ix] #torch.int64
    labels = [data[i]["labels"].type(torch.int64) for i in ix]

    # print('input_ids',len(input_ids) ,input_ids[0].size(), input_ids[1].size(), input_ids[2].size(), input_ids[3].size(), input_ids[4].size())
    # print('labels', len(labels),labels[0].size(), labels[1].size(), labels[2].size(), labels[3].size(), labels[4].size())
    
    print('input_ids',len(input_ids) ,input_ids[0].size()) 
    print('labels', len(labels),labels[0].size())

    # print('input_ids',input_ids ,input_ids[0].size()) 
    # print('labels', labels,labels[0].size())

    max_len = max(len(s) for s in input_ids)
    print('max_len', max_len)

    def pad_right(x, pad_id):
        # pad right based on the longest sequence
        n = max_len - len(x)
        return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))

    x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
    y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
    # x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))

    print('x', x.size())
    print('x', y.size())
    # print('x',len(x) ,x[0].size(), x[1].size(), x[2].size(), x[3].size(), x[4].size())
    # print('y', len(y),y[0].size(), y[1].size(), y[2].size(), y[3].size(), y[4].size())
    return x, y

input_ids, targets = get_batch(train_data)

ix tensor([19])
input_ids 1 torch.Size([1, 157])
labels 1 torch.Size([1, 157])
max_len 1
x torch.Size([1, 1, 157])
x torch.Size([1, 1, 157])


In [11]:
input_ids, targets


(tensor([[[ 7.3019e-02,  3.0004e-01, -2.1092e-01,  2.7332e-01,  5.6172e-02,
            3.1421e-01,  2.0004e-01, -3.9866e-01, -3.6370e-01,  3.8636e-03,
            4.6104e-01, -3.8067e-01, -2.3545e-01,  1.6375e-01,  2.8970e-01,
            5.8310e-01,  2.5817e-01,  2.2851e-01,  1.7788e-01, -3.9559e-01,
           -8.2401e-02, -2.1959e-02, -7.0833e-02,  1.3922e-01,  2.1348e-01,
            2.0838e-01, -5.1258e-02,  2.0160e-01, -4.7469e-01, -2.2535e-01,
            3.8030e-02,  2.6900e-01, -7.5497e-01,  2.4113e-01,  1.7330e-01,
           -5.0797e-01, -4.2867e-01,  5.5541e-02,  4.3060e-01,  1.2487e-01,
            6.1188e-02,  3.0710e-01,  2.7663e-01,  6.8004e-01, -1.5852e-01,
           -1.0050e-01, -4.2995e-01,  3.5580e-01,  1.6571e-01, -1.7752e-01,
            6.5570e-02,  5.1389e-02, -5.8529e-01, -4.5564e-02, -1.8349e-01,
           -9.3535e-02, -1.9400e-01, -4.8248e-02, -2.2423e-01,  7.4156e-02,
            1.8327e-01,  6.2321e-03, -1.2860e-01,  1.9064e-01,  1.8702e-01,
           -

In [49]:
!ls 

ls: cannot access '/data/user-data/sa25729/lit_llama_qa/data/train_temp.pt': No such file or directory


In [45]:
!pwd

/data/user-data/sa25729/lit_llama_qa/lit-llama-qa


In [31]:
sample = train_dataset[0]
IGNORE_INDEX = -1
tokenizer_path = Path("checkpoints/lit-llama/tokenizer.model")
tokenizer = Tokenizer(tokenizer_path)

sample_set = prepare_sample(sample, tokenizer, max_length=512, mask_inputs= True)


sample_set

{'context': tensor([[-0.1911, -0.2390, -0.1190, -0.4613, -0.2316,  0.0827, -0.2930,  0.2204,
          -0.1746,  0.5785, -0.0113, -0.2014, -0.3710,  0.0827, -0.2906, -0.4311,
          -0.1014,  0.1405, -0.0646, -0.2600, -0.4056,  0.2690, -0.1107,  0.3428,
          -0.3609,  0.5873, -0.0112, -0.3733,  0.1605,  0.3588,  0.0437,  0.0793,
          -0.1228, -0.0583, -0.0052,  0.0739, -0.0917,  0.2080, -0.3682, -0.1224,
           0.3304, -0.1189, -0.1956, -0.4890,  0.1407, -0.0748, -0.0239, -0.1615,
           0.1607, -0.2159, -0.1074,  0.1661,  0.1009,  0.2857,  0.4329,  0.5538,
           0.2900,  0.0875,  0.1731,  0.0848,  0.2790,  0.2409,  0.0900, -0.3946,
           0.1820,  0.0960, -0.1258, -0.3767,  0.0630, -0.1480,  0.1856,  0.1505,
          -0.1298,  0.2001, -0.0829, -0.3554,  0.0613,  0.0326, -0.0554, -0.0528,
          -0.1541, -0.1876,  0.2919,  0.0029,  0.1859,  0.1724,  0.2765, -0.1543,
           0.1515,  0.4419,  0.4060,  0.3493,  0.0931,  0.0111,  0.1780, -0.2226,
     

In [18]:
def generate_prompt_qa_(example, item):
    """Generates a standardized message to prompt the model with an instruction, optional input and a
    'response' field."""
    if item == "question": 
        return f"### Question:\n{example[item]}\n\n### Answer:\n"
    else:
        return example[item]
    # return f"### Context:\n{example['context']}\n\n### Question:\n{example['question']}\n\n### Answer:\n"

temp = generate_prompt_qa_(train_dataset[0], item = 'context')
type(temp)

torch.Tensor

In [45]:


sample_set

full_prompt_and_response ### Question:
where are liver stem cells oval cells located

### Answer:
in the canals of hering
encoded_full_prompt_and_response torch.Size([29]) tensor([    1, 16121, 12945, 31871,    13,  3272,   397, 13357, 10700,  3984,
          269,  1735,  3984,  3521,    13,    13,  8458, 31922, 19775, 31871,
           13,   261,   266,   473,   811,   287,   600,   281,     2],
       dtype=torch.int32)
x torch.Size([128, 1])
y torch.Size([29, 1]) tensor([[    1],
        [16121],
        [12945],
        [31871],
        [   13],
        [ 3272],
        [  397],
        [13357],
        [10700],
        [ 3984],
        [  269],
        [ 1735],
        [ 3984],
        [ 3521],
        [   13],
        [   13],
        [ 8458],
        [31922],
        [19775],
        [31871],
        [   13],
        [  261],
        [  266],
        [  473],
        [  811],
        [  287],
        [  600],
        [  281],
        [    2]], dtype=torch.int32)
torch.Size([128,

{'context': tensor([[-0.1055,  0.2734,  0.1262, -0.2347,  0.0028,  0.0040, -0.1914,  0.0465,
           0.2062,  0.1054, -0.2364,  0.3519, -0.1473,  0.0259,  0.0820,  0.0109,
           0.0139,  0.3067, -0.2592,  0.0848,  0.2514,  0.3518,  0.4497, -0.2951,
          -0.3130, -0.1118, -0.0941,  0.3054, -0.5071,  0.0687, -0.1363,  0.4278,
          -0.4031, -0.1676,  0.0071,  0.0520,  0.3752,  0.1141,  0.3719, -0.3148,
          -0.2305, -0.0538, -0.0989,  0.2517,  0.1017, -0.2824,  0.1597, -0.2711,
           0.1097,  0.2318,  0.1044,  0.1743,  0.3498,  0.2412, -0.1490,  0.1832,
           0.1183, -0.0080, -0.1818, -0.1111,  0.4689, -0.1080, -0.1939, -0.3927,
           0.0463, -0.0191, -0.0605, -0.1500,  0.2321,  0.1685, -0.1704,  0.1670,
          -0.5345,  0.2472,  0.6871, -0.0969,  0.1618, -0.3084,  0.2111, -0.0923,
           0.0460,  0.3786, -0.3205, -0.0489,  0.1974, -0.0695,  0.1415,  0.1238,
           0.3380, -0.0533, -0.1027,  0.1114, -0.1339,  0.2812, -0.1582, -0.3171,
     

In [None]:


tokenizer_path = Path("checkpoints/lit-llama/tokenizer.model")
tokenizer = Tokenizer(tokenizer_path)

train_sample_set = []
for sample in tqdm(train_dataset):
        sample_set = prepare_sample(sample, tokenizer, max_seq_length=512, mask_inputs= True)
        if sample_set == None:
            continue
        else:
            train_sample_set.append(sample_set)
    torch.save(train_sample_set, training_file_path.parent / "train_temp.pt")

In [None]:




inputs1 = {'pixel_values':0}
inputs1['pixel_values'] = train_loader.dataset[5][0].unsqueeze(0)
print(inputs1['pixel_values'].shape)
image_features = model.get_image_features(**inputs1)
print(image_features.shape)

proj = img_projection((512, 128))
image_features_proj = proj(image_features)
print(image_features_proj.shape)