## Task 2: Multi-Modality Similarity Search

The **second task** we defined for our **Game Recommendation Assistant** is the following:

**Text-based retrieval**: Given a game description or textual query such as "open-world fantasy adventure", the system retrieves game covers (images), trailers (videos), and descriptions (texts) of games with similar themes, genres, or narrative elements.

**Image-based retrieval**: Given a game cover or in-game snapshot, the system retrieves visually and semantically related games, including similar covers (images), trailers (videos), and descriptions (texts) that share comparable art styles, visual motifs, or atmosphere.

**Video-based retrieval**: Given a game trailer, the system retrieves trailers (videos), covers (images), and descriptions (texts) of games with a similar visual tone, gameplay style, or mood.

In [None]:
# Importing useful dependencies
import io
import boto3
import torch
import requests
import imageio
import chromadb
import open_clip
import numpy as np
from PIL import Image
import torch.nn as nn
from io import BytesIO
import ipywidgets as widgets
from IPython.display import display
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor, AutoTokenizer, CLIPTextModelWithProjection

In [None]:
# Setup S3 client for MinIO (MinIO implements Amazon S3 API)
s3 = boto3.client(
    "s3",
    endpoint_url="http://127.0.0.1:9000", # MinIO API endpoint
    aws_access_key_id="minioadmin", # User name
    aws_secret_access_key="minioadmin", # Password
)

In [None]:
# Connect to the server of ChromaDB where we stored the embeddings of files (Docker Container)
client = chromadb.HttpClient(host="localhost", port=8000)

# Create or get the collection named "texts_images"
collection_texts_images = client.create_collection(name="texts_images", get_or_create=True, embedding_function=None)

# Create or get the collection named "texts_images_videos"
collection_texts_images_videos = client.create_collection(name="texts_images_videos", get_or_create=True, embedding_function=None)

In [None]:
# Just in case our device has gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

Let's define the models we will be using.

In [None]:
# Load "ViT-B-16" model for images and texts
model_it, _, preprocess_it = open_clip.create_model_and_transforms("ViT-B-16", pretrained="openai")
tokenizer_it = open_clip.get_tokenizer("ViT-B-16") # Tokenizer for texts
model_it.to(device)

# =================================================================================================

# Load "Searchium-ai/clip4clip-webvid150k" for texts, images and videos

# === Text encoder ===
text_tokenizer = AutoTokenizer.from_pretrained("Searchium-ai/clip4clip-webvid150k")
text_model = CLIPTextModelWithProjection.from_pretrained("Searchium-ai/clip4clip-webvid150k")
text_model.to(device)

# === Image / video frame encoder ===
image_processor = CLIPImageProcessor.from_pretrained("Searchium-ai/clip4clip-webvid150k")
vision_model = CLIPVisionModelWithProjection.from_pretrained("Searchium-ai/clip4clip-webvid150k")
vision_model.to(device)

### Using **ViT-B-16**

In the following cells we implement some functions to get embeddings of a given data and get files from MinIO.

In [None]:
# We can use this function to retrieve a text from our bucket
def get_text(bucket, key):
    response = s3.get_object(Bucket=bucket, Key=key)
    body = response["Body"].read().decode("utf-8")
    return body

# We can use this function to retrieve an image from our bucket in PIL Image format
def get_image(bucket, key):
    resp = s3.get_object(Bucket=bucket, Key=key)
    body = resp["Body"].read()
    img = Image.open(io.BytesIO(body))
    return img

In [None]:
@torch.no_grad()
# The next function returns the embedding of the given text
def embed_text(model, tokenizer, texts: str):
    tokens = tokenizer([texts]).to(device) # tokenized batch
    feats = model.encode_text(tokens)
    feats = feats / feats.norm(dim=-1, keepdim=True) # normalize
    return feats.cpu().numpy()[0]

# The next function returns the embedding of the given PIL Image
def embed_image(preprocess, model, pil_img):
    img_tensor = preprocess(pil_img).unsqueeze(0).to(device)
    with torch.no_grad():
        feats = model.encode_image(img_tensor)
    feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats.cpu().numpy().squeeze()

Here below, we implement some functions to retrieve similar multi-modality data.

In [None]:
def print_top_k_files(res, k=5, videos=False):
    # Print results with type (text/image/video)
    n_text = 0
    n_image = 0
    n_video = 0
    i = 0
    for _, doc in enumerate(res["documents"][0]):
        if (doc.split(".")[-1] == "txt" and n_text < k):
            print(f"{i+1}. Distance: {res['distances'][0][i]:.4f}")
            print("Content:", doc)
            print(get_text("trusted-zone", doc.replace("trusted-zone/", "", 1)))
            print("-" * 40)
            i += 1
            n_text += 1
        elif (doc.split(".")[-1] == "png" and n_image < k):
            print(f"{i+1}. Distance: {res['distances'][0][i]:.4f}")
            print("Content:", doc)
            display(get_image("trusted-zone", doc.replace("trusted-zone/", "", 1)))
            print("-" * 40)
            i += 1
            n_image += 1

        if videos:
            if (doc.split(".")[-1] == "mp4" and n_video < k):
                print(f"{i+1}. Distance: {res['distances'][0][i]:.4f}")
                print("Content:", doc)
                frames = get_video("trusted-zone", doc.replace("trusted-zone/", "", 1))
                for frame in frames:
                    display(frame)
                print("-" * 40)
                i += 1
                n_video += 1
                
            # Stop early if both top-k limits are reached
            if n_text >= k and n_image >= k and n_video >= k:
                break
        else:
            # Stop early if both top-k limits are reached
            if n_text >= k and n_image >= k:
                break

**Text-based retrieval**

