Gradio Demo for Fine-tuned PaliGemma LLM:
</br>1) Run all cells
</br>2) Enter demo image url in the interface

In [1]:
!pip install gradio
!pip install transformers
!pip install datasets
!pip install pandas
!pip install pillow
!pip install requests
!pip install gdown

Collecting gradio
  Downloading gradio-4.36.1-py3-none-any.whl (12.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)
Collecting fastapi (from gradio)
  Downloading fastapi-0.111.0-py3-none-any.whl (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.0/92.0 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ffmpy (from gradio)
  Downloading ffmpy-0.3.2.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gradio-client==1.0.1 (from gradio)
  Downloading gradio_client-1.0.1-py3-none-any.whl (318 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.1/318.1 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting httpx>=0.24.1 (from gradio)
  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━

In [2]:
# @title Fetch big_vision code and install dependencies.
import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for ml_collections (setup.py) ... [?25l[?25hdone


Configure your API key to access Kaggle
To use PaliGemma, you must provide your Kaggle username and a Kaggle API key.

To generate a Kaggle API key, go to the Account tab of your Kaggle user profile and select Create New Token. This will trigger the download of a kaggle.json file containing your API credentials.
In Colab, select Secrets (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name KAGGLE_USERNAME and your API key under the name KAGGLE_KEY.
To be able to download, you will also need to acknowledge the Terms and Conditions of the PaliGemma on:

https://www.kaggle.com/models/google/paligemma/

In [3]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json

# os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
# os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [4]:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="50"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

In [5]:

import base64
import functools
import html
import io
import os
import warnings
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections

import tensorflow as tf
import sentencepiece


from IPython.core.display import display, HTML
from PIL import Image

# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding

# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

backend = jax.lib.xla_bridge.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")


JAX version:  0.4.26
JAX platform: cpu
JAX devices:  1


In [6]:
# @title Download checkpoint, tokenizer and dataset to local filesystem.
#
import os
import kagglehub
import gdown

MODEL_PATH = './my-custom-paligemma-ckpt.npz'
if not os.path.exists(MODEL_PATH):
  weights_link = "https://drive.google.com/file/d/1iL1Gli0ft4LA-kPU2HIoC5ah9eNsl-Xt/view?usp=sharing"
  file_path = 'my-custom-paligemma-ckpt.npz'
  gdown.download(weights_link, file_path, quiet=False,fuzzy=True)

TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")



Downloading...
From (original): https://drive.google.com/uc?id=1iL1Gli0ft4LA-kPU2HIoC5ah9eNsl-Xt
From (redirected): https://drive.google.com/uc?id=1iL1Gli0ft4LA-kPU2HIoC5ah9eNsl-Xt&confirm=t&uuid=e7c0fd80-2f5d-4373-9b5e-15d446cc9c9e
To: /content/my-custom-paligemma-ckpt.npz
100%|██████████| 6.19G/6.19G [02:01<00:00, 50.9MB/s]


Downloading the model tokenizer...
Copying gs://big_vision/paligemma_tokenizer.model...
/ [1 files][  4.1 MiB/  4.1 MiB]                                                
Operation completed over 1 objects/4.1 MiB.                                      
Tokenizer path: ./paligemma_tokenizer.model


In [7]:
TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

In [8]:
# @title Construct model and load params into RAM.

# Define model
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)

# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

In [9]:
# @title Move params to GPU/TPU memory.
#
# To keep HBM usage low and fit in a T4 GPU (16GB HBM) we opt to only finetune
# a part of the parameters. Additionally we keep the frozen params in float16
# and cast trainable to float32.

# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return False
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

#
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,
                      params, trainable)

# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default (12GB RAM).
# Instead we do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)

 == Model params == 
img/Transformer/encoder_norm/bias                                                (1152,)                float16
img/Transformer/encoder_norm/scale                                               (1152,)                float16
img/Transformer/encoderblock/LayerNorm_0/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_0/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias                             (27, 4304)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel                           (27, 1152, 4304)       float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias                             (2

In [10]:
# @title Define preprocess functions to create inputs to the model.

from PIL import Image
from io import BytesIO
import requests

def download_image(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    return img

def preprocess_image(image, size=224):
  # Model has been trained to handle images of different aspects ratios
  # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
  # options are helpful to improve quality in some tasks.
  image = np.asarray(image)
  if image.ndim == 2:  # Convert image without last channel into greyscale.
    image = np.stack((image,)*3, axis=-1)
  image = image[..., :3]  # Remove alpha layer.
  assert image.shape[-1] == 3

  image = tf.constant(image)
  image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
  return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]

def preprocess_tokens(prefix, suffix=None, seqlen=None):
  # Model has been trained to handle tokenized text composed of a prefix with
  # full attention and a suffix with causal attention.
  separator = "\n"
  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

  if suffix:
    suffix = tokenizer.encode(suffix, add_eos=True)
    tokens += suffix
    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

  mask_input = [1] * len(tokens)    # 1 if its a token, 0 if padding.
  if seqlen:
    padding = [0] * max(0, seqlen - len(tokens))
    tokens = tokens[:seqlen] + padding
    mask_ar = mask_ar[:seqlen] + padding
    mask_loss = mask_loss[:seqlen] + padding
    mask_input = mask_input[:seqlen] + padding

  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

def postprocess_tokens(tokens):
  tokens = tokens.tolist()  # np.array to list[int]
  try:  # Remove tokens at and after EOS if any.
    eos_pos = tokens.index(tokenizer.eos_id())
    tokens = tokens[:eos_pos]
  except ValueError:
    pass
  return tokenizer.decode(tokens)


In [11]:
SEQLEN = 128
def example_processing(raw_image):

  # img_url = requests.get(img_url).content
  # raw_image = Image.open(io.BytesIO(img_url))
  image = preprocess_image(raw_image)

  prefix = "caption en"  # Could also be a different prefix per example.
  tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)

  return {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_input": np.asarray(mask_input)
    }

In [66]:
def model_inference(params,example,seqlen=SEQLEN,sampler='greedy'):
  test_img=[]
  # try:
  #   test_img.append(next(example))
  #   test_img[-1]["_mask"] = np.array(True)
  # except:
  #   return "Error Processing Image"
  test_img.append(example)
  test_img[-1]["_mask"] = np.array(True)
  test_batch = jax.tree_util.tree_map(lambda *x: np.stack(x), *test_img)
  test_batch = big_vision.utils.reshard(test_batch, data_sharding)
  test_batch
  # Make model predictions
  tokens = decode({"params": params}, batch=test_batch,
                  max_decode_len=seqlen, sampler=sampler)

  # Fetch model predictions to device and detokenize.
  tokens, mask = jax.device_get((tokens, test_batch["_mask"]))
  tokens = tokens[mask]  # remove padding examples.
  response = [postprocess_tokens(t) for t in tokens]
  # response=["1"]
  return response[0]

In [67]:
def model_generate_caption(img_url):
  img_url = requests.get(img_url).content
  raw_image = Image.open(io.BytesIO(img_url))
  example = example_processing(raw_image)
  caption = model_inference(params,example)
  # caption="jo"
  return raw_image,caption


In [68]:
def generate_caption(img_url):
  return "Hole"

In [113]:

iface = gr.Interface(
    fn=model_generate_caption,
    inputs=gr.Textbox(label="Image URL"),
    # outputs=[gr.Image(label="Fetched Image"), gr.Textbox(label="Generated Caption")],
    outputs = gr.Textbox(label="Generated Caption"),
    title="Image Caption Generator",
    description="Provide an image URL to generate a caption."
)

In [115]:
iface.launch()

 In case Gradio doesn't work Pls remove the comment for below line and run this cell


In [116]:
demo_url = 'https://storage.googleapis.com/docci/data/images_aar/aar_test_04600.jpg'
# raw_image,caption = model_generate_caption(demo_url)