In [1]:
# Imports
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoProcessor, LlavaForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm


### Load dataset

In [2]:
# Subset (Only for testing purpose)
data = load_dataset("AbdulMuqtadir/Doc_VQA_subset", split="train")
data

Dataset({
    features: ['question', 'docId', 'answers', 'data_split', 'bounding_boxes', 'word_list', 'image_raw', 'ground_truth'],
    num_rows: 10
})

In [11]:
sample = data[0]
query = sample['question']
answer = sample['answers'][0]
image = sample['image_raw']
query, answer, image

('What is the Voucher Number ?',
 '8',
 <PIL.PngImagePlugin.PngImageFile image mode=L size=1490x653>)

In [None]:
prompt = (
    "USER: <image>\n"
    f"Question: {query}\n"
    "ASSISTANT:"
        )



In [13]:
inputs = processor(
    text=prompt,
    images=image,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=1024,
)

In [18]:
input_ids = inputs["input_ids"]               # [1,1024]
attention_mask = inputs["attention_mask"]     # [1,1024]
pixel_values = inputs["pixel_values"]         # [1,3,336,336]

In [23]:
ids = input_ids.squeeze(0)
text = processor.tokenizer.decode(ids, skip_special_tokens=False)
text


'<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad

### Load Processor and Model

In [2]:
# load processor
processor = AutoProcessor.from_pretrained('llava-hf/llava-1.5-7b-hf')

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


### Dataset preprocessing and DataLoader Setup

In [5]:
from llavadataset import LlavaDataset

In [6]:
dataset = LlavaDataset("AbdulMuqtadir/Doc_VQA_subset", processor)

In [14]:
sample = dataset[0]

input_ids = sample["input_ids"]
attention_mask = sample["attention_mask"]
pixel_values = sample["pixel_values"]
labels = sample["labels"]

print(input_ids.shape, attention_mask.shape, pixel_values.shape, labels.shape)


torch.Size([1024]) torch.Size([1024]) torch.Size([3, 336, 336]) torch.Size([1024])


In [16]:
from torch.utils.data import DataLoader
training_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
training_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7ed14a62b650>

In [19]:
for batch_idx, batch in enumerate(training_dataloader):
    input_ids = batch["input_ids"]         # shape: [batch_size, 1024]
    attention_mask = batch["attention_mask"]  # shape: [batch_size, 1024]
    pixel_values = batch["pixel_values"]      # shape: [batch_size, 3, 336, 336]
    labels = batch["labels"]                 # shape: [batch_size, 1024]

    print("input_ids shape:", input_ids.shape)
    print("attention_mask shape:", attention_mask.shape)
    print("pixel_values shape:", pixel_values.shape)
    print("labels shape:", labels.shape)
    break  # only look at the first batch


input_ids shape: torch.Size([4, 1024])
attention_mask shape: torch.Size([4, 1024])
pixel_values shape: torch.Size([4, 3, 336, 336])
labels shape: torch.Size([4, 1024])


### Training Loop

In [None]:
# Imports
import pytorch_lightning as pl

class LlavaTraining(pl.LightningModule):
    def __init__(self, config, model, processor):
        super.__init__()
        self.config = config
        self.processor = processor
        self.model = model
    
    def training_step(self, batch, batch_idx):
        
        # 1. Extract the inputs
        input_ids = batch["input_ids"]         # shape: [batch_size, 1024]
        attention_mask = batch["attention_mask"]  # shape: [batch_size, 1024]
        pixel_values = batch["pixel_values"]      # shape: [batch_size, 3, 336, 336]
        labels = batch["labels"]                 # shape: [batch_size, 1024]

        # 2. Forward pass
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            labels=labels
        )

        # 3. compute loss
        loss = output.loss

        # 4. log the loss
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config.epochs)
        return [optimizer], [scheduler]

