In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import tqdm
from sklearn.metrics import accuracy_score
import numpy as np

from datasets import load_dataset, Image

from transformers import BlipForQuestionAnswering, BlipProcessor, BlipConfig, BlipModel

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, image_files, text, processor, num_labels):
        self.image_files = image_files
        self.text = text
        self.processor = processor
        self.num_labels = num_labels

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        text = self.text
        image = self.image_files[idx]['img']
        label = self.image_files[idx]['label']
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        # encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
        encoding = self.processor(image, text, return_tensors="pt")

        # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        targets = torch.zeros(self.num_labels)
        targets[label] = 1
        encoding["labels"] = targets

        return encoding

In [4]:
dataset = load_dataset('cifar10')

Found cached dataset cifar10 (/root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
label_list = dataset["train"].features["label"].names
num_labels = len(label_list)

config = BlipConfig.from_pretrained("Salesforce/blip-vqa-base")
config.id2label = {str(i): label for i, label in enumerate(label_list)}
config.label2id = {label: str(i) for i, label in enumerate(label_list)}
config.num_labels = num_labels
config.max_length = 1
config.text_config.max_length = 1

processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")

train_dataset = ImageDataset(image_files=dataset["train"], text="What is this image?", processor=processor, num_labels=num_labels)
test_dataset = ImageDataset(image_files=dataset["test"], text="What is this image?", processor=processor, num_labels=num_labels)

In [6]:
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", config=config)
model = model.to(device)

In [7]:
def collate_fn(batch):
  input_ids = [item['input_ids'] for item in batch]
  pixel_values = [item['pixel_values'] for item in batch]
  attention_mask = [item['attention_mask'] for item in batch]
  # token_type_ids = [item['token_type_ids'] for item in batch]
  labels = [item['labels'] for item in batch]

  # create padded pixel values and corresponding pixel mask
  # encoding = processor.feature_extractor.pad_and_create_pixel_mask(pixel_values, return_tensors="pt")

  # create new batch
  batch = {}
  batch['input_ids'] = torch.stack(input_ids)
  batch['attention_mask'] = torch.stack(attention_mask)
  # batch['token_type_ids'] = torch.stack(token_type_ids)
  # batch['pixel_values'] = encoding['pixel_values']
  # batch['pixel_mask'] = encoding['pixel_mask']
  batch['pixel_values'] = torch.stack(pixel_values)
  batch['labels'] = torch.stack(labels)
  # batch['labels'] = torch.Tensor(labels).type(torch.LongTensor).unsqueeze(1)

  return batch

In [8]:
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=128, shuffle=True)
val_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=128, shuffle=True)

In [9]:
for name, param in model.named_parameters():
    param.requires_grad = False

In [10]:
model = model.to(device)

In [11]:
cls_model = nn.Sequential(
    nn.Linear(in_features=model.text_decoder.config.vocab_size, out_features=num_labels, bias=True)
)
cls_model = cls_model.to(device)

In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cls_model.parameters(), lr=6e-4)
num_epochs = 30

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.25)

best_params = None
best_val_accuracy = -1

for epoch in range(1):
    model.train()
    total_loss = 0.0
    
    train_predictions = []
    train_labels_eval = []
    step = 0
    
    for batch in tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        batch = {k:v.to(device) for k,v in batch.items()}

        optimizer.zero_grad()
        
        # import ipdb; ipdb.set_trace()
        outputs = model.generate(**batch)  # N,2
        outputs = outputs[:,1]
        outputs = nn.functional.one_hot(outputs, num_classes = model.text_decoder.config.vocab_size).type(torch.FloatTensor)
        outputs = outputs.to(device)
        labels = batch['labels']

        outputs_cls = cls_model(outputs)
        loss = criterion(outputs_cls, labels)
        loss.backward()
        
        optimizer.step()

        _, preds = torch.max(outputs_cls, 1)
        train_predictions.extend(preds.cpu().numpy())
        train_labels_eval.extend(labels.cpu().numpy())

        total_loss += loss
        step += 1
            
    scheduler.step()
    
    train_loss = total_loss / len(train_dataloader)
    
    # Evaluate the model on the validation set
    model.eval()
    val_predictions = []
    val_labels_eval = []
        
    with torch.no_grad():
        step = 0

        for batch in val_dataloader:
            batch = {k:v.to(device) for k,v in batch.items()}
            
            outputs = model.generate(**batch)
            outputs = outputs[:,1]
            outputs = nn.functional.one_hot(outputs, num_classes = model.text_decoder.config.vocab_size).type(torch.FloatTensor)
            outputs = outputs.to(device)

            outputs_cls = cls_model(outputs)
            _, preds = torch.max(outputs_cls, 1)
            labels = batch['labels']
            
            val_predictions.extend(preds.cpu().numpy())
            val_labels_eval.extend(labels.cpu().numpy())

            step += 1
    
    val_labels_idx = [np.argmax(tensor) for tensor in val_labels_eval]
    val_accuracy = accuracy_score(val_labels_idx, val_predictions)
    
    train_labels_idx = [np.argmax(tensor) for tensor in train_labels_eval]
    train_accuracy = accuracy_score(train_labels_idx, train_predictions)
    
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}, Training Acc: {train_accuracy:.4f}, Validation Accuracy: {val_accuracy:.4f}")
    # print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

Input length of input_ids is 1, but `max_length` is set to 1. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Epoch 1/30:   0%|          | 1/391 [00:04<29:36,  4.55s/it]Input length of input_ids is 1, but `max_length` is set to 1. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Epoch 1/30:   1%|          | 2/391 [00:07<22:12,  3.43s/it]Input length of input_ids is 1, but `max_length` is set to 1. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Epoch 1/30:   1%|          | 3/391 [00:09<20:11,  3.12s/it]Input length of input_ids is 1, but `max_length` is set to 1. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Epoch 1/30:   1%|          | 4/391 [00:12<18:57,  2.94s/it]Input length of input_ids is 1, but `max_length` is set to 1. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Epoch 1/30:   1%|▏ 

Epoch 1/30, Loss: 2.1659, Training Acc: 0.7526, Validation Accuracy: 0.7803
