In [1]:
!pip install ml_collections
import os
import sys
import csv
import pandas as pd
import numpy as np
from PIL import Image
import tensorflow as tf
import sentencepiece
import jax
import jax.numpy as jnp
import ml_collections
from sklearn.model_selection import train_test_split
import base64
import functools
import html
import io
import warnings
import time
import pickle  # Added for model saving
from IPython.core.display import display, HTML

# Check for TPUs
if "COLAB_TPU_ADDR" in os.environ:
    raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository
if not os.path.exists("big_vision_repo"):
    os.system("git clone --quiet --branch=main --depth=1 https://github.com/google-research/big_vision big_vision_repo")

# Append big_vision to python path
if "big_vision_repo" not in sys.path:
    sys.path.append("big_vision_repo")

# Import big_vision modules
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
import big_vision.utils
import big_vision.sharding

# Suppress TensorFlow GPU/TPU usage
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

# Print JAX environment
backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")

# Download model checkpoint
MODEL_PATH = "./pt_224_128.params.f16.npz"
if not os.path.exists(MODEL_PATH):
    print("Downloading the checkpoint from Kaggle...")
    import kagglehub
    MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')
    print(f"Model path: {MODEL_PATH}")

# Download tokenizer
TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
    print("Downloading the model tokenizer...")
    from google.cloud import storage
    def download_from_gcs(bucket_name, source_blob_name, destination_file_name):
        storage_client = storage.Client.create_anonymous_client()
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(source_blob_name)
        blob.download_to_filename(destination_file_name)
        print(f"Downloaded {source_blob_name} to {destination_file_name}")
    download_from_gcs("big_vision", "paligemma_tokenizer.model", TOKENIZER_PATH)
    print(f"Tokenizer path: {TOKENIZER_PATH}")

# Define dataset paths
DATA_DIR = "/kaggle/input/flickr8k"
CAPTIONS_FILE = os.path.join(DATA_DIR, "captions.txt")
IMAGES_DIR = os.path.join(DATA_DIR, "Images")

# Load Flickr8k dataset with robust parsing
try:
    # Attempt to load with proper quoting to handle commas in captions
    df = pd.read_csv(CAPTIONS_FILE, sep=',', quoting=csv.QUOTE_MINIMAL, on_bad_lines='warn')
    print("Successfully loaded captions.txt")
except pd.errors.ParserError as e:
    print(f"Error parsing captions.txt: {e}")
    print("Attempting to load with skipping bad lines...")
    df = pd.read_csv(CAPTIONS_FILE, sep=',', quoting=csv.QUOTE_MINIMAL, on_bad_lines='skip')
    print("Loaded with skipped lines. Inspect the data for issues.")

# Verify the dataframe
print("Dataframe head:\n", df.head())
print("Dataframe columns:", df.columns)

# Ensure expected columns
if not {'image', 'caption'}.issubset(df.columns):
    raise ValueError("Expected columns 'image' and 'caption' in captions.txt")

# Limit to 1,000 images (random sampling for diversity)
df = df.sample(n=1500, random_state=42).reset_index(drop=True)
print(f"Reduced dataset to {len(df)} images")

# Prepare image paths and captions
image_paths = [os.path.join(IMAGES_DIR, row['image']) for _, row in df.iterrows()]
captions = [row['caption'] for _, row in df.iterrows()]

# Split into train and validation (80-20)
train_images, val_images, train_captions, val_captions = train_test_split(
    image_paths, captions, test_size=0.2, random_state=42
)
print(f"Train images: {len(train_images)}, Validation images: {len(val_images)}")

# Recalculate train steps per epoch
batch_size = 10
train_steps_per_epoch = len(train_images) // batch_size
print(f"Train steps per epoch: {train_steps_per_epoch}")

# Define model
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params
params = paligemma.load(None, MODEL_PATH, model_config)

# Define decode function
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

# Create trainable params mask
def is_trainable_param(name, param):
    if name.startswith("llm/layers/attn/"): return True
    if name.startswith("llm/"): return False
    if name.startswith("img/"): return False
    raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

# Shard parameters
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh
)

# Ignore unusable donated buffers warning
warnings.filterwarnings("ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
    return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p, params, trainable)

