<a href="https://colab.research.google.com/github/AniketGurav/CLIP/blob/main/genshin_finetuner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Genshin Impact! Fine-tuning CLIP for anime search

Today let's build a search-anime system. We will use text as our query and get images as result. For this we would usually need to manually annotate the image with some tags, often referred to as `TBIR` (Text/Tag-based Image Retrieval). And for this example, we will use [OpenAI CLIP](https://github.com/openai/CLIP).

CLIP is a powerful embedding model that outputs the similarity between text and images. And while it delivers great results, its zero-shot capability needs to be improved on domain data.

We'll build our search system using [Jina](https://docs.jina.ai/) with CLIP as the encoder. But since we already said that CLIP needs to improve its zero-shot capabilities, we will fine-tune our model with [Finetuner](https://github.com/jina-ai/finetuner). But we want to make sure fine-tuning is improving our results, so to examine the effectiveness of the original CLIP on an anime dataset, we'll fine-tune the CLIP model with half of the dataset, then see if the search results improve.

## The data

![Albedo](https://storage.googleapis.com/kaggle-datasets-images/2071434/3438247/7cae3f8ae6d0e09f93df60bea523f7aa/dataset-cover.jpg?t=2022-04-10-16-04-06)

The dataset is a public dataset from [Kaggle](https://www.kaggle.com/datasets/just1ce5/genshin-impact-characters-dataset) from the popular game [Genshin Impact](https://genshin.hoyoverse.com/pc-launcher/?utm_source=EU_google_DE_search_20220720&mhy_trace_channel=ga_channel&new=1&gclid=Cj0KCQjwz96WBhC8ARIsAATR250k1qTcxh8i1saR-tXXDt3SFlk2XGV93oMz3DZEF8T7Zs8RrTxT0LIaAgFUEALw_wcB#/GI008). It contatins five categories: four Genshin characters and one for 'None of them'

- Albedo
- Ayaka
- Hu Tao
- Kokomi
- Nether (None of these characters)

Each character has 100 images, for a total of 500 images. We'll fine-tune `CLIP` with only 250 images.

## Download data and install dependencies

In [1]:
!pip install -U "finetuner[full]" # finetuner-full will install ML libraries such as torchvision and transformers

!pip install gdown
!pip install git+https://github.com/openai/CLIP.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting finetuner[full]
  Downloading finetuner-0.7.6.tar.gz (36 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting docarray[common]<0.30.0
  Downloading docarray-0.21.0.tar.gz (658 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m658.0/658.0 kB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting finetuner-stubs==0.13.4
  Downloading finetuner_stubs-0.13.4-py3-none-any.whl (15 kB)
Collecting jina-hubble-sdk==0.33.1
  Downloading jina_hubble_sdk-0.33.1-py3-none-any.whl (67 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.7/67.7 kB[0m [31

In [2]:
!pip show finetuner

Name: finetuner
Version: 0.7.6
Summary: Task-oriented finetuning for better embeddings on neural search.
Home-page: https://github.com/jina-ai/finetuner/
Author: Jina AI
Author-email: hello@jina.ai
License: Apache 2.0
Location: /usr/local/lib/python3.9/dist-packages
Requires: docarray, finetuner-stubs, jina-hubble-sdk, trimesh
Required-by: 


In [3]:
!gdown https://drive.google.com/uc?id=1i4UdF-DwiR00HUNFYsyXI5bmOqCdLLtJ
!unzip /content/dataset.zip

Downloading...
From: https://drive.google.com/uc?id=1i4UdF-DwiR00HUNFYsyXI5bmOqCdLLtJ
To: /content/dataset.zip
100% 42.0M/42.0M [00:00<00:00, 108MB/s] 
Archive:  /content/dataset.zip
   creating: dataset/
  inflating: __MACOSX/._dataset      
   creating: dataset/Neither/
  inflating: __MACOSX/dataset/._Neither  
   creating: dataset/Ayaka/
  inflating: __MACOSX/dataset/._Ayaka  
   creating: dataset/Albedo/
  inflating: __MACOSX/dataset/._Albedo  
   creating: dataset/Hu Tao/
  inflating: __MACOSX/dataset/._Hu Tao  
   creating: dataset/Kokomi/
  inflating: __MACOSX/dataset/._Kokomi  
  inflating: dataset/Neither/63.jpg  
  inflating: __MACOSX/dataset/Neither/._63.jpg  
  inflating: dataset/Neither/77.jpg  
  inflating: __MACOSX/dataset/Neither/._77.jpg  
  inflating: dataset/Neither/88.jpg  
  inflating: __MACOSX/dataset/Neither/._88.jpg  
  inflating: dataset/Neither/89.jpg  
  inflating: __MACOSX/dataset/Neither/._89.jpg  
  inflating: dataset/Neither/76.jpg  
  inflating: __MACOSX

## Zero-Shot CLIP Retrieval

Now we have downloaded the dataset in the `/content/dataset/` folder. We'll use `docarray` to build a search system, and use pre-trained `CLIP ViT-B/32` as the encoder to encode the images.

We'll construct some text queries about the Gineshi characters, such as `Ayaka dancing`, `Hu Tao fighting`, `Albedo fly` etc, to see if the zero-shot CLIP works.

In [4]:
from docarray import Document, DocumentArray

da = DocumentArray.from_files('/content/dataset/*/*.*')

def assign_labels(d: Document):
    d.tags['finetuner_label'] = d.uri.split('/')[3]
    return d

da.apply(assign_labels, show_progress=True)
# shuffle and train-test-split to 50-50
da = da.shuffle()
train_da = da[:250]
test_da = da[250:]

Output()

In [5]:
test_da[0]

In [6]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

def preprocess_and_encode_image(d: Document):
    """Preprocess image and extract embeddings from CLIP image encoder"""
    d.tensor = preprocess(Image.open(d.uri)).unsqueeze(0).to(device)
    d.embedding = model.encode_image(d.tensor).cpu().detach().numpy().squeeze()
    d.pop('tensor')
    return d

def preprocess_and_encode_text(d: Document):
    """Preprocess image and extract embeddings from CLIP text encoder"""
    d.tensor = clip.tokenize(d.text).to(device)
    d.embedding = model.encode_text(d.tensor).cpu().detach().numpy().squeeze()
    d.pop('tensor')
    return d

100%|███████████████████████████████████████| 338M/338M [00:06<00:00, 55.5MiB/s]


Now we have our data and pre-processing functions ready. The next steps are:

1. Apply feature extraction on all the `test_da`.

2. Construct some text queries to find similar image Documents.

In [7]:
# embed all test documents
test_da.apply(preprocess_and_encode_image, show_progress=True)

Output()

In [8]:
# embed a user query
query = 'Hu Tao fighting'
query_docs = DocumentArray([Document(content=query)])
query_docs.apply(preprocess_and_encode_text)
# find top 5 matches
query_docs.match(test_da, metric='cosine', limit=5)
# plot matches
for idx, match in enumerate(query_docs[0].matches):
    print(f'The matched document agains query \033[1m{query}\033[0m at position \033[1m{idx}\033[0m is \033[1m{match.tags["finetuner_label"]}\033[0m')
    match.display()

The matched document agains query [1mHu Tao fighting[0m at position [1m0[0m is [1mAyaka[0m


<IPython.core.display.Image object>

<IPython.core.display.Image object>

<IPython.core.display.Image object>

<IPython.core.display.Image object>

<IPython.core.display.Image object>

The matched document agains query [1mHu Tao fighting[0m at position [1m1[0m is [1mHu Tao[0m
The matched document agains query [1mHu Tao fighting[0m at position [1m2[0m is [1mAyaka[0m
The matched document agains query [1mHu Tao fighting[0m at position [1m3[0m is [1mNeither[0m
The matched document agains query [1mHu Tao fighting[0m at position [1m4[0m is [1mNeither[0m


In [9]:
# embed a user query
query = 'Ayaka dancing'
query_docs = DocumentArray([Document(content=query)])
query_docs.apply(preprocess_and_encode_text)
# find top 5 matches
query_docs.match(test_da, metric='cosine', limit=5)
# plot matches
for idx, match in enumerate(query_docs[0].matches):
    print(f'The matched document agains query \033[1m{query}\033[0m at position \033[1m{idx}\033[0m is \033[1m{match.tags["finetuner_label"]}\033[0m')
    match.display()

The matched document agains query [1mAyaka dancing[0m at position [1m0[0m is [1mKokomi[0m


<IPython.core.display.Image object>

The matched document agains query [1mAyaka dancing[0m at position [1m1[0m is [1mAyaka[0m


<IPython.core.display.Image object>

The matched document agains query [1mAyaka dancing[0m at position [1m2[0m is [1mAyaka[0m


<IPython.core.display.Image object>

The matched document agains query [1mAyaka dancing[0m at position [1m3[0m is [1mAyaka[0m


<IPython.core.display.Image object>

The matched document agains query [1mAyaka dancing[0m at position [1m4[0m is [1mAyaka[0m


<IPython.core.display.Image object>

In [10]:
# embed a user query
query = 'Albedo flying'
query_docs = DocumentArray([Document(content=query)])
query_docs.apply(preprocess_and_encode_text)
# find top 5 matches
query_docs.match(test_da, metric='cosine', limit=5)
# plot matches
for idx, match in enumerate(query_docs[0].matches):
    print(f'The matched document agains query \033[1m{query}\033[0m at position \033[1m{idx}\033[0m is \033[1m{match.tags["finetuner_label"]}\033[0m')
    match.display()

The matched document agains query [1mAlbedo flying[0m at position [1m0[0m is [1mNeither[0m


<IPython.core.display.Image object>

The matched document agains query [1mAlbedo flying[0m at position [1m1[0m is [1mHu Tao[0m


<IPython.core.display.Image object>

The matched document agains query [1mAlbedo flying[0m at position [1m2[0m is [1mHu Tao[0m


<IPython.core.display.Image object>

The matched document agains query [1mAlbedo flying[0m at position [1m3[0m is [1mHu Tao[0m


<IPython.core.display.Image object>

The matched document agains query [1mAlbedo flying[0m at position [1m4[0m is [1mAyaka[0m


<IPython.core.display.Image object>

As you can see we created three user queries:

1. `Hu Tao Fighting`: In the top five results we only got one match. While thre matches are related to Ayaka.
2. `Ayaka dancing`: In this case, zero-shot CLIP returns good results. One of the reasons is that in Genshin Impact [Ayaka's cutscene](https://youtu.be/g-o-l8j6d8Q) is dancing, this is rare for other characters.
3. `Albedo flying`: This is a hard query, and as expected, none of the returned matches are related to Albedo, while most of the images are indeed related to flying.

So it's ok-ish but now, let's fine-tune CLIP to see if we can get better results.

## CLIP Fine-tuning

Fine-tuning the CLIP model is not easy. It involves two models: the `clip image encoder` and `clip text encoder` and optimize the clip loss jointly. Luckily, the package `finetuner` makes it easy!

The GitHub repo: https://github.com/jina-ai/finetuner

The documentation: https://finetuner.jina.ai/

In short, `finetuner` receives a `DocumentArray` as training data and ouputs an `artifact`. For the case of CLIP, the `artifact` contains 2 onnx files (for speeding up inference). We'll come back to this later.

In [15]:
import finetuner


finetuner.notebook_login(force=True)

AttributeError: ignored

Now we need to prepare training data for CLIP. As you might already know CLIP was trained on pairs of textual data and image data, it can be organized as a `Document` with two `chunks` inside. For example:

```python
train_da = DocumentArray([
    Document(
        chunks=[
            Document(
                content='pencil skirt slim fit available for sell',
                modality='text',
            ),
            Document(
                uri='https://...skirt-1.png',
                modality='image',
            ),
        ],
    ),
    ...# more documents
])
```

We just need to clean up our `train_da` to adopt the above form. It should be noted that if an image is stored locally, you can convert it to a `tensor` object:

In [16]:
pairs = DocumentArray() # initialize a DocumentArray as final training data.

prompt = 'This is a photo of '
for doc in train_da:
    pair = Document()
    img_chunk = doc.load_uri_to_image_tensor(224, 224)
    img_chunk.modality = 'image'
    txt_chunk = Document(content=prompt + doc.tags['finetuner_label'])
    txt_chunk.modality = 'text'
    pair.chunks.extend([img_chunk, txt_chunk])
    # add pair to pairs
    pairs.append(pair)

# Lets see the first item of the pairs
pairs[0]

In [17]:
len(pairs)

250

As you can see above, now we have a list of pairs of `Document`s which forms a `DocumentArray`, in each pair, we have:

1. An `image chunk`: containing the image `tensor` object and `modality='image'`
2. A `text_chunk`: containing the `text` which is identical to the prompt and the name of the character (e.g. `this is a photo of Ayaka`) and `modality='text'`.

Now, we are ready to train our model.

Before training, something basic concepts to be aware of is that:

1. `experiment` and `run`: represent how your different settings are organized: an experiment contains multiple runs with different hyper-parameter settings for your fine-tuning job.
2. Once you call `finetuner.fit` function, it will schedule CPU/GPU/memory and all computation resources in the cloud (for free!).

In [19]:
run = finetuner.fit(
    model='openai/clip-vit-base-patch32', # fine-tune CLIP
    train_data=pairs,   
    learning_rate=1e-5,
    loss='CLIPLoss',
    cpu=False,
)

ERROR:Client:Please report this session_id: [yellow bold]c9837e4a-e4cf-11ed-8e04-0242ac1c000c[/] to https://github.com/jina-ai/jina-hubble-sdk/issues


AuthenticationRequiredError: ignored

Model fine-tuning might take some time. You can trace your job with:

In [None]:
for entry in run.stream_logs():
    print(entry)

Output()

[20:09:31] INFO     Starting finetuner run ...                                                           __main__.py:113
           DEBUG    Found Jina AI Cloud authentication token                                             __main__.py:125
           DEBUG    Running in online mode                                                               __main__.py:126
           INFO     Reading config ...                                                                   __main__.py:133
           DEBUG    Reading config from stream                                                           __main__.py:145
           INFO     Parsing config ...                                                                   __main__.py:148
           INFO     Config loaded 📜                                                                     __main__.py:150
           INFO     Run name: brave-lamport                                                              __main__.py:152
           INFO     Experiment na

In [12]:
# save the artifact
artifact = run.save_artifact('/content/')

NameError: ignored

In [13]:
clip_text_encoder = finetuner.get_model(artifact=artifact, select_model='clip-text')
clip_image_encoder = finetuner.get_model(artifact=artifact, select_model='clip-vision')

NameError: ignored

In [14]:
finetuner.encode(model=clip_image_encoder, data=test_da)

NameError: ignored

In [None]:
# embed a user query, in this case, the query keywords is "Hu Tao fighting"
query = 'Hu Tao flying'
query_docs = DocumentArray([Document(content=query)])

finetuner.encode(model=clip_text_encoder, data=query_docs)
# find top 5 matches
query_docs.match(test_da, metric='cosine', limit=5)
# plot matches
for idx, match in enumerate(query_docs[0].matches):
    print(f'The matched document again query \033[1m{query}\033[0m at position \033[1m{idx}\033[0m is \033[1m{match.tags["finetuner_label"]}\033[0m')
    match.display()

Output()

The matched document again query [1mHu Tao flying[0m at position [1m0[0m is [1mHu Tao[0m


<IPython.core.display.Image object>

The matched document again query [1mHu Tao flying[0m at position [1m1[0m is [1mHu Tao[0m


<IPython.core.display.Image object>

The matched document again query [1mHu Tao flying[0m at position [1m2[0m is [1mHu Tao[0m


<IPython.core.display.Image object>

The matched document again query [1mHu Tao flying[0m at position [1m3[0m is [1mHu Tao[0m


<IPython.core.display.Image object>

The matched document again query [1mHu Tao flying[0m at position [1m4[0m is [1mHu Tao[0m


<IPython.core.display.Image object>

In [None]:
# embed a user query.
query = 'Ayaka dancing'
query_docs = DocumentArray([Document(content=query)])

finetuner.encode(model=clip_text_encoder, data=query_docs)
# find top 5 matches
query_docs.match(test_da, metric='cosine', limit=5)
# plot matches
for idx, match in enumerate(query_docs[0].matches):
    print(f'The matched document again query \033[1m{query}\033[0m at position \033[1m{idx}\033[0m is \033[1m{match.tags["finetuner_label"]}\033[0m')
    match.display()

Output()

The matched document again query [1mAyaka dancing[0m at position [1m0[0m is [1mAyaka[0m


<IPython.core.display.Image object>

The matched document again query [1mAyaka dancing[0m at position [1m1[0m is [1mAyaka[0m


<IPython.core.display.Image object>

The matched document again query [1mAyaka dancing[0m at position [1m2[0m is [1mAyaka[0m


<IPython.core.display.Image object>

The matched document again query [1mAyaka dancing[0m at position [1m3[0m is [1mAyaka[0m


<IPython.core.display.Image object>

The matched document again query [1mAyaka dancing[0m at position [1m4[0m is [1mAyaka[0m


<IPython.core.display.Image object>

In [None]:
# embed a user query.
query = 'Albedo flying'
query_docs = DocumentArray([Document(content=query)])

finetuner.encode(model=clip_text_encoder, data=query_docs)
# find top 5 matches
query_docs.match(test_da, metric='cosine', limit=5)
# plot matches
for idx, match in enumerate(query_docs[0].matches):
    print(f'The matched document again query \033[1m{query}\033[0m at position \033[1m{idx}\033[0m is \033[1m{match.tags["finetuner_label"]}\033[0m')
    match.display()

Output()

The matched document again query [1mAlbedo flying[0m at position [1m0[0m is [1mAlbedo[0m


<IPython.core.display.Image object>

The matched document again query [1mAlbedo flying[0m at position [1m1[0m is [1mAlbedo[0m


<IPython.core.display.Image object>

The matched document again query [1mAlbedo flying[0m at position [1m2[0m is [1mAlbedo[0m


<IPython.core.display.Image object>

The matched document again query [1mAlbedo flying[0m at position [1m3[0m is [1mAlbedo[0m


<IPython.core.display.Image object>

The matched document again query [1mAlbedo flying[0m at position [1m4[0m is [1mAlbedo[0m


<IPython.core.display.Image object>

Let's see the results!

1. `Hu Tao Fighting`: In top 5 results we only got 2 matches. Compared with before, minor improvements.
2. `Ayaka dancing`: In this case, zero-shot CLIP returns nice results. Compared with before, all matched results are Ayaka and also related to her dancing scene!
3. `Albedo flying`: This is a hard query, previously, we got 0 matches. Now, all images are Albedo images, and some of them are flying albedo!

End thoughts:

1. Fine-tuning CLIP with only 250 samples can already improve CLIP on retrieval a lot!