This notebook for fine-tuning the PaliGemma model wit the RISCM dataset and testing some model performance improvement suggestions.

Install the required packages with the following commands.

In [None]:
!pip install -U pip setuptools wheel

# "jax[cpu]"==0.4.20
!pip install \
    "jax[cuda12]"==0.4.20 \
    jaxlib==0.4.20 \
    flax==0.7.5 \
    tensorflow==2.15.0 \
    keras==2.15.0 \
    einops~=0.7 \
    pillow \
    scikit-image \
    matplotlib \
    sentencepiece \
    overrides \
    ml_collections \
    kagglehub \
    polars \
    wandb \
    ipython \
    gensim \
    bert-score \
    rouge_score \
    textstat \
    language-tool-python \
    pycocoevalcap \
    hf_xet \
    numpy \
    scipy \
    pot

!apt-get update
!apt-get install -y openjdk-17-jdk
!update-alternatives --set java /usr/lib/jvm/java-17-openjdk-amd64/bin/java
!update-alternatives --set javac /usr/lib/jvm/java-17-openjdk-amd64/bin/javac

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Download the modified "big_vision_repo" to use PaliGemma utilities.

In [None]:
import os
import sys

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
    !git clone --quiet --branch=master --depth=1 https://github.com/ErenNarin/modified-big-vision big_vision_repo  # TODO: merge to master

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
    sys.path.append("big_vision_repo")

Import the required packages and check the availability of jax.

In [None]:
import base64
import functools
import html
import io
import warnings
import json

import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import polars as pl

import tensorflow as tf
from tensorflow.keras import mixed_precision
import sentencepiece

from IPython.core.display import display, HTML
from PIL import Image

import kagglehub

from skimage.segmentation import slic

from google.colab import userdata

# Import model definition from big_vision
from big_vision_repo.big_vision.models.proj.paligemma import paligemma
from big_vision_repo.big_vision.trainers.proj.paligemma import predict_fns
from big_vision_repo.big_vision.custom_evaluation.evaluation import evaluation_score

# Import big vision utilities
from big_vision_repo.big_vision.datasets.jsonl import DataSource
from big_vision_repo.big_vision.utils import tree_map_with_names, reshard, create_learning_rate_schedule, \
    tree_flatten_with_names
from big_vision_repo.big_vision.sharding import infer_sharding

# Monkey-patch for backwards compatibility
import jax.tree_util
jax.tree = jax.tree_util
jax.tree_util.map = jax.tree_util.tree_map
jax.tree_util.flatten = jax.tree_util.tree_flatten
jax.tree_util.unflatten = jax.tree_util.tree_unflatten
jax.tree_util.leaves = lambda x: jax.tree_util.tree_flatten(x)[0]

from transformers import logging
logging.set_verbosity_error()

# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

mixed_precision.set_global_policy("mixed_float16")

#backend = jax.extend.backend.get_backend()
devices = jax.devices()
backend_platform = devices[0].platform if devices else "None" # Get platform from the first device if available
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend_platform}")  # backend.platform
print(f"JAX devices:  {jax.device_count()}")

Download the pre-trained model to be fine-tuned. We will use PaliGemma-2 3B model.

In [None]:
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [None]:
# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224"

"""
# Use these for PaliGemma 1:
LLM_VARIANT = "gemma_2b"
MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"
"""

if not os.path.exists(MODEL_PATH):
    print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
    MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
    print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "model/paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
    print("Downloading the model tokenizer...")
    !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
    print(f"Tokenizer path: {TOKENIZER_PATH}")

Define the model with the pre-trained weights and default config.

In [None]:
# Define model

# IMPORTANT: Gemma-2 has a "final_logits_softcap" property. Set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)

# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

Load all model parameters and add sharding to train the model on multiple GPUs, if available.

In [None]:
# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
    if name.startswith("llm/layers/attn/"):  return True
    if name.startswith("llm/"):              return False
    if name.startswith("img/"):              return False
    raise ValueError(f"Unexpected param name {name}")


trainable_mask = tree_map_with_names(is_trainable_param, params)

# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")


@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
    # Cast others to float16, since some GPUs don't support bf16.
    return jax.tree.map(lambda p, m: p.astype(jnp.float32)
    if m else p.astype(jnp.float16),
                        params, trainable)


# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead, do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
    params[idx] = reshard(params[idx], sharding)
    params[idx] = maybe_cast_to_f32(params[idx], trainable)
    params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)


# Print params to show what the model is made of.
def parameter_overview(params):
    for path, arr in tree_flatten_with_names(params)[0]:
        print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")


print(" == Model params == ")
parameter_overview(params)

Load and transform the data for using the PaliGemma's DataSource object. While doing this, drop duplicate or partially unavailable samples.

In [None]:
data_path = userdata.get('DATA_PATH')  # Add data path as a Colab secret or give the direct path of the data
images_folder = "resized/"
captions_path = data_path + "captions.csv"

df_input = pl.read_csv(captions_path, separator=",", glob=False)
splits = df_input.select('split').unique()['split'].to_list()
for split in splits:
    filename = f"{data_path}{split}_captions.jsonl"
    if True:  # not os.path.isfile(filename):  # TODO: enable this check
        print(f"{split}_captions not found, starting to process...")
        df_split = df_input.filter(pl.col('split') == split)
        with open(filename, 'w') as f:
            for row in df_split.iter_rows(named=True):
                if os.path.isfile(f"{data_path}{images_folder}{row['image']}"):  # Skip the annotation if referenced image does not exist
                    for i in range(1, 6):
                        json_object = {
                            "prefix": "",
                            "image": f"{images_folder}{row['image']}",
                            "suffix": row[f"caption_{i}"],  # TODO: remove duplicate samples
                        }
                        f.write(json.dumps(json_object) + '\n')
    else:
        print(f"{split}_captions is already processed, skipping...")

Define preprocess and postprocess methods.

In [None]:
def preprocess_image(image, size=224):
    # Model has been trained to handle images of different aspects ratios
    # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
    # options are helpful to improve quality in some tasks.

    image = np.asarray(image)
    if image.ndim == 2:  # Convert image without last channel into greyscale.
        image = np.stack((image,) * 3, axis=-1)
    image = image[..., :3]  # Remove alpha layer.
    assert image.shape[-1] == 3

    image = tf.constant(image)
    image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
    return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]


def preprocess_tokens(prefix, suffix=None, seqlen=None):
    # Model has been trained to handle tokenized text composed of a prefix with
    # full attention and a suffix with causal attention.
    separator = "\n"
    tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
    mask_ar = [0] * len(tokens)  # 0 to use full attention for prefix.
    mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

    if suffix:
        suffix = tokenizer.encode(suffix, add_eos=True)
        tokens += suffix
        mask_ar += [1] * len(suffix)  # 1 to use causal attention for suffix.
        mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

    mask_input = [1] * len(tokens)  # 1 if it's a token, 0 if padding.
    if seqlen:
        padding = [0] * max(0, seqlen - len(tokens))
        tokens = tokens[:seqlen] + padding
        mask_ar = mask_ar[:seqlen] + padding
        mask_loss = mask_loss[:seqlen] + padding
        mask_input = mask_input[:seqlen] + padding

    return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))


def postprocess_tokens(tokens):
    tokens = tokens.tolist()  # np.array to list[int]
    try:  # Remove tokens at and after EOS if any.
        eos_pos = tokens.index(tokenizer.eos_id())
        tokens = tokens[:eos_pos]
    except ValueError:
        pass
    return tokenizer.decode(tokens)

Define the datasets and the data loaders.

In [None]:
SEQLEN = 128

val_dataset = DataSource(
    os.path.join(data_path, "val_captions.jsonl"),
    fopen_keys={"image": data_path})

