In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%pwd

'/ocean/projects/asc170022p/mtragoza/med_vqa'

In [2]:
import sys, os
import numpy as np
import torch
import open_clip
import transformers as T
import datasets
torch.cuda.is_available()

True

In [3]:
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_properties(i).name)

NVIDIA RTX A6000
NVIDIA RTX A6000


In [4]:
%%time
class ImageEncoder(torch.nn.Module):
    
    @classmethod
    def from_name(cls, name, **kwargs):
        if name == 'CLIP':
            url = 'laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K' # (256, 1024)
        elif name == 'PMC-CLIP':
            url = 'ryanyip7777/pmc_vit_l_14'

        model, train_preprocess, val_preprocess = \
            open_clip.create_model_and_transforms(f'hf-hub:{url}', device=torch.device('cuda'))
    
        return cls(model.visual, train_preprocess, val_preprocess, n_patches=256, embed_size=1024, **kwargs)

    def __init__(self, model, train_preprocess, val_preprocess, n_patches, embed_size, embed_type):
        super().__init__()

        self.train_preprocess = train_preprocess
        self.val_preprocess = val_preprocess

        self.model = model
        self.model.output_tokens = True
        self.model.proj = None
    
        self.embed_type = embed_type
        self.embed_size = embed_size
        
        if embed_type == 'both':
            self.seq_length = n_patches + 1
        elif embed_type == 'patch':
            self.seq_length = n_patches
        elif embed_type == 'global':
            self.seq_length = 1
        
    def forward(self, images):

        global_embeddings, patch_embeddings = self.model(images)
        global_embeddings = global_embeddings.unsqueeze(1)
        
        if self.embed_type == 'global':
            image_embeddings = global_embeddings
        elif self.embed_type == 'patch':
            image_embeddings = patch_embeddings  
        elif self.embed_type == 'both':
            image_embeddings = torch.cat([global_embeddings, patch_embeddings], dim=1)
        
        assert image_embeddings.shape[1:] == (self.seq_length, self.embed_size), image_embeddings.shape
        return image_embeddings

image_encoder = ImageEncoder.from_name('PMC-CLIP', embed_type='both')

CPU times: user 7.04 s, sys: 1.36 s, total: 8.4 s
Wall time: 4.7 s


In [5]:
%%time
class TextDecoder(torch.nn.Module):
    
    @classmethod
    def from_name(cls, name):
        if name == 'LLaMA':
            url = 'meta-llama/Llama-2-7b-hf'
        elif name == 'PMC-LLaMA':
            url = 'chaoyi-wu/PMC_LLAMA_7B'
        
        model = T.LlamaForCausalLM.from_pretrained(url, device_map='auto')
        tokenizer = T.LlamaTokenizer.from_pretrained(url)
        
        return cls(model, tokenizer, max_length=512, embed_size=4096)

    def __init__(self, model, tokenizer, max_length, embed_size):
        super().__init__()

        self.model = model
        self.tokenizer = tokenizer

        self.max_length = max_length
        self.embed_size = embed_size
        
    def forward(self, input_embeddings, mask, labels):
        assert input_embeddings.shape[1] <= self.max_length, input_embeddings.shape
        assert input_embeddings.shape[2] == self.embed_size, input_embeddings.shape

        return self.model.forward(
            inputs_embeds=input_embeddings,
            attention_mask=mask,
            labels=labels
        )

    def generate(self, input_embeddings, mask, **kwargs):
        assert input_embeddings.shape[1] <= self.max_length, input_embeddings.shape
        assert input_embeddings.shape[2] == self.embed_size, input_embeddings.shape

        return self.model.generate(
            inputs_embeds=input_embeddings,
            attention_mask=mask,
            **kwargs
        )

text_decoder = TextDecoder.from_name('PMC-LLaMA')

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


CPU times: user 5.87 s, sys: 7.43 s, total: 13.3 s
Wall time: 12 s


In [6]:
class MultimodalFusion(torch.nn.Module):
    
    def __init__(self, image_embed_size, text_embed_size, device):
        super().__init__()
        self.image_embed_size = image_embed_size
        self.text_embed_size = text_embed_size

        self.project_image = torch.nn.Linear(image_embed_size, text_embed_size, device=device)
        
    def forward(self, image_embeddings, text_embeddings, mask):
        batch_size, image_length = image_embeddings.shape[:2]

        image_embeddings = self.project_image(image_embeddings)
        combined_embeddings = torch.cat([image_embeddings, text_embeddings], dim=1)

        image_mask = torch.ones((batch_size, image_length), device=mask.device)
        combined_mask = torch.cat([image_mask, mask], dim=1)

        return combined_embeddings, combined_mask

In [7]:
class VQAModel(torch.nn.Module):
    
    def __init__(self, image_encoder, text_encoder, text_decoder):
        super().__init__()

        self.image_encoder = image_encoder
        self.text_encoder  = text_encoder
        self.text_decoder  = text_decoder

        self.fusion_module = MultimodalFusion(
            image_encoder.embed_size,
            text_decoder.embed_size,
            device=text_decoder.model.device
        )
        
    def combine_multimodal_inputs(self, images, padded_tokens, mask):   
        image_embeddings = self.image_encoder(images)
        text_embeddings  = self.text_encoder(padded_tokens)

        combined_embeddings, combined_mask = self.fusion_module(
            image_embeddings, text_embeddings, mask
        )
        return combined_embeddings, combined_mask
    
    def forward(self, images, padded_tokens, mask):       
        input_embeddings, mask = self.combine_multimodal_inputs(images, padded_tokens, mask)
        
        dummy_tokens = torch.zeros(
            (images.shape[0], self.image_encoder.seq_length),
            dtype=padded_tokens.dtype,
            device=padded_tokens.device
        )
        labels = torch.cat([dummy_tokens, padded_tokens], dim=1)
        
        output = self.text_decoder.forward(input_embeddings, mask, labels)
        return output

    def generate(self, images, padded_tokens, mask, **kwargs):
        input_embeddings, mask = self.combine_multimodal_inputs(images, padded_tokens, mask)
        output = self.text_decoder.generate(input_embeddings, mask, **kwargs)
        return output