# Load parameters with sharding
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
    params[idx] = big_vision.utils.reshard(params[idx], sharding)
    params[idx] = maybe_cast_to_f32(params[idx], trainable)
    params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print parameter overview
def parameter_overview(params):
    for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
        print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")
print(" == Model params == ")
parameter_overview(params)

# Image preprocessing
def preprocess_image(image, size=224):
    image = np.asarray(image)
    if image.ndim == 2:
        image = np.stack((image,)*3, axis=-1)
    image = image[..., :3]
    assert image.shape[-1] == 3
    image = tf.constant(image)
    image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
    return image.numpy() / 127.5 - 1.0  # [0, 255] -> [-1, 1]

# Token preprocessing
def preprocess_tokens(prefix, suffix=None, seqlen=None):
    separator = "\n"
    tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
    mask_ar = [0] * len(tokens)
    mask_loss = [0] * len(tokens)
    if suffix:
        suffix = tokenizer.encode(suffix, add_eos=True)
        tokens += suffix
        mask_ar += [1] * len(suffix)
        mask_loss += [1] * len(suffix)
    mask_input = [1] * len(tokens)
    if seqlen:
        padding = [0] * max(0, seqlen - len(tokens))
        tokens = tokens[:seqlen] + padding
        mask_ar = mask_ar[:seqlen] + padding
        mask_loss = mask_loss[:seqlen] + padding
        mask_input = mask_input[:seqlen] + padding
    return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

# Token postprocessing
def postprocess_tokens(tokens):
    tokens = tokens.tolist()
    try:
        eos_pos = tokens.index(tokenizer.eos_id())
        tokens = tokens[:eos_pos]
    except ValueError:
        pass
    return tokenizer.decode(tokens)

SEQLEN = 128

# Data iterators for Flickr8k
def train_data_iterator():
    dataset = tf.data.Dataset.from_tensor_slices((train_images, train_captions))
    dataset = dataset.shuffle(1000).repeat()
    for image_path, caption in dataset.as_numpy_iterator():
        try:
            image = Image.open(image_path.decode())
            image = preprocess_image(image)
            prefix = "caption en"
            suffix = caption.decode().lower()
            tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)
            yield {
                "image": np.asarray(image),
                "text": np.asarray(tokens),
                "mask_ar": np.asarray(mask_ar),
                "mask_loss": np.asarray(mask_loss),
            }
        except Exception as e:
            print(f"Error processing image {image_path.decode()}: {e}")
            continue

def validation_data_iterator():
    dataset = tf.data.Dataset.from_tensor_slices((val_images, val_captions))
    for image_path, caption in dataset.as_numpy_iterator():
        try:
            image = Image.open(image_path.decode())
            image = preprocess_image(image)
            prefix = "caption en"
            tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)
            yield {
                "image": np.asarray(image),
                "text": np.asarray(tokens),
                "mask_ar": np.asarray(mask_ar),
                "mask_input": np.asarray(mask_input),
            }
        except Exception as e:
            print(f"Error processing image {image_path.decode()}: {e}")
            continue

# Render examples
def render_inline(image, resize=(128, 128)):
    image = Image.fromarray(image)
    image.resize(resize)
    with io.BytesIO() as buffer:
        image.save(buffer, format='jpeg')
        image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
        return f"data:image/jpeg;base64,{image_b64}"

def render_example(image, caption):
    image = ((image + 1)/2 * 255).astype(np.uint8)
    return f"""
        <div style="display: inline-flex; align-items: center; justify-content: center;">
            <img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
            <p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
        </div>
    """

html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
    caption = postprocess_tokens(example["text"])
    caption = caption[len("caption en\n"):] if caption.startswith("caption en\n") else caption
    html_out += render_example(example["image"], caption)

print("Training examples")
display(HTML(html_out))

# Training update function
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
    imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]
    def loss_fn(params):
        text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
        logp = jax.nn.log_softmax(text_logits, axis=-1)
        mask_loss = batch["mask_loss"][:, 1:]
        targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])
        token_pplx = jnp.sum(logp * targets, axis=-1)
        example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)
        example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)
        return jnp.mean(example_loss)
    loss, grads = jax.value_and_grad(loss_fn)(params)
    def apply_grad(param, gradient, trainable):
        if not trainable: return param
        return param - learning_rate * gradient
    params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)
    return params, loss