def validation_data_iterator():
    """Single iterator over validation examples."""
    for example in val_dataset.get_tfdata().as_numpy_iterator():
        image = Image.open(io.BytesIO(example["image"]))
        image = preprocess_image(image)

        prefix = "caption en"  # Could also be a different prefix per example.
        tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)

        suffix = example["suffix"].decode().lower()
        tmp_tokens, _, _, _ = preprocess_tokens(prefix, suffix, SEQLEN)

        yield {
            "image": np.asarray(image),
            "text": np.asarray(tokens),
            "annotation": np.asarray(tmp_tokens),
            "mask_ar": np.asarray(mask_ar),
            "mask_input": np.asarray(mask_input)
        }

Display some examples from the training dataset.

In [None]:
def render_inline(image, resize=(128, 128)):
    """Convert image into inline html."""
    image = Image.fromarray(image)
    image.resize(resize)
    with io.BytesIO() as buffer:
        image.save(buffer, format='jpeg')
        image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
        return f"data:image/jpeg;base64,{image_b64}"


def render_example(image, caption, annotation=None):
    caption = caption.replace(" .", ".").capitalize()
    image = ((image + 1) / 2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]
    annotation_div = ""
    if annotation:
        annotation_div = f"""
            <br/>
            <p style="width:256px; margin:10px; font-size:small;">{html.escape(annotation)}</p>
        """
    return f"""
    <div style="display: inline-flex; align-items: center; justify-content: center;">
        <img style="width:128px; height:128px;" src="{render_inline(image, resize=(64, 64))}" />
        <p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
        {annotation_div}
    </div>
    """

Define loss and inference functions.

In [None]:
# Evaluation/inference loop.
def make_predictions(data_iterator, *, params, num_examples=None,
                     batch_size=4, seqlen=SEQLEN, sampler="greedy"):
    outputs = []
    while True:
        # Construct a list of examples in the batch.
        examples = []
        try:
            for _ in range(batch_size):
                examples.append(next(data_iterator))
                examples[-1]["_mask"] = np.array(True)  # Indicates true example.
        except StopIteration:
            if len(examples) == 0:
                return outputs

        # Not enough examples to complete a batch. Pad by repeating last example.
        while len(examples) % batch_size:
            examples.append(dict(examples[-1]))
            examples[-1]["_mask"] = np.array(False)  # Indicates padding example.

        # Convert list of examples into a dict of np.arrays and load onto devices.
        batch = jax.tree.map(lambda *x: np.stack(x), *examples)
        batch = reshard(batch, data_sharding)

        # Make model predictions
        tokens = decode({"params": params}, batch=batch,
                        max_decode_len=seqlen, sampler=sampler)

        # Fetch model predictions to device and detokenize.
        tokens, mask = jax.device_get((tokens, batch["_mask"]))
        tokens = tokens[mask]  # remove padding examples.
        responses = [postprocess_tokens(t) for t in tokens]

        # Append to html output.
        for example, response in zip(examples, responses):
            caption = postprocess_tokens(example["annotation"])  # detokenize model input.
            caption = caption[len("caption en\n"):]  # strip prefix
            outputs.append((example["image"], response, caption))
            if num_examples and len(outputs) >= num_examples:
                return outputs

Initialize the W&B project and define the hyperparameters to be tested.

In [None]:
import wandb

wandb.login(key=userdata.get('WANDB_KEY'))

PROJECT_NAME = "DI725_project_2389088"

Start the training with W&B sweeps.

In [None]:
wandb.init(project=PROJECT_NAME)
batch_size = 4
num_examples = 4
base_validation_step = 580
validation_params = params.copy()
for step in range(1, base_validation_step):
    html_out = ""
    sum_score = 0
    for image, caption, annotation in make_predictions(
            validation_data_iterator(), params=validation_params, num_examples=num_examples, batch_size=batch_size):
        # Evaluation
        caption = caption.replace(" .", ".").capitalize()
        annotation = annotation.replace(" .", ".").capitalize()
        score = evaluation_score(annotation, caption)
        sum_score += score
        html_out += render_example(image, caption, annotation)
    wandb.log({"base_step": step, "base_avg_score": sum_score / num_examples})
    print(f"Base average score at step {step}: {sum_score / num_examples}")
    display(HTML(html_out))