In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# General dependencies
from tqdm.notebook import tqdm
from functools import partial

# Deep learning dependencies
import torch
from torch.utils.data import random_split, DataLoader
from transformers import Blip2ForConditionalGeneration

# Self-defined functions
from dataset import ImageCaptioningDataset
from featurizer import ImageFeaturizer

In [3]:
def featurize_func(encoding, featurizer, caching=True):
    object_features, bbox_tensor = featurizer.get_image_features(encoding['image'])
    encoding['object_features'] = object_features
    encoding['bbox'] = bbox_tensor
    return encoding
    

def collate_fn(batch, processor):
    processed_batch = {}
    for key in batch[0].keys():
        if key in ['image_id', 'image']:
            continue
        if key == "pixel_values":
            processed_batch[key] = torch.stack([example[key] for example in batch])
        elif key == "geometry" or key == "action" or key == "object_features" or key == "bbox":
            processed_batch[key] = torch.stack([torch.tensor(example[key]) for example in batch])
        elif key == "text":
            text_inputs = processor.tokenizer([example["text"] for example in batch], padding=True, return_tensors="pt")
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
        elif key == "agent_classifier" or key == "s3_fileUrl":
            processed_batch[key] = [example[key] for example in batch]
        else:
            other_inputs = processor.tokenizer([example[key] for example in batch], padding=True, return_tensors="pt")
            processed_batch[key] = other_inputs["input_ids"]

    return processed_batch

In [14]:
# Build dataset
dataset = ImageCaptioningDataset(
    json_file = './processed_dataset_4.json',
    cache_dir = '/ix1/xjia/yuw253/av_cache',  # target directory
    caching = True
)

featurizer = ImageFeaturizer()
featurized_dataset = dataset.map(partial(featurize_func, featurizer=featurizer))

# Split dataset into train/validation
train_ratio = 0.9
total_size = len(featurized_dataset)
train_size = int(total_size * train_ratio)
valid_size = total_size - train_size

train_dataset, valid_dataset = random_split(featurized_dataset, [train_size, valid_size])

# Build dataloader:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, 
                              collate_fn=partial(collate_fn, processor=dataset.processor))
valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=False, 
                              collate_fn=partial(collate_fn, processor=dataset.processor))

In [None]:
# Iterate through the dataset for caching
for i in tqdm(range(len(dataset))):
    _ = dataset[i]

In [45]:
from transformers import Blip2Config, Blip2ForConditionalGeneration

# Initializing a Blip2Config with Salesforce/blip2-opt-2.7b style configuration
configuration = Blip2Config()

# Initializing a Blip2ForConditionalGeneration (with random weights) from the Salesforce/blip2-opt-2.7b style configuration
model = Blip2ForConditionalGeneration(configuration)

In [None]:
from transformers import Blip2ForConditionalGeneration, AutoProcessor
from peft import PeftModel, PeftConfig

peft_model_id = "SeeonQwQ/blip2_frame_v4.0"
config = PeftConfig.from_pretrained(peft_model_id)
model = Blip2ForConditionalGeneration.from_pretrained(config.base_model_name_or_path, device_map="auto")
model = PeftModel.from_pretrained(model, peft_model_id, is_trainable=True)
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")

In [46]:
model

Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0): Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
        )
        (1): Blip2EncoderLayer(
          (self_attn): 