# Single Frame Process for Training a BLIP model for DriveLM Dataset

## Installing Libraries

In [None]:
!pip install peft
!pip install transformers



In [None]:
from google.colab import drive
from google.colab import files
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import pdb

import torch
from transformers import BlipProcessor, BlipForQuestionAnswering, \
TrainingArguments, Trainer, BertTokenizerFast, Blip2ForConditionalGeneration, Blip2Processor, AutoTokenizer
from torchvision.io import read_image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import time
from peft import LoraConfig, get_peft_model, LoftQConfig
from copy import deepcopy
import argparse
import pandas as pd
import matplotlib.pyplot as plt
import json
import asyncio
import pdb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
!mkdir DriveLM
!mkdir DriveLM/results
!unzip -q drive/MyDrive/DriveLM/data.zip -d DriveLM

mkdir: cannot create directory ‘DriveLM’: File exists
mkdir: cannot create directory ‘DriveLM/results’: File exists
replace DriveLM/data/multi_frame/multi_frame_test.json? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
class SingleFrameDataset(Dataset):

    def __init__(self, input_file, processor, custom_train=True):
        with open(input_file) as f:
            self.data = json.load(f)

        # Make processors for Image-Question pairs and Answer text
        self.processor = processor

        self.custom_train = custom_train

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get the question and answer at the idx
        qa, img_path = self.data[idx]
        img_path = os.path.join('DriveLM', '/'.join(img_path.split('\\')))
        q_text, a_text = qa['Q'], qa['A']
        new_q_text = f"Question: {q_text} Answer:"
        full_text = f"Question: {q_text} Answer: {a_text}"

        return full_text, img_path, a_text, new_q_text

    def collate_fn(self, batch):

        full_texts, img_paths, a_texts, q_texts = zip(*batch)
        N = len(img_paths)
        imgs = [read_image(img_path) for img_path in img_paths]

        # Perform this so Q & A are properly padded
        # q_texts, a_texts = list(q_texts), list(a_texts)
        # q_texts.extend(a_texts)

        encodings = self.processor(images=imgs, text=full_texts, padding=True, return_tensors="pt")
        q_enc = self.processor(text=q_texts, padding=True, return_tensors='pt').to(device)

        # labels = encodings['input_ids'][N:]
        # encodings['input_ids'] = encodings['input_ids'][:N]
        # encodings['attention_mask'] = encodings['attention_mask'][:N]

        # labels = self.processor(text=a_texts, padding=True, return_tensors='pt')


        if self.custom_train:
            encodings = encodings.to(device)
            q_enc = q_enc.to(device)

        encodings = {k: v for k, v in encodings.items()}
        q_enc = {k: v for k, v in q_enc.items()}
        encodings['labels'] = deepcopy(encodings['input_ids'])
        # encodings['labels'] = labels

        return encodings, q_enc
        # return encodings

    def test_collate_fn(self, batch):

        q_texts, img_paths, a_texts = zip(*batch)
        imgs = [read_image(img_path) for img_path in img_paths]

        encodings = self.processor(imgs, q_texts, padding=True, truncation=True, return_tensors="pt")
        labels = self.tokenizer(a_texts, padding=True, truncation=True, return_tensors='pt')

        if self.custom_train:
            encodings = encodings.to(device)
            # labels = labels.to(device)

        encodings = {k: v.squeeze() for k, v in encodings.items()}
        labels = {k: v.squeeze() for k, v in labels.items()}
        encodings['labels'] = labels['input_ids']

        return encodings, img_paths, q_texts

In [None]:
BATCH_SIZE = 2
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.05
EPOCHS = 1
LORA_DIM, LORA_ALPHA, LORA_DROPOUT = 16, 32, 0.05
NUM_WORKERS = 0
LOAD_CHECKPOINT = False
CHECKPOINT_FILE = '20240215-044630'
CUSTOM_TRAIN = True
LORA_ENABLED = True
MODEL_TYPE = 'BLIP2'

In [None]:
def save_model(model, model_name):
    # Save the model into the designated folder
    path = os.path.join('drive', 'MyDrive', 'DriveLM', 'results', timestr, model_name + '.pth')
    torch.save(model, path)


def val_model(dloader):

    val_loss = 0

    with torch.no_grad():

      model.eval()
      for idx, (batch, _) in tqdm(enumerate(dloader), total=len(dloader)):
          outputs = model(**batch)
          val_loss += outputs.loss.item()

    return val_loss


def save_stats(train_loss, val_loss, epochs):
    stats_dict = {
        'losses': losses,
        'val losses': val_losses,
        'min train loss': train_loss,
        'min val loss': val_loss,
        'epochs': epochs
    }

    # Save stats into checkpoint
    with open(os.path.join('drive', 'MyDrive', 'DriveLM', 'results', timestr, 'stats.json'), 'w') as f:
        json.dump(stats_dict, f)