# Evaluation/inference loop
def make_predictions(data_iterator, *, num_examples=None, batch_size=4, seqlen=SEQLEN, sampler="greedy"):
    outputs = []
    while True:
        examples = []
        try:
            for _ in range(batch_size):
                examples.append(next(data_iterator))
                examples[-1]["_mask"] = np.array(True)
        except StopIteration:
            if len(examples) == 0:
                return outputs
        while len(examples) % batch_size:
            examples.append(dict(examples[-1]))
            examples[-1]["_mask"] = np.array(False)
        batch = jax.tree.map(lambda *x: np.stack(x), *examples)
        batch = big_vision.utils.reshard(batch, data_sharding)
        tokens = decode({"params": params}, batch=batch, max_decode_len=seqlen, sampler=sampler)
        tokens, mask = jax.device_get((tokens, batch["_mask"]))
        tokens = tokens[mask]
        responses = [postprocess_tokens(t) for t in tokens]
        for example, response in zip(examples, responses):
            outputs.append((example["image"], response))
            if num_examples and len(outputs) >= num_examples:
                return outputs

# Training loop
learning_rate = 5e-5
epochs = 10
iteration_times = []  # Track iteration times for averaging
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    train_iter = train_data_iterator()
    iteration_times.clear()  # Reset for each epoch
    for step in range(train_steps_per_epoch):
        start_time = time.time()
        batch_examples = []
        try:
            for _ in range(batch_size):
                batch_examples.append(next(train_iter))
            batch = jax.tree.map(lambda *x: np.stack(x), *batch_examples)
            batch = big_vision.utils.reshard(batch, data_sharding)
            params, loss = update_fn(params, batch, learning_rate)
            iteration_time = time.time() - start_time
            iteration_times.append(iteration_time)
            if step % 100 == 0:
                avg_iteration_time = sum(iteration_times) / len(iteration_times) if iteration_times else iteration_time
                epoch_remaining_steps = train_steps_per_epoch - step - 1
                total_remaining_steps = epoch_remaining_steps + (epochs - epoch - 1) * train_steps_per_epoch
                epoch_remaining_time = epoch_remaining_steps * avg_iteration_time
                total_remaining_time = total_remaining_steps * avg_iteration_time
                print(f"Step {step}, Loss: {loss:.4f}, "
                      f"Iteration Time: {iteration_time:.2f}s, "
                      f"Epoch Remaining: {epoch_remaining_time/60:.2f}min, "
                      f"Total Remaining: {total_remaining_time/60:.2f}min")
        except Exception as e:
            print(f"Error in training step {step}: {e}")
            continue

# Evaluate on validation set
val_predictions = make_predictions(validation_data_iterator(), num_examples=20, batch_size=4)
html_out = ""
for image, caption in val_predictions:
    caption = caption[len("caption en\n"):] if caption.startswith("caption en\n") else caption
    html_out += render_example(image, caption)
print("Validation examples")
display(HTML(html_out))

# Custom save_params implementation to replace big_vision.utils.save_params
def save_params(params, path):
    """Save JAX parameters to a file with proper device transfer."""
    # Get params off device first
    params_flat, treedef = jax.tree.flatten(params)
    
    # Explicitly move arrays from device to host
    params_flat_np = []
    for param in params_flat:
        # Convert to numpy array and ensure it's on host memory
        param_np = np.array(jax.device_get(param))
        params_flat_np.append(param_np)
    
    # Save directly without temporary file
    try:
        arrays_dict = {f'param_{i}': param for i, param in enumerate(params_flat_np)}
        arrays_dict['treedef'] = np.array([pickle.dumps(treedef)], dtype=np.object_)
        np.savez_compressed(path, **arrays_dict)
        print(f"Parameters saved to {path}")
    except Exception as e:
        print(f"Error saving parameters: {e}")
        
        # Fallback to direct saving without dictionary
        try:
            print("Attempting fallback save method...")
            # Save the tree definition separately
            with open(f"{path}.treedef", "wb") as f:
                pickle.dump(treedef, f)
                
            # Save parameters as separate files
            for i, param in enumerate(params_flat_np):
                np.save(f"{path}.{i}", param)
                
            print(f"Parameters saved using fallback method to {path}.* files")
        except Exception as e2:
            print(f"Fallback save also failed: {e2}")