model = VQAModel(
    image_encoder=image_encoder,
    text_encoder=text_decoder.model.model.embed_tokens,
    text_decoder=text_decoder
)

In [8]:
class VQADataset(torch.utils.data.Dataset):
    
    @classmethod
    def from_name(cls, name, train_preprocess, val_preprocess, **kwargs):

        if name == 'VQA-RAD':
            url = 'flaviagiammarino/vqa-rad'
            val_split = 'test'
        elif name == 'SLAKE':
            url = 'BoKelvin/SLAKE'
            val_split = 'validation'

        ds = datasets.load_dataset(url, cache_dir='data')
        
        train_set = cls(ds['train'], train_preprocess, **kwargs)
        val_set = cls(ds[val_split], val_preprocess, **kwargs)
        test_set = cls(ds['test'], val_preprocess, **kwargs)

        return train_set, val_set, test_set
            
    def __init__(self, ds, image_preprocess, tokenizer, image_length, max_length, device):
        super().__init__()
        self.ds = ds
        
        # image preprocessor
        self.image_preprocess = image_preprocess
        
        # text tokenizer
        self.tokenizer = tokenizer
        self.tokenizer.pad_token = tokenizer.eos_token

        assert max_length > image_length
        self.image_length = image_length
        self.max_length = max_length
        self.device = device

    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        raw_image = self.ds[idx]['image']
        question = self.ds[idx]['question']
        answer = self.ds[idx]['answer']
        
        image = self.image_preprocess(raw_image)

        prompt = f'Answer the following question based on the provided image.\nQ: {question}\nA: '
                
        prompt_tokens = self.tokenizer.encode(prompt)
        answer_tokens = self.tokenizer.encode(answer)
        
        padded_tokens, mask = self.pad_tokens(prompt_tokens, answer_tokens)

        return (
            torch.tensor(image, device=self.device),
            torch.tensor(padded_tokens, device=self.device),
            torch.tensor(mask, device=self.device)
        )
    
    def pad_tokens(self, prompt_tokens, answer_tokens):
        pad = self.tokenizer.pad_token_id
        tokens = prompt_tokens + answer_tokens
        mask = [1 for i in prompt_tokens] + [0 for i in answer_tokens]
        padding = self.max_length - self.image_length - len(tokens)
        if padding > 0:
            tokens = tokens + [pad for i in range(padding)]
            mask = mask + [0 for i in range(padding)]
        elif padding < 0:
            tokens = tokens[:self.max_length - self.image_length]
            mask = mask[:self.max_length - self.image_length]
        return tokens, mask


train_set, val_set, test_set = VQADataset.from_name(
    'VQA-RAD',
    train_preprocess=model.image_encoder.train_preprocess,
    val_preprocess=model.image_encoder.val_preprocess,
    tokenizer=model.text_decoder.tokenizer,
    image_length=model.image_encoder.seq_length,
    max_length=model.text_decoder.max_length,
    device='cuda'
)

image, padded_tokens, mask = train_set[0]
for t in train_set[0]:
    print(t.shape, t.dtype, t.device)

torch.Size([3, 224, 224]) torch.float32 cuda:0
torch.Size([255]) torch.int64 cuda:0
torch.Size([255]) torch.int64 cuda:0


  torch.tensor(image, device=self.device),


In [9]:
image, padded_tokens, mask = train_set[0]
output = model.generate(
    image.unsqueeze(0), padded_tokens.unsqueeze(0), mask.unsqueeze(0), max_new_tokens=256
)
output

  torch.tensor(image, device=self.device),
  return F.conv2d(input, weight, bias, self.stride,


tensor([[29896, 29900, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896, 29896,
         29896, 29896, 29896, 29896, 29896, 29896, 2

In [10]:
model.text_decoder.tokenizer.decode(output[0])

'1011111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111'

In [None]:
import tqdm
import peft

peft_config = peft.LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=[
        'q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'lm_head'
    ],
    bias='none',
    task_type='CAUSAL_LM'
)
model.text_decoder.model = peft.get_peft_model(model.text_decoder.model, peft_config)

batch_size = 4
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size)

optimizer = torch.optim.AdamW(model.text_decoder.model.parameters(), lr=1e-5)
model.train()

for epoch in range(10):
    total_loss = 0
    for step, batch in enumerate(t:=tqdm.tqdm(train_loader)):
        images, padded_tokens, mask = batch
        outputs = model.forward(images, padded_tokens, mask)
        loss = outputs.loss.detach().float()
        t.set_description(f'loss = {loss:.4f}')
        total_loss += loss
        outputs.loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    total_loss /= len(train_loader)
    print(total_loss)


  torch.tensor(image, device=self.device),
loss = 18.8498:   2%|▏         | 11/449 [00:46<26:57,  3.69s/it] 

In [None]:
image, padded_tokens, mask = dataset[37]
print(llama_tokenizer.decode(padded_tokens))

In [None]:
output = model.generate(
    image.unsqueeze(0), padded_tokens.unsqueeze(0), mask.unsqueeze(0), max_new_tokens=256
)
llama_tokenizer.decode(output[0])