def plot_loss(training_loss):
    num_epochs = len(training_loss)

    plt.plot(range(1, num_epochs + 1), training_loss, label='Training Loss')
    plt.title('Training Loss')
    plt.xlabel('Num epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join('drive', 'MyDrive', 'DriveLM', 'results', timestr, 'loss.png'))


def custom_train(train_loss, val_loss, best_model, epochs):

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1, verbose=False)


    for epoch in range(epochs, EPOCHS):
        print('-------------------- EPOCH ' + str(epoch) + ' ---------------------')

        model.train()
        epoch_loss = 0

        for step, (batch, q_texts) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):

              # Forward pass through model
              outputs = model(**batch)

              # Calculate loss
              loss = outputs.loss

              # print(loss.item())
              # print(outputs.logits)
              # print(loss, batch['input_ids'].shape, batch['labels'].shape, batch.keys())
              epoch_loss += loss.item()

              if step % 500 == 0:
                print()
                print('Loss: ' + str(loss.item()))

                with torch.no_grad():
                  outputs = model.generate(**q_texts, pixel_values=batch['pixel_values'], max_length=100)
                  # outputs = model.generate(**batch, max_length=400)
                  text_outputs = [processor.decode(output.to('cpu'), skip_special_tokens=True) for output in outputs]
                  text_questions = [processor.decode(q.to('cpu'), skip_special_tokens=True) for q in q_texts['input_ids']]
                  text_labels = [processor.decode(a.to('cpu'), skip_special_tokens=True) for a in batch['labels']]
                print()
                print('Questions:')
                print(text_questions)
                print()
                print('Generated Answers:')
                print(text_outputs)
                print()
                print('Ground Truth Answers:')
                print(text_labels)

              # Back-prop
              loss.backward()
              optimizer.step()
              optimizer.zero_grad()

        save_model(model.state_dict(), 'latest_model')

        # Get train and val loss per batch
        epoch_train_loss = epoch_loss / len(train_dataloader)

        epoch_val_loss = val_model(val_dataloader) / len(val_dataloader)

        if not val_loss or min(epoch_val_loss, val_loss) == epoch_val_loss:
            val_loss = epoch_val_loss
            best_model = deepcopy(model.state_dict())
        if not train_loss or min(train_loss, epoch_train_loss) == epoch_train_loss:
            train_loss = epoch_train_loss

        # Add losses to epoch list
        val_losses.append(epoch_val_loss)
        losses.append(epoch_train_loss)

        # Adjust learning rate scheduler
        # scheduler.step()

        print('Training Loss: ' + str(epoch_train_loss))
        print('Validation Loss: ' + str(epoch_val_loss))
        print('---------------------------------------------')

        # Save model and stats for checkpoints
        save_model(best_model, 'latest_model')
        epochs += 1
        save_stats(train_loss, val_loss, epochs)

    # Save the model and plot the loss
    save_model(best_model, 'latest_model')
    plot_loss(losses)

    return train_loss, val_loss


