In [1]:
import re
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from transformers import (
    AutoProcessor,
    PaliGemmaForConditionalGeneration
)

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image, ImageDraw, ImageFont

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "google/paligemma-3b-mix-224"

device = "cuda:0"
dtype = torch.bfloat16

batch_size = 8
learning_rate = 5e-5

processor = AutoProcessor.from_pretrained(model_id)

In [3]:
import pandas as pd    

training_objects = pd.read_json(path_or_buf=r"_annotations.test.jsonl", lines=True)
validation_objects = pd.read_json(path_or_buf=r"_annotations.train.jsonl", lines=True)
test_objects = pd.read_json(path_or_buf=r"_annotations.valid.jsonl", lines=True)

In [4]:
train_dataloader = DataLoader(
    training_objects,
    batch_size=batch_size,
    shuffle=True,
)

validation_dataloader = DataLoader(
    validation_objects,
    batch_size=batch_size,
    shuffle=True,
)

test_dataloader = DataLoader(
    test_objects,
    batch_size=batch_size,
    shuffle=True,
)

In [5]:
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=dtype,
    device_map=device,
    revision="bfloat16",
)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Downloading shards: 100%|██████████| 2/2 [01:55<00:00, 57.62s/it]
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards:   0%|          | 0/2 [00:15<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 0 has a total capacity of 4.00 GiB of which 275.40 MiB is free. Of the allocated memory 2.97 GiB is allocated by PyTorch, and 43.18 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
for name, param in model.named_parameters():
    if "attn" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

for name, param in model.named_parameters():
    print(f"{name:<70}: requires_grad={param.requires_grad}, dtype={param.dtype}")

In [None]:
val_batch = next(iter(validation_dataloader))
test_batch = next(iter(test_dataloader))

In [None]:
print(f"{val_batch.keys()=}")
print(f"{test_batch.keys()=}")

In [None]:
outputs = model(**val_batch)
print(f"{outputs.loss=}")

In [None]:
with torch.inference_mode():
    generation = model.generate(**test_batch, max_new_tokens=100, do_sample=False)
    decoded = processor.batch_decode(generation, skip_special_tokens=True)

In [None]:
for element in decoded:
    location = element.split("\n")[1]
    if location == "":
        print("No bbox found")
    else:
        print(location)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()

for epoch in range(1):
    for idx, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        if idx % 500 == 0:
            print(f"Epoch: {epoch} Iter: {idx} Loss: {loss.item():.4f}")

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
DETECT_RE = re.compile(
    r"(.*?)" + r"((?:){4})\s*" + r"([^;<>]+) ?(?:; )?",
)

In [None]:
def extract_objects(detection_string, image_width, image_height, unique_labels=False):
    objects = []
    seen_labels = set()

    while detection_string:
        match = DETECT_RE.match(detection_string)
        if not match:
            break

        prefix, locations, label = match.groups()
        location_values = [int(loc) for loc in re.findall(r"\d{4}", locations)]
        y1, x1, y2, x2 = [value / 1024 for value in location_values]
        y1, x1, y2, x2 = map(
            round,
            (y1 * image_height, x1 * image_width, y2 * image_height, x2 * image_width),
        )

        label = label.strip()  # Remove trailing spaces from label

        if unique_labels and label in seen_labels:
            label = (label or "") + "'"
        seen_labels.add(label)

        objects.append(dict(xyxy=(x1, y1, x2, y2), name=label))

        detection_string = detection_string[len(match.group()) :]

    return objects

In [None]:
def draw_bbox(image, objects):
    fig, ax = plt.subplots(1)
    ax.imshow(image)
    for obj in objects:
        bbox = obj["xyxy"]
        rect = patches.Rectangle(
            (bbox[0], bbox[1]),
            bbox[2] - bbox[0],
            bbox[3] - bbox[1],
            linewidth=2,
            edgecolor="r",
            facecolor="none",
        )
        ax.add_patch(rect)
        plt.text(
            bbox[0], bbox[1] - 10, "plate", color="red", fontsize=12, weight="bold"
        )
    plt.show()

In [None]:
with torch.inference_mode():
    generation = model.generate(**test_batch, max_new_tokens=100, do_sample=False)
    decoded = processor.batch_decode(generation, skip_special_tokens=True)

In [None]:
for index in range(batch_size):
    image = test_batch["pixel_values"][index].permute(1, 2, 0).cpu().float()
    element = decoded[index]
    detection_string = element.split("\n")[1]
    objects = extract_objects(detection_string, 224, 224, unique_labels=False)
    draw_bbox(image, objects)