In [2]:
from typing import Dict, List

import tensorflow as tf
from PIL import Image
from pydantic import FilePath

from clipkit.cliplayers import ClipMe
from train import image_model, proj_dim, text_max_len, text_model, tokenizer


def prepare_image(image_path: FilePath) -> tf.Tensor:
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.resize_with_pad(img, 224, 224)[tf.newaxis, :, :, :]
    return img


def clip_data_prep(image_path: FilePath, captions: List[str]) -> Dict:
    img = prepare_image(image_path)
    enc = tokenizer(
        captions,
        padding="max_length",
        truncation=True,
        max_length=text_max_len,
        return_tensors="np",
    )
    ids, mask = enc["input_ids"], enc["attention_mask"]
    inputs = {"pixel_values": img, "input_ids": ids, "attention_mask": mask}
    return inputs


def get_embeddings(image_path: FilePath, captions: List[str]):
    data_sample = clip_data_prep(image_path=image_path, captions=captions)
    prediction = CLIPME(data_sample, training=False)
    im_vect = tf.math.l2_normalize(prediction[0], axis=-1)
    txt_vetc = tf.math.l2_normalize(prediction[1], axis=-1)
    return im_vect, txt_vetc


def compute_scores(
    image_vector: tf.Tensor,
    text_vector: tf.Tensor,
    captions: List[str],
    top_pred_count: int = 3,
):
    compute_sim = image_vector @ tf.transpose(text_vector)
    out = captions[tf.argmax(compute_sim, axis=-1)[0].numpy()]
    indices = tf.argsort(compute_sim, direction="DESCENDING")[0][
        :top_pred_count
    ].numpy()
    scores = compute_sim[0].numpy()[indices]
    scores = scores.tolist()
    labels = [captions[s] for s in indices]
    output_preds = dict(zip(labels, scores))
    return output_preds

E0000 00:00:1747381946.515077   12711 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747381946.519412   12711 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1747381946.530232   12711 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747381946.530248   12711 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747381946.530249   12711 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747381946.530250   12711 computation_placer.cc:177] computation placer already registered. Please check linka

#### Load the Model Checkpoint

In [None]:
# ckpt dir
checkpoint_path = "breed_model/"

CLIPME = ClipMe(image_model_id=image_model, text_model_id=text_model, proj_dim=proj_dim)
_ = CLIPME(
    {
        "pixel_values": tf.ones((1, 224, 224, 3)),
        "input_ids": tf.ones((1, 12), tf.int32),
        "attention_mask": tf.ones((1, 12), tf.int32),
    }
)

ckpt = tf.train.Checkpoint(net=CLIPME)
ckpt.restore(tf.train.latest_checkpoint(checkpoint_path)).expect_partial()

#### Load the captions for zero shot classification

In [None]:
captions = [
    "the breed is shitzu",
    "the breed is norweight_elkahound",
    "the breed is Maltese",
    "the breed is irish_grayhound",
    "the breed is japanese_spaniel",
    "the breed is bloodhound",
    "the breed is rotweiller",
    "the breed is Komondor",
    "the breed is redbone",
]

In [None]:
image_id = "/home/anish/Desktop/68768d392e81a9864575a1678707565b.jpg"  # image_path
image_vect, text_vect = get_embeddings(image_path=image_id, captions=captions)
predictions = compute_scores(
    image_vector=image_vect, text_vector=text_vect, captions=captions, top_pred_count=3
)
print(predictions)
im = Image.open(image_id)
im.thumbnail((250, 250))
im