def train():
    training_args = TrainingArguments(
        output_dir="agopalkr/EfficientDriveLM",
        learning_rate=LEARNING_RATE,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=EPOCHS,
        weight_decay=WEIGHT_DECAY,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dset,
        eval_dataset=val_dset,
        data_collator=train_dset.collate_fn
    )

    trainer.train()
    model.push_to_hub("agopalkr/EfficientDriveLM")


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()

    print(
        f"Trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


def save_experiment(statistics):
    """
    Saves the experiment results to a csv
    :param args: The hyperparameters used
    :param statistics: The accuracies for the training, validation, and test sets
    """
    trial_dict = {
        'Model name': [timestr],
        'Model type': [MODEL_TYPE],
        'Learning rate': [LEARNING_RATE],
        'Weight decay': [WEIGHT_DECAY],
        'Batch size': [BATCH_SIZE],
        'Epochs': [EPOCHS],
        'LoRA Dimension': [LORA_DIM],
        'LoRA Alpha': [LORA_ALPHA],
        'LoRA Dropout': [LORA_DROPOUT],
        'Min Training Loss': [statistics[0]],
        'Min Validation Loss': [statistics[1]],
        'Min Testing Loss': [statistics[2]],
    }

    trial_dict = pd.DataFrame(trial_dict)
    trial_dict.to_csv(os.path.join('drive', 'MyDrive', 'DriveLM', 'results', timestr, 'results.csv'), index=False, header=True)


if __name__ == '__main__':
    timestr = time.strftime("%Y%m%d-%H%M%S")

    losses = []
    val_losses = []
    min_train_loss = None
    min_val_loss = None
    best_model = None
    epochs_ran = 0

    if MODEL_TYPE == 'BLIP':
      # Load processors and models
      processor = BlipProcessor.from_pretrained('Salesforce/blip-vqa-base')
      model = BlipForQuestionAnswering.from_pretrained('Salesforce/blip-vqa-base')
    else:
      processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
      processor.tokenizer.pad_token = processor.tokenizer.eos_token
      model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b', device_map="auto")

    train_dset = SingleFrameDataset(
        input_file=os.path.join('drive', 'MyDrive', 'DriveLM', 'sf_train.json'),
        processor=processor,
        custom_train=CUSTOM_TRAIN
    )
    val_dset = SingleFrameDataset(
        input_file=os.path.join('drive', 'MyDrive', 'DriveLM',
                                'sf_val.json'),
        processor=processor,
        custom_train=CUSTOM_TRAIN
    )
    test_dset = SingleFrameDataset(
        input_file=os.path.join('drive', 'MyDrive', 'DriveLM',
                                'sf_test.json'),
        processor=processor,
        custom_train=CUSTOM_TRAIN
    )

    print(len(train_dset), len(val_dset), len(test_dset))

    # Create Dataloaders
    train_dataloader = DataLoader(train_dset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
                                  collate_fn=train_dset.collate_fn)
    val_dataloader = DataLoader(val_dset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
                                collate_fn=val_dset.collate_fn, drop_last=True)
    test_dataloader = DataLoader(test_dset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
                                 collate_fn=test_dset.collate_fn, drop_last=True)

    print_trainable_parameters(model)

    # Create LoRA model
    if LORA_ENABLED:
      # For quantization
      loftq_config = LoftQConfig(loftq_bits=8)
      lora_config = LoraConfig(loftq_config=loftq_config)
      model = get_peft_model(model, lora_config)
      print_trainable_parameters(model)

    model.to(device)

    if CUSTOM_TRAIN:

        # Load checkpoint if neccesary:
        if LOAD_CHECKPOINT:

            print('Loading model from ' + CHECKPOINT_FILE)

            # Load the model and stats from the checkpoint
            model.load_state_dict(torch.load(os.path.join('drive', 'MyDrive', 'DriveLM', 'results', CHECKPOINT_FILE, 'latest_model.pth')))
            best_model = deepcopy(model.state_dict())

            with open(os.path.join('drive', 'MyDrive', 'DriveLM', 'results', CHECKPOINT_FILE, 'stats.json'), 'r') as f:
                stats = json.load(f)

            min_train_loss, min_val_loss, losses, val_losses, epochs_ran = stats['min train loss'], \
                                                                           stats['min val loss'], \
                                                                           stats['losses'], stats['val losses'], \
                                                                           stats['epochs']
            print(f'Minimum Training Loss: {min_train_loss}')
            print(f'Training Losses: {losses}')
            print(f'Minimum Validation Loss: {min_val_loss}')
            print(f'Validation Losses: {val_losses}')
            print(f'Epochs ran: {epochs_ran}')
            timestr = CHECKPOINT_FILE
        else:
            os.mkdir(os.path.join('drive', 'MyDrive', 'DriveLM', 'results', timestr))

        min_train_loss, min_val_loss = custom_train(min_train_loss, min_val_loss, best_model, epochs_ran)

        if MODEL_TYPE == 'BLIP':
          best_model = BlipForQuestionAnswering.from_pretrained('Salesforce/blip-vqa-base')
        else:
          best_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b', device_map="auto")

        if LORA_ENABLED:
          best_model = get_peft_model(best_model, lora_config)

        best_model.load_state_dict(torch.load(os.path.join('drive', 'MyDrive', 'DriveLM', 'results', timestr, 'latest_model.pth')))
        test_loss = val_model(test_dataloader) / len(test_dataloader)
        statistics = [min_train_loss, min_val_loss, test_loss]
        save_experiment(statistics)
    else:
        train()

In [None]:
from google.colab import runtime
runtime.unassign()

In [None]:
# processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b', padding_side='left')
# model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b', torch_dtype=torch.float16, device_map="auto")
# model.to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# text = ['Question: What is the ego state of the vehicle? Answer: The ego-vehicle is moving']
# enc = processor(read_image(train_dset[0][1]).to(device), text, padding=True, truncation=True, return_tensors='pt')
# enc.to(device)

# outputs = model(**enc, labels=enc['input_ids'])
# loss = outputs.loss
# print(outputs.loss.item())

# # Back-prop
# loss.backward()
# optimizer.step()
# optimizer.zero_grad()

# outputs = model(**enc, labels=enc['input_ids'])
# print(outputs.loss.item())
