# CAPPA demo

This notebook showcases how to use a CAPPA model.

In [None]:
from functools import partial
from io import BytesIO

import jax
import jax.numpy as jnp
import numpy as np
import orbax
import requests
import wandb
from flax.training import orbax_utils
from PIL import Image
from transformers import AutoTokenizer

from clip_jax import CLIPModel
from clip_jax.data import Dataset, image_to_logits, logits_to_image
from clip_jax.utils import load_config

## Loading a trained model

In [None]:
# load tokenizer
tokenizer_name = "xxxx"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

In [None]:
# load model
config_name = "entity/project/config-run_id:latest"
config = load_config(config_name)
model = CLIPModel(**config)

In [None]:
# initialize model
rng = jax.random.PRNGKey(0)
logical_shape = jax.eval_shape(lambda rng: model.init_weights(rng), rng)["params"]
params = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), logical_shape)

In [None]:
# get model checkpoint
api = wandb.Api()
artifact = api.artifact(config_name)
step = artifact.metadata["step"]
model_path = artifact.metadata["output_dir"]
model_path, step

In [None]:
# restore checkpoint
ckpt = {"params": params}
restore_args = orbax_utils.restore_args_from_target(ckpt)
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
orbax_options = orbax.checkpoint.CheckpointManagerOptions()
checkpoint_manager = orbax.checkpoint.CheckpointManager(model_path, orbax_checkpointer, orbax_options)
ckpt = checkpoint_manager.restore(step, ckpt, restore_kwargs={"restore_args": restore_args, "transforms": {}})
params = ckpt["params"]

## Inference

In [None]:
@partial(
    jax.jit,
    static_argnames=("num_beams", "do_sample", "temperature", "top_p", "top_k", "max_length", "num_return_sequences"),
)
def generate_caption(pixel_values, *args, **kwargs):
    return model.generate(pixel_values, *args, **kwargs)


def caption(*args, **kwargs):
    outputs = generate_caption(*args, **kwargs)
    res = outputs.sequences
    return tokenizer.batch_decode(res, skip_special_tokens=True)

In [None]:
# image data
img_url = "https://pics.craiyon.com/2023-06-23/3b050d2ebfcc47e7a2d25265ffc6b588.webp"

response = requests.get(img_url)
img = Image.open(BytesIO(response.content))
img = img.resize((256, 256))
img = img.convert("RGB")
# image inference
pixel_values = image_to_logits(img)
pixel_values = pixel_values[np.newaxis, ...]
img

In [None]:
caption(pixel_values, params=params, num_beams=4)

In [None]:
caption(pixel_values, params=params, do_sample=True, temperature=0.7)

In [None]:
caption(pixel_values, params=params, num_beams=4, num_return_sequences=4)

## Test on a dataset

In [None]:
ds_folder = "xxx"  # need to follow same format as for training, at least tfrecords
ds = Dataset(train_folder=ds_folder, train_batch_size=1, image_crop_resize=256).train

In [None]:
pixel_values, captions = next(ds)
generated_captions = caption(pixel_values, params=params, num_beams=4)
img = Image.fromarray(logits_to_image(pixel_values[0]))
display(img)
print("caption:", captions[0].decode("utf-8"))
print("generated caption:", generated_captions[0])