<a href="https://colab.research.google.com/github/Stability-AI/model-demo-notebooks/blob/main/japanese_stable_clip.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Japanese Stable CLIP Demo
This is a demo for text retrieval using [Japanese Stable CLIP](https://huggingface.co/stabilityai/japanese-stable-clip-vit-l-16) from [Stability AI](https://stability.ai/).

- Blog: https://ja.stability.ai/blog/japanese-stable-clip
- Twitter: https://twitter.com/StabilityAI_JP
- Discord: https://discord.com/invite/StableJP


In [None]:
#@title Setup
!nvidia-smi
!pip install ftfy regex tqdm gradio transformers sentencepiece 'accelerate>=0.12.0' 'bitsandbytes>=0.31.5'
# download samples image for demo
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/2/29/JAPANPOST-DSC00250.JPG/500px-JAPANPOST-DSC00250.JPG -O sample1.png
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/1/1c/Search_and_rescue_at_Unosumai%2C_Kamaishi%2C_-17_Mar._2011_a.jpg/500px-Search_and_rescue_at_Unosumai%2C_Kamaishi%2C_-17_Mar._2011_a.jpg -O sample2.png
!wget https://upload.wikimedia.org/wikipedia/commons/thumb/6/60/Policeman_at_Tokyo.jpg/500px-Policeman_at_Tokyo.jpg -O sample3.png

In [None]:
# @title Login HuggingFace
!huggingface-cli login

In [None]:
#@title Load Japanese Stable CLIP
from typing import Union, List
import ftfy, html, re, io
import requests
from PIL import Image
import torch
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor, BatchFeature


device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "stabilityai/japanese-stable-clip-vit-l-16"
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
processor = AutoImageProcessor.from_pretrained(model_path)


# taken from https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/tokenizer.py#L65C8-L65C8
def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r"\s+", " ", text)
    text = text.strip()
    return text


def tokenize(
    texts: Union[str, List[str]],
    max_seq_len: int = 77,
):
    """
    This is a function that have the original clip's code has.
    https://github.com/openai/CLIP/blob/main/clip/clip.py#L195
    """
    if isinstance(texts, str):
        texts = [texts]
    texts = [whitespace_clean(basic_clean(text)) for text in texts]

    inputs = tokenizer(
        texts,
        max_length=max_seq_len - 1,
        padding="max_length",
        truncation=True,
        add_special_tokens=False,
    )
    # add bos token at first place
    input_ids = [[tokenizer.bos_token_id] + ids for ids in inputs["input_ids"]]
    attention_mask = [[1] + am for am in inputs["attention_mask"]]
    position_ids = [list(range(0, len(input_ids[0])))] * len(texts)

    return BatchFeature(
        {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "position_ids": torch.tensor(position_ids, dtype=torch.long),
        }
    )


def compute_text_embeddings(text):
  if isinstance(text, str):
    text = [text]
  text = tokenize(texts=text)
  text_features = model.get_text_features(**text.to(device))
  text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
  del text
  return text_features.cpu().detach()

def compute_image_embeddings(image):
  image = processor(images=image, return_tensors="pt").to(device)
  image_features = model.get_image_features(**image)
  image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
  del image
  return image_features.cpu().detach()


In [None]:
#@title Prepare for the demo
#@markdown Please feel free to change `categories` for your usage.
categories = [
    "配達員",
    "営業",
    "消防士",
    "救急隊員",
    "自衛隊",
    "スポーツ選手",
    "警察官",
]

# pre-compute text embeddings
text_embeds = compute_text_embeddings(categories)

In [None]:
# @title Launch the demo
import gradio as gr

num_categories = len(categories)
TOP_K = 3


def inference_fn(img):
  image_embeds = compute_image_embeddings(img)
  similarity = (100.0 * image_embeds @ text_embeds.T).softmax(dim=-1)
  similarity = similarity[0].numpy().tolist()
  output_dict = {categories[i]: float(similarity[i]) for i in range(num_categories)}
  del image_embeds
  return output_dict


with gr.Blocks() as demo:
    gr.Markdown("# Japanese Stable CLIP Demo")
    gr.Markdown(
        """[Japanese Stable CLIP](https://huggingface.co/stabilityai/japanese-stable-clip-vit-l-16) is a [CLIP](https://arxiv.org/abs/2103.00020) model by [Stability AI](https://ja.stability.ai/).
                - Blog: https://ja.stability.ai/blog/japanese-stable-clip
                - Twitter: https://twitter.com/StabilityAI_JP
                - Discord: https://discord.com/invite/StableJP"""
    )
    with gr.Row():
      with gr.Column():
        inp = gr.Image(type="pil")
      with gr.Column():
        out = gr.Label(num_top_classes=TOP_K)

    btn = gr.Button("Run")
    btn.click(fn=inference_fn, inputs=inp, outputs=out)
    examples = gr.Examples(
        examples=[
            # https://ja.wikipedia.org/wiki/%E9%83%B5%E4%BE%BF
            "sample1.png",
            # https://ja.wikipedia.org/wiki/%E6%97%A5%E6%9C%AC%E3%81%AE%E6%B6%88%E9%98%B2
            "sample2.png",
            # https://ja.wikipedia.org/wiki/%E6%97%A5%E6%9C%AC%E3%81%AE%E8%AD%A6%E5%AF%9F%E5%AE%98
            "sample3.png",
        ],
        inputs=inp
    )

if __name__ == "__main__":
    demo.launch(debug=True, share=True)


### Credit

* The mail carrier photograph was taken by Ryouta0411 and is provided under the Creative Commons Attribution-Share Alike 3.0 Unported License.
  * Source: https://commons.wikimedia.org/wiki/File:JAPANPOST-DSC00250.JPG
  * License: https://creativecommons.org/licenses/by-sa/3.0/
* The firefighter photograph was taken by Master Sgt. Jeremy Lock and is in the public domain.
  * Source: https://commons.wikimedia.org/wiki/File:Search_and_rescue_at_Unosumai,_Kamaishi,_-17_Mar._2011_a.jpg
* The police officer photograph was taken by Spaz Tacular and is provided under the Creative Commons Attribution 2.0 Generic License.
  * Source: https://commons.wikimedia.org/wiki/File:Policeman_at_Tokyo.jpg
  * License: https://creativecommons.org/licenses/by/2.0/deed.en