In [3]:
from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration, Blip2Model, Blip2Processor, Blip2ForConditionalGeneration, Blip2Config
import time
import torch
import accelerate
from torch.nn import functional as F
from typing import Any, Optional, Tuple, Union
from dataclasses import dataclass

In [None]:
# Download BLIP2 weights
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16)
# model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16)

In [7]:
# Utils

In [18]:
def forward_from_image_query_output(
    model,
    language_model_inputs: torch.FloatTensor,
    input_ids: torch.FloatTensor,
    attention_mask: Optional[torch.LongTensor] = None,
    labels: Optional[torch.LongTensor] = None,
):
    output_attentions = None
    output_hidden_states = None
    return_dict = model.config.use_return_dict
    
    # step 3: use the language model, conditioned on the query outputs and the prompt
    language_model_attention_mask = torch.ones(
        language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
    )
    inputs_embeds = model.language_model.get_input_embeddings()(input_ids)
    inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)

    if attention_mask is None:
        attention_mask = torch.ones_like(input_ids)
    expected_device = language_model_attention_mask.device
    attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1)

    outputs = model.language_model(
        inputs_embeds=inputs_embeds,
        attention_mask=attention_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    logits = outputs.logits if return_dict else outputs[0]
    loss = None
    # we compute the loss here since we need to take into account the sequence length of the query embeds
    if labels is not None:
        labels = labels.to(logits.device)
        logits = logits[:, -labels.size(1) :, :]
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous().to(logits.device)

        # Flatten the tokens
        loss_fct = CrossEntropyLoss(reduction="mean")

        loss = loss_fct(shift_logits.view(-1, model.config.text_config.vocab_size), shift_labels.view(-1))

    if not return_dict:
        output = (logits, outputs)
        return ((loss,) + output) if loss is not None else output

    return {
        'loss': loss,
        'logits': logits,
        'language_model_outputs': outputs,
    }

def image_encoding(
    model,
    processor,
    prompt,
):
    inputs = processor(image, curr_prompt, return_tensors="pt").to(0, torch.float16)
    pixel_values = inputs.pixel_values

    decoder_input_ids = None
    decoder_attention_mask = None
    output_attentions = None
    output_hidden_states = None
    labels = None
    return_dict = model.config.use_return_dict

    # step 1: forward the images through the vision encoder,
    # to get image embeddings of shape (batch_size, seq_len, hidden_size)
    vision_outputs = model.vision_model(
        pixel_values=pixel_values,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    image_embeds = vision_outputs[0]
    print(f'Image Embedds: {image_embeds.shape}')

    # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
    image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

    query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
    query_outputs = model.qformer(
        query_embeds=query_tokens,
        encoder_hidden_states=image_embeds,
        encoder_attention_mask=image_attention_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    query_output = query_outputs[0]

    # step 3: use the language model, conditioned on the query outputs and the prompt
    return model.language_projection(query_output), vision_outputs

# Load in Data

In [14]:
INPUT_FILE = '../datasets/full_ds.csv'
BATCH_SIZE = 2
BLOCK_SIZE = 250
EPOCHS = 1

#  Fine Tuning

In [None]:
from torch import nn
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE)
optimizer = torch.optim.AdamW(model.parameters())


epochs = 1
for _ in range(epochs):
  for idx, batch in enumerate(train_dataloader):
    optimizer.zero_grad()
    out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"])
    loss = out.loss
    loss.backward()
    optimizer.step()
    print(f"Loss at step {idx} = {loss}")
    if idx >= MAX_STEPS-1:
      break

In [None]:
training_data = torch.tensor([]) # [B,T,C]

for _ in range(epochs):
  for idx, batch in enumerate(train_dataloader):
    optimizer.zero_grad()

    # Input_IDS
    input_ids = batch["input_ids"] # [T,C]
    pixel_ids = batch["pixel_ids"] # [T,C]
    labels = batch["labels"] # [T,1]
    
    
    out = model(text = input_ids, image = pixel_ids, labels = labels)
    loss = out.loss
    loss.backward()
    optimizer.step()
    print(f"Loss at step {idx} = {loss}")
    if idx >= MAX_STEPS-1:
      break