# Custom load_params implementation for future use
def load_params(path):
    """Load JAX parameters from a file."""
    with np.load(path, allow_pickle=True) as data:
        treedef = pickle.loads(data['treedef'][0])
        params_flat = [data[f'param_{i}'] for i in range(len(data) - 1)]
        return jax.tree.unflatten(treedef, params_flat)

# Save model
output_dir = "/kaggle/working/PaliGemma_Fine_Tune_Flickr8k"
os.makedirs(output_dir, exist_ok=True)
save_params(params, os.path.join(output_dir, "paligemma_flickr8k.params.f16.npz"))
print(f"Model saved to {output_dir}")

Collecting ml_collections
  Downloading ml_collections-1.1.0-py3-none-any.whl.metadata (22 kB)
Downloading ml_collections-1.1.0-py3-none-any.whl (76 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ml_collections
Successfully installed ml_collections-1.1.0


2025-05-06 07:55:48.072688: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746518148.257711      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746518148.313327      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


JAX version:  0.4.33
JAX platform: gpu
JAX devices:  1
Downloading the checkpoint from Kaggle...
Model path: /kaggle/input/paligemma/jax/paligemma-3b-pt-224/1/paligemma-3b-pt-224.f16.npz
Downloading the model tokenizer...
Downloaded paligemma_tokenizer.model to ./paligemma_tokenizer.model
Tokenizer path: ./paligemma_tokenizer.model
Successfully loaded captions.txt
Dataframe head:
                        image  \
0  1000268201_693b08cb0e.jpg   
1  1000268201_693b08cb0e.jpg   
2  1000268201_693b08cb0e.jpg   
3  1000268201_693b08cb0e.jpg   
4  1000268201_693b08cb0e.jpg   

                                             caption  
0  A child in a pink dress is climbing up a set o...  
1              A girl going into a wooden building .  
2   A little girl climbing into a wooden playhouse .  
3  A little girl climbing the stairs to her playh...  
4  A little girl in a pink dress going into a woo...  
Dataframe columns: Index(['image', 'caption'], dtype='object')
Reduced dataset to 1500 images

Epoch 1/10
Step 0, Loss: 2.3377, Iteration Time: 25.17s, Epoch Remaining: 49.91min, Total Remaining: 502.90min
Step 100, Loss: 2.2891, Iteration Time: 6.07s, Epoch Remaining: 1.97min, Total Remaining: 114.00min
Epoch 2/10
Step 0, Loss: 1.9370, Iteration Time: 6.08s, Epoch Remaining: 12.06min, Total Remaining: 109.31min
Step 100, Loss: 1.8288, Iteration Time: 6.07s, Epoch Remaining: 1.91min, Total Remaining: 98.49min
Epoch 3/10
Step 0, Loss: 2.2333, Iteration Time: 6.08s, Epoch Remaining: 12.05min, Total Remaining: 97.10min
Step 100, Loss: 1.9253, Iteration Time: 6.07s, Epoch Remaining: 1.91min, Total Remaining: 86.44min
Epoch 4/10
Step 0, Loss: 2.3051, Iteration Time: 6.07s, Epoch Remaining: 12.04min, Total Remaining: 84.90min
Step 100, Loss: 1.8010, Iteration Time: 6.07s, Epoch Remaining: 1.91min, Total Remaining: 74.35min
Epoch 5/10
Step 0, Loss: 2.0471, Iteration Time: 6.07s, Epoch Remaining: 12.04min, Total Remaining: 72.76min
Step 100, Loss: 1.9308, Iteration Time: 6.07s, Epoch Re

Parameters saved to /kaggle/working/PaliGemma_Fine_Tune_Flickr8k/paligemma_flickr8k.params.f16.npz
Model saved to /kaggle/working/PaliGemma_Fine_Tune_Flickr8k


In [2]:
import os
import sys
import numpy as np
from PIL import Image
import tensorflow as tf
import sentencepiece
import jax
import jax.numpy as jnp
import ml_collections
import functools
import html
import io
import base64
import warnings
from IPython.core.display import display, HTML

# Suppress TensorFlow GPU/TPU usage
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

# Append big_vision to python path
if "big_vision_repo" not in sys.path:
    sys.path.append("big_vision_repo")

# Import big_vision modules
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
import big_vision.utils
import big_vision.sharding

# Print JAX environment
backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")

# Define paths
MODEL_PATH = "/kaggle/working/PaliGemma_Fine_Tune_Flickr8k/paligemma_flickr8k.params.f16.npz"
TOKENIZER_PATH = "./paligemma_tokenizer.model"
DATA_DIR = "/kaggle/input/flickr8k"
IMAGES_DIR = os.path.join(DATA_DIR, "Images")

# Custom load_params implementation
def load_params(path):
    """Load JAX parameters from a file."""
    try:
        with np.load(path, allow_pickle=True) as data:
            # Check if it's saved in the "param_X" format
            if 'treedef' in data:
                import pickle
                treedef = pickle.loads(data['treedef'][0])
                params_flat = [data[f'param_{i}'] for i in range(len(data) - 1)]
                return jax.tree.unflatten(treedef, params_flat)
            else:
                # Fall back to traditional NPZ loading
                return {k: data[k] for k in data.files}
    except Exception as e:
        print(f"Error loading model: {e}")
        # Try fallback method
        print("Trying fallback loading method...")
        try:
            # Load tree definition
            with open(f"{path}.treedef", "wb") as f:
                import pickle
                treedef = pickle.load(f)
                
            # Load parameters
            params_flat = []
            i = 0
            while os.path.exists(f"{path}.{i}.npy"):
                params_flat.append(np.load(f"{path}.{i}.npy"))
                i += 1
                
            return jax.tree.unflatten(treedef, params_flat)
        except Exception as e2:
            print(f"Fallback loading also failed: {e2}")
            raise

# Define model
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load fine-tuned parameters
print(f"Loading fine-tuned model from {MODEL_PATH}")
try:
    # First try using the paligemma.load function
    params = paligemma.load(None, MODEL_PATH, model_config)
    print("Successfully loaded model with paligemma.load")
except Exception as e:
    print(f"Error loading model with paligemma.load: {e}")
    print("Falling back to custom load function")
    params = load_params(MODEL_PATH)
    print("Successfully loaded model with custom loader")

# Define decode function
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

# Create trainable params mask (for completeness)
def is_trainable_param(name, param):
    if name.startswith("llm/layers/attn/"): return True
    if name.startswith("llm/"): return False
    if name.startswith("img/"): return False
    raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

# Shard parameters
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh
)

# Ignore unusable donated buffers warning
warnings.filterwarnings("ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
    return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p, params, trainable)

# Load parameters with sharding
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
    params[idx] = big_vision.utils.reshard(params[idx], sharding)
    params[idx] = maybe_cast_to_f32(params[idx], trainable)
    params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

print("Model parameters loaded and resharded successfully")

# Image preprocessing
def preprocess_image(image, size=224):
    image = np.asarray(image)
    if image.ndim == 2:
        image = np.stack((image,)*3, axis=-1)
    image = image[..., :3]
    assert image.shape[-1] == 3
    image = tf.constant(image)
    image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
    return image.numpy() / 127.5 - 1.0  # [0, 255] -> [-1, 1]

# Token preprocessing
def preprocess_tokens(prefix, suffix=None, seqlen=None):
    separator = "\n"
    tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
    mask_ar = [0] * len(tokens)
    mask_loss = [0] * len(tokens)
    if suffix:
        suffix = tokenizer.encode(suffix, add_eos=True)
        tokens += suffix
        mask_ar += [1] * len(suffix)
        mask_loss += [1] * len(suffix)
    mask_input = [1] * len(tokens)
    if seqlen:
        padding = [0] * max(0, seqlen - len(tokens))
        tokens = tokens[:seqlen] + padding
        mask_ar = mask_ar[:seqlen] + padding
        mask_loss = mask_loss[:seqlen] + padding
        mask_input = mask_input[:seqlen] + padding
    return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