In [None]:
# Example: query by another game's description
query_text = "Games similar to Nier: Automata"
q_vec = embed_text(model_it, tokenizer_it, query_text).tolist()

res = collection_texts_images.query(
    query_embeddings=[q_vec],
    # It’s expected that all nearest neighbors are text for a long text query.
    # To get images, we need to retrieve more embeddings.
    n_results=2000,
    include=["documents","distances"]
)

print_top_k_files(res, k = 5)

**Image-based retrieval**

In [None]:
# Upload an image from local storage
uploader = widgets.FileUpload(accept='image/*', multiple=False)
display(uploader)

In [None]:
# Extract the uploaded file
image_data = uploader.value[0].content
img_example = Image.open(BytesIO(image_data))

# Create embeddings for the Image
img_example_emb = embed_image(preprocess_it, model_it, img_example)
img_example

In [None]:
# Example: query by another game's cover or snapshot
q_vec = embed_image(preprocess_it, model_it, img_example)

res = collection_texts_images.query(
    query_embeddings=[q_vec],
    n_results=2000,
    include=["documents","distances"]
)

print_top_k_files(res, k = 5)

### Using **clip4clip-webvid150k**

In [None]:
# Layer to project 512-d embeddings to 768-d
projection = nn.Linear(512, 768).to(device)

In [None]:
temp_file = "temp_video_in.mp4"
# We can use the following function to retrieve a video from our MinIO database
# We are only extracting frames of the video, ignoring the audio content of the video
def get_video(bucket = None, key = None, max_frames = 16, url = None):
    if url:
        r = requests.get(url, stream=True)
        r.raise_for_status()
        with open(temp_file, "wb") as f:
            f.write(r.content)
    else:
        resp = s3.get_object(Bucket=bucket, Key=key)
        body = resp["Body"].read()
        with open(temp_file, "wb") as f:
            f.write(body)
    frames = []
    reader = imageio.get_reader(temp_file, format="ffmpeg")
    total_frames = reader.count_frames()
    if total_frames and total_frames > 0:
        step = max(1, total_frames // max_frames)
        #print(total_frames, step, max_frames)
        idxs = list(range(0, total_frames, step))[:max_frames]
        for i in idxs:
            try:
                frame = reader.get_data(i)
                frames.append(Image.fromarray(frame))
            except Exception:
                continue
    else:
        # fallback: iterate and collect up to max_frames
        for i, frame in enumerate(reader):
            frames.append(Image.fromarray(frame))
            if len(frames) >= max_frames:
                break
    reader.close()
    return frames

In [None]:
@torch.no_grad()
def embed_image_second_model(processor, model, pil_img):
    inputs = processor(images=pil_img, return_tensors="pt").to(device)
    outputs = model(**inputs)
    feats = outputs.image_embeds
    feats = feats / feats.norm(dim=-1, keepdim=True)
    
    # Project to 768-d
    feats_768 = projection(feats)
    feats_768 = feats_768 / feats_768.norm(dim=-1, keepdim=True)
    
    return feats_768.cpu().numpy().squeeze()

@torch.no_grad()
# The next function returns the embedding of the given text
def embed_text_second_model(model, text: str, tokenizer=text_tokenizer):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(device) # max_length = 77 -> CLIP's max token length
    outputs = model(**inputs)
    text_features = outputs.text_embeds
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    # Project to 768-d
    text_features_768 = projection(text_features)
    text_features_768 = text_features_768 / text_features_768.norm(dim=-1, keepdim=True)
    return text_features_768.cpu().numpy()[0]

# The next function returns the embedding of the given video
def embed_video(tokenizer, model, frames_video):
    inputs = tokenizer(images=frames_video, return_tensors="pt")
    pixel_values = inputs["pixel_values"].to(device) # shape (num_frames, 3, H, W)

    with torch.no_grad():
        outputs = model(pixel_values=pixel_values)
        # prefer pooler_output if available, else mean over last_hidden_state
        emb_frames = getattr(outputs, "pooler_output", None)
        if emb_frames is None:
            emb_frames = outputs.last_hidden_state.mean(dim=1)
        # average frame embeddings to make video embedding
        video_emb = emb_frames.mean(dim=0).cpu().numpy()
    return  video_emb

**Text-based retrieval**

In [None]:
# Example: query by another game's description
query_text = "Games similar to Nier: Automata"
q_vec = embed_text_second_model(text_model, query_text).tolist()

res = collection_texts_images_videos.query(
    query_embeddings=[q_vec],
    # It’s expected that all nearest neighbors are text for a long text query.
    # To get images, we need to retrieve more embeddings.
    n_results=2000,
    include=["documents","distances"]
)

print_top_k_files(res, k = 5, videos = True)

**Image-based retrieval**

In [None]:
# Example: query by another game's cover or snapshot (using the same photo as the one uploaded previously)
q_vec = embed_image_second_model(image_processor, vision_model, img_example)

res = collection_texts_images_videos.query(
    query_embeddings=[q_vec],
    n_results=2000,
    include=["documents","distances"]
)

print_top_k_files(res, k = 5,  videos = True)

**Video-based retrieval**

In [None]:
# Sample video
video_example = get_video(url = "https://cdn.akamai.steamstatic.com/steam/apps/256853884/movie_max.mp4?t=1633085092") # frames/list of images of the video
# Create embeddings for the video
video_example_emb = embed_video(image_processor, vision_model, video_example)
for frame in video_example:
    display(frame)

In [None]:
# Example: query by a list of frames from a game trailer
res = collection_texts_images_videos.query(
    query_embeddings=[video_example_emb],
    n_results=2000,
    include=["documents","distances"]
)

print_top_k_files(res, k = 5,  videos = True)