<a href="https://colab.research.google.com/github/Seyviour/enel645-g12/blob/main/645_assignment2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Approach**

We finetuned Paligemma, a 3-Billion parameter vision-language model developed by Google, using LoRA to make the training process more efficient. We framed the task as a multiple-choice-question answering problem, and achieved an accuracy of ~86% on the test set.

Final training was done on the ARC cluster with a batch size of 8 and a 1e-5 learning rate which we found, experimentally, to give the best results. Inference was done on Ziheng's 4070.



In [9]:
#Imports

#Utils
import os
from PIL import Image
from pathlib import Path
import string
import random
import re
#Modelling and computation
import torch
import numpy as np
from transformers import Trainer
from transformers import PaliGemmaForConditionalGeneration
from transformers import BitsAndBytesConfig
from peft import get_peft_model, LoraConfig
from transformers import PaliGemmaProcessor
#Huggingface
from huggingface_hub import login
#Dataset and Transforms
from torch.utils.data import Dataset, ConcatDataset
from torchvision import transforms
from transformers import TrainingArguments
#Evaluation
# import evaluate
from sklearn import metrics
import matplotlib.pyplot as plt
#Quantization
from transformers import BitsAndBytesConfig

In [None]:
#Data Folders
train_folder = "/home/ziheng.chang/garbage_data/garbage_data/CVPR_2024_dataset_Train"
test_folder = "/home/ziheng.chang/garbage_data/garbage_data/CVPR_2024_dataset_Test"
val_folder = "/home/ziheng.chang/garbage_data/garbage_data/CVPR_2024_dataset_Val"

In [None]:
#Setup
RANDOM_SEED = 42
UAT = "" #Set hugging face access token here

HUGGINGFACE_USER_NAME = "palicoqiqi"
HUGGINGFACE_USER_NAME = "seyviour"

IS_TEST_JOB = False
# IS_TEST_JOB = True

UAT = UAT or os.getenv("HF_TOKEN")
if UAT is None:
    raise Exception("Hugging Face Token is not set")
os.environ["HF_TOKEN"] = UAT #hugging face token
login(token=UAT)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model_id = "google/paligemma-3b-pt-224"

In [6]:
class ImageTextDataset(Dataset):
    """
      A PyTorch Dataset for loading and transforming image-text data for a multiple-choice
      question-answering task, designed to categorize images based on visual and textual cues.

      Attributes:
      ----------
      base_folder : str
          The root directory containing the dataset folders.
      sub_folder : str
          The specific subdirectory within `base_folder` corresponding to a class of images.
      label : str
          The label associated with this dataset, typically the category of the images.
      base_transform : callable, optional
          The base transformation applied to each image (e.g., normalization).
      aug_transforms : list of callables, optional
          A list of augmentation transformations applied to each image to increase dataset diversity.
      encoder : callable, optional
          An optional encoder for the text descriptions, if required.
      max_size : int, optional
          Maximum number of images to load from `sub_folder`. If None, all images are loaded.
    """
    def __init__(self, base_folder, sub_folder, label, base_transform=None,
                 aug_transforms=None, encoder=None, max_size=None
                 ):

        self.aug_transforms = [] if aug_transforms is None else aug_transforms
        self.base_folder = base_folder
        self.sub_folder = sub_folder
        self.base_transform = base_transform
        self._path = os.path.join(self.base_folder, self.sub_folder)
        self._file_names = sorted(os.listdir(self._path))
        self.max_size = max_size
        if (self.max_size):
            self._file_names = self._file_names[:self.max_size]
        self.encoder = encoder
        self.label = label

    def _get_image_path(self, idx):
        full_path = os.path.join(self._path, self._file_names[idx])
        return full_path

    def __len__(self):
        return len(self._file_names) * (len(self.aug_transforms) + 1)

    def to_description(self, file_name):
        file_name_no_ext, _ = os.path.splitext(file_name)
        text = file_name_no_ext.replace('_', ' ')
        text_without_digits = re.sub(r'\d+', '', text)
        return text_without_digits.strip()

    def _get_image_at_idx(self, idx):
        transform_idx = idx//(len(self._file_names)+1)
        true_idx = idx % (len(self._file_names))
        filepath = self._get_image_path(true_idx)
        image = Image.open(filepath).convert('RGB')  # Convert to RGB
        if transform_idx:
            # Augment the Dataset by applying a transform when `idx` is greater
            # than the number of files in the subfolder
            image = self.aug_transforms[transform_idx-1](image)
        if self.base_transform:
            image = self.base_transform(image)
        return image

    def _get_text_at_idx(self, idx):
        true_idx = idx % (len(self._file_names))
        filename = self._file_names[true_idx]
        text = self.to_description(filename)
        return text

    def __getitem__(self, idx):
        # Returns a dictionary containing the image, question, and label for the specified index.
        description = self._get_text_at_idx(idx)
        image = self._get_image_at_idx(idx)
        label = self.label
        data_point = {
            'multiple_choice_answer': label,
            'question': f'What type of garbage is this {description}?',
            'image': image
        }
        return data_point


In [None]:
#Helper for creating datasets
def make_garbage_dataset(basefolder:str, key, kwargs)->tuple[ConcatDataset, dict]:
    # Ensure folders are sorted, exclude hidden files/folders
    class_folders = sorted([x for x in os.listdir(basefolder) if x[0]!='.'])
    individual_datasets = [ImageTextDataset(basefolder, x, key.get(x.lower(), x), **kwargs) for x in class_folders]
    return ConcatDataset(individual_datasets)

In [None]:
folder_names = ['Black', 'Green', 'Blue', 'TTR']
garbage_types = ['black', 'green', 'blue', 'other'] #Corresponds to ['landfill', 'compostable', 'recyclable', 'other']
garbage_types = ["landfill", "compostable", "recyclable", "other"]

garbage_types = {
      "black" : "landfill",
      "green": "compostable",
      "blue": "recyclable",
      "ttr": "other"
    }

In [None]:
#Image Transformations. Paligemma's `process` function also applies transforms
#that prepare the images for processing by PaliGemma, so we apply only basic
#transforms here

preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(0.5)
        ])

aug_transform = transforms.v2.RandomChoice([
    transforms.Compose([
        transforms.ColorJitter(brightness=.5, hue=.3),
        transforms.RandomRotation(degrees = (0,45))
        ]
    ),
    transforms.Compose([
        transforms.GaussianBlur(kernel_size = (5,5)),
        transforms.RandomRotation(degrees = (0,45))
        ]
    )
])

In [None]:
max_len = 100 if IS_TEST_JOB else None
kwargs_train = {
    "base_transform": preprocess,
    "aug_transforms": [aug_transform],
    "max_len": max_len
}

kwargs_eval = {
    "base_transform": preprocess,
    "aug_transforms": None,
    "max_len": max_len
}

train_data = make_garbage_dataset(train_folder, garbage_types, kwargs_train)
test_data = make_garbage_dataset(test_folder, garbage_types, kwargs_eval)
val_data = make_garbage_dataset(test_folder, garbage_types, kwargs_eval)

In [None]:
# Quantization Parameters: LoRA & QLoRA fine-tuning to reduce cost of training
!pip install bitsandbytes
model_id = "google/paligemma-3b-pt-224"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

lora_config = LoraConfig(
    r=8,
    #Paligemma modules to apply LoRA to
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj",
                    "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)


In [None]:
processor = PaliGemmaProcessor.from_pretrained(model_id)
# Collate function passed to the Trainer for creating batches of training data
def collate_fn(examples):
  texts = ["<image> <bos> answer " + example["question"] for example in examples]
  labels= [example['multiple_choice_answer'] for example in examples]
  images = [example["image"].convert("RGB") for example in examples]
  tokens = processor(text=texts, images=images, suffix=labels,
                    return_tensors="pt", padding="longest")
  tokens = tokens.to(torch.bfloat16).to(device)
  return tokens

In [None]:
#Get base model from Huggingfacehub
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map={"":0}
  )

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()



In [12]:

# # # Setup training, then train
#Initialize the `TrainingArguments`.
# Reasonable parameters are chosen after experimentation.
# 12 epochs gives good results without running for too long.
# 8 batch size gives good speed withou OOM
# learning rate of 4e-5 gives best accuracy
# We experimented with multiple learning rates and found that 4e-5 gave
# The best accuracy
TRAIN_BATCH_SIZE = 3
EVAL_BATCH_SIZE = 3

args=TrainingArguments(
            num_train_epochs=12,
            remove_unused_columns=False,
            per_device_train_batch_size=TRAIN_BATCH_SIZE,
            per_device_eval_batch_size=EVAL_BATCH_SIZE,
            gradient_accumulation_steps=4,
            warmup_steps=2,
            learning_rate=0.00004,
            weight_decay=1e-6,
            adam_beta2=0.999,
            optim="adamw_torch",
            save_strategy="epoch",
            eval_strategy="epoch",
            push_to_hub=True,
            save_total_limit=1,
            output_dir="paligemma_vqav2_2",
            bf16=True,
            report_to=["tensorboard"],
            dataloader_pin_memory=False,
            load_best_model_at_end=True,
            eval_do_concat_batches = False
        )

trainer = Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        data_collator=collate_fn,
        args=args
    )




In [None]:
# Train and save the model for inference/testing
trainer.train()

#Push the model to huggingface hub for future use.
#Also save it locally just in case
trainer.push_to_hub(f'{HUGGINGFACE_USER_NAME}/paligemma_VQAv2_enel645_2')
trainer.save_model("model/paligemma_model_2")

In [None]:
def make_predictions(examples):
    predicted = []
    actual = []
    count = 0
    tot_count = len(examples)
    for example in examples:
        texts = "<image> <bos> answer " + example["question"]
        labels= example['multiple_choice_answer']
        images = example["image"].convert("RGB")

        # Preprocessing Inputs
        inputs = processor(text=texts, images=images, padding="longest", do_convert_rgb=True, return_tensors="pt").to(device)
        inputs = inputs.to(dtype=model.dtype)

        # Generating and Decoding Output
        with torch.no_grad():
            output = model.generate(**inputs, max_length=496)

        output = processor.decode(output[0], skip_special_tokens=True)
        output = output.lower()
        options = sorted(["black", "blue", "ttr", "green"])
        if output not in options:
            print('Wrong prediction detected:',processor.decode(output[0], skip_special_tokens=True), "\n")
            predicted.append(len(options))
        else:
            predicted.append(options.index(output))
        if count % 500 == 0:
            Accuracy = metrics.accuracy_score(actual, predicted)
            print(count/tot_count*100, '% done. Accuracy so far is', Accuracy*100, '%.')
        count += 1

    print('Completed.')
    return actual, predicted

actual, predicted = make_predictions(test_data)

In [1]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir=runs