# Token postprocessing
def postprocess_tokens(tokens):
    tokens = tokens.tolist()
    try:
        eos_pos = tokens.index(tokenizer.eos_id())
        tokens = tokens[:eos_pos]
    except ValueError:
        pass
    return tokenizer.decode(tokens)

SEQLEN = 128

# Test on new images
def test_image_captioning(image_paths, batch_size=1, sampler="greedy"):
    results = []
    for i in range(0, len(image_paths), batch_size):
        batch_image_paths = image_paths[i:i+batch_size]
        batch_examples = []
        
        for image_path in batch_image_paths:
            try:
                image = Image.open(image_path)
                image = preprocess_image(image)
                prefix = "caption en"
                tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)
                batch_examples.append({
                    "image": np.asarray(image),
                    "text": np.asarray(tokens),
                    "mask_ar": np.asarray(mask_ar),
                    "mask_input": np.asarray(mask_input),
                    "_mask": np.array(True)  # Valid example
                })
            except Exception as e:
                print(f"Error processing image {image_path}: {e}")
                continue
        
        # Pad to batch size if needed
        while len(batch_examples) % batch_size:
            if len(batch_examples) > 0:
                batch_examples.append(dict(batch_examples[-1]))
                batch_examples[-1]["_mask"] = np.array(False)  # Mark as padding
            else:
                # Handle case where all images in batch failed
                break
        
        if not batch_examples:
            continue
            
        # Stack batch and move to device
        batch = jax.tree.map(lambda *x: np.stack(x), *batch_examples)
        batch = big_vision.utils.reshard(batch, data_sharding)
        
        # Generate captions
        tokens = decode({"params": params}, batch=batch, max_decode_len=SEQLEN, sampler=sampler)
        tokens, mask = jax.device_get((tokens, batch["_mask"]))
        tokens = tokens[mask]
        
        # Post-process captions
        for j, token_seq in enumerate(tokens):
            caption = postprocess_tokens(token_seq)
            caption = caption[len("caption en\n"):] if caption.startswith("caption en\n") else caption
            original_image = Image.open(batch_image_paths[j])
            results.append((original_image, batch_image_paths[j], caption))
    
    return results

# Render results 
def render_inline(image, resize=(128, 128)):
    image = image.resize(resize)
    with io.BytesIO() as buffer:
        image.save(buffer, format='jpeg')
        image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
        return f"data:image/jpeg;base64,{image_b64}"

def render_test_results(results):
    html_output = "<div style='display: flex; flex-wrap: wrap;'>"
    for image, image_path, caption in results:
        filename = os.path.basename(image_path)
        html_output += f"""
            <div style="width: 200px; margin: 10px; text-align: center;">
                <img style="width:180px; height:auto; max-height:180px; object-fit:contain;" 
                     src="{render_inline(image, resize=(180, 180))}" />
                <p style="font-size:12px; color:#666;">{filename}</p>
                <p style="font-size:14px;">{html.escape(caption)}</p>
            </div>
        """
    html_output += "</div>"
    display(HTML(html_output))

# Get list of test images
import random
all_images = [os.path.join(IMAGES_DIR, fname) for fname in os.listdir(IMAGES_DIR) if fname.endswith(('.jpg', '.jpeg', '.png'))]
test_images = random.sample(all_images, min(10, len(all_images)))
print(f"Testing on {len(test_images)} random images")

# Run the test
print("Generating captions for test images...")
test_results = test_image_captioning(test_images, batch_size=2)

# Display results
print(f"Results for {len(test_results)} images:")
render_test_results(test_results)


JAX version:  0.4.33
JAX platform: gpu
JAX devices:  1
Loading fine-tuned model from /kaggle/working/PaliGemma_Fine_Tune_Flickr8k/paligemma_flickr8k.params.f16.npz
Error loading model with paligemma.load: Object arrays cannot be loaded when allow_pickle=False
Falling back to custom load function
Successfully loaded model with custom loader
Model parameters loaded and resharded successfully
Testing on 10 random images
Generating captions for test images...
Results for 10 images:
