# Multimodal LLMs with Database Constrained Decoding for Recycling Classification

This is the Google Colab notebook accompanying the repo https://github.com/acluous/recycling-database-constrained-decoding.git.

**Clone repo and install requirements**

In [None]:
FOLDER_ROOT = "/content/recycling-database-constrained-decoding"
!git clone https://github.com/acluous/recycling-database-constrained-decoding.git
!git clone https://huggingface.co/datasets/acluous/waste-wizard-materials-list
!pip install -r {FOLDER_ROOT}/requirements.txt

**Restart runtime and load model**

Make sure to select Runtime > Change runtime type > T4 GPU.

In [2]:
from collections import OrderedDict
import requests
import json
import numpy as np
from PIL import Image
import torch
from tqdm import tqdm

from transformers import AutoProcessor, AutoModelForVision2Seq, AutoModel, BitsAndBytesConfig
from datasets import load_dataset

import sys
FOLDER_ROOT = "/content/recycling-database-constrained-decoding"
sys.path.append(FOLDER_ROOT)
import decoding_utils

# load in 4bit to reduce memory consumption
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

**Load Waste Wizard Dataset**

In [39]:
image_dataset = load_dataset("acluous/waste-wizard-materials-list")['train']

# change city_name to whatever city you want to evaluate next
# ["toy", "davis", "mountain-view", "waverley", "waterloo"]
city_name = "mountain-view"
image_labels_file = f"/content/waste-wizard-materials-list/data/image-labels/{city_name}.json"
city_database_file = f"/content/waste-wizard-materials-list/data/city-databases/{city_name}.json"

query = "What is this item?"
if city_name == "toy":
  database = json.load(open(city_database_file))
else:
  database = [ann["text"].title().strip() for ann in json.load(open(city_database_file))]
image_labels_map = json.load(open(image_labels_file))

Using the latest cached version of the dataset since acluous/waste-wizard-materials-list couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /root/.cache/huggingface/datasets/acluous___waste-wizard-materials-list/default/0.0.0/ecb0cc3995fa35fbb6c50cd5d3d8aa144ae06bae (last modified on Tue Jun 18 04:48:28 2024).


**Load Idefics2**

In [None]:
model_id = "HuggingFaceM4/idefics2-8b"
processor = AutoProcessor.from_pretrained(
    model_id,
    do_image_splitting=False
)
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    quantization_config=quantization_config,
)

Run Database Constrained Decoding

In [40]:
preds, labels, all_database_chosen = decoding_utils.mllm_classification(model, processor, image_dataset, database, query, mode="dcd")
acc = decoding_utils.get_accuracy(preds, labels, image_labels_map)
print("Classification Accuracy, Idefics2 - Database Constrained Decoding", acc)

100%|██████████| 100/100 [05:50<00:00,  3.51s/it]


Classification Accuracy, Idefics2 - Database Constrained Decoding 0.62


Run Normalized Sequence Likelihood

In [None]:
if city_name != "toy":
  print("Warning: NSL will take a very long time to run on this database. Switch to the toy database instead.")
preds, labels, all_database_chosen = decoding_utils.mllm_classification(model, processor, image_dataset, database, query, mode="nsl")
acc = decoding_utils.get_accuracy(preds, labels, image_labels_map)
print("Classification Accuracy, Idefics2 - Normalized Sequence Likelihood", acc)

**Load OpenCLIP**

In [31]:
clip_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
clip_processor = AutoProcessor.from_pretrained(clip_id)
clip_model = AutoModel.from_pretrained(
    clip_id,
    torch_dtype=torch.float16,
    quantization_config=quantization_config
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Run Contrastive Classification

In [41]:
preds, labels, all_database_chosen = decoding_utils.contrastive_classification(clip_model, clip_processor, image_dataset, database)
acc = decoding_utils.get_accuracy(preds, labels, image_labels_map)
print("Classification Accuracy, OpenCLIP", acc)

100%|██████████| 100/100 [00:04<00:00, 21.11it/s]
100%|██████████| 470/470 [00:14<00:00, 32.22it/s]
100%|██████████| 100/100 [00:00<00:00, 13300.89it/s]


Classification Accuracy, OpenCLIP 0.53
