In [26]:
# Importing useful dependencies
import io
import os
import boto3
import torch
import chromadb
from transformers import BertTokenizer, BertModel
import numpy as np
from chromadb.config import Settings
import torch.nn.functional as F

In [None]:
# Setup S3 client for MinIO (MinIO implements Amazon S3 API)
s3 = boto3.client(
    "s3",
    endpoint_url="http://127.0.0.1:9000", # MinIO API endpoint
    aws_access_key_id="minioadmin", # User name
    aws_secret_access_key="minioadmin", # Password
)

In [None]:
# Connect to the server (Docker Container)
client = chromadb.HttpClient(host="localhost", port=8000)
# Although we set a path for persistent directory when defining the Docker Container
# It actually stores the embeddings inside the container

# We can use the following line to remove all the stored data in a collection
#client.delete_collection(name="images")

# Create or get the collection named "images"
collection = client.create_collection(name="texts", get_or_create=True, embedding_function=None)

In [None]:
# Just in case our device has gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model
tokenizer = BertTokenizer.from_pretrained('bert-large-cased')
model = BertModel.from_pretrained("bert-large-cased")
model.to(device)

In [49]:
# We can use this function to retrieve an text from our bucket
def get_text(bucket, key):
    resp = s3.get_object(Bucket=bucket, Key=key)
    body = resp["Body"].read()
    text = body.decode("utf-8")
    return text
# The next function returns the embedding of the given text
def embed_text(tokenizer, model, text):
    encoded_input = tokenizer(text, return_tensors='pt',truncation=True,max_length=512)
    with torch.no_grad():
        output = model(**encoded_input)
        feats = output.pooler_output
        feats = feats / feats.norm(dim=-1, keepdim=True)
        feats_np = feats.cpu().numpy().squeeze()
    return  feats_np

In [57]:
def texts_to_embeddings(src_bucket, collection,tokenizer , model, src_prefix=""):

    # Incremental id assigned to each image embedding
    id_counter = 0
    
    paginator = s3.get_paginator("list_objects_v2") # It returns objects in pages and not all at once.
    for page in paginator.paginate(Bucket=src_bucket, Prefix=src_prefix):

        # List of paths (meta_data)
        texts_paths = []
        # List of embeddings
        embeddings = []
        # List of unique IDs for each embedding
        ids = []
        
        for obj in page.get("Contents", []):

            key = obj["Key"]

            if obj['Size'] == 0 and key.endswith("/"): # skip the folder itself
                continue

            id_counter += 1

            # Download the image
            text = get_text(src_bucket, key)
            
            # Compute embedding
            vector = embed_text(tokenizer, model, text) # A numerical vector of size 1024

            print(f"Created embedding for {key} ({len(embeddings)} items in current batch).")

            # Storing data
            texts_paths.append(f"{src_bucket}/{key}")
            embeddings.append(vector)
            ids.append(f"text_{id_counter}")

        # Store the images of a page at once
        collection.add(
                ids=ids,
                documents=texts_paths,
                embeddings=embeddings
        )

        print(f"All embeddings in the current batch are store successfully in the collection {collection.name}.")


In [None]:
texts_to_embeddings(src_bucket = "trusted-zone", src_prefix = "texts/", collection = collection, tokenizer = tokenizer, model = model)

In [None]:
# Function that prints the embeddings stored in a collection
def print_stored_embeddings(collection, x=None): # x is the maximum number of files to print
    results = collection.get(include=["documents", "embeddings"])
    for i in range(len(results["documents"])):
        print("ID:", results['ids'][i])
        print("Document:", results["documents"][i])
        print("Embedding (first 5 dims):", results["embeddings"][i][:5])
        print("---")
        if x and (x-1) == i:
            break

# We can use this function to print the embeddings stored in chromaDB
print_stored_embeddings(collection, x = 10)

In [54]:
# We can now perform a similarity search to test it

# The following function searches the top k most similar images in ChromaDB using the embeddings of an text
def find_similar_texts(collection, query_emb: np.ndarray, top_k: int = 5):
    # Chroma expects list-of-lists for query_embeddings
    query_vector = query_emb.tolist()

    results = collection.query(
        query_embeddings=[query_vector],
        n_results=top_k,
        include=["documents", "distances"]
    )

    # Extract first query results
    ids = results.get("ids", [[]])[0]
    docs = results.get("documents", [[]])[0]
    dists = results.get("distances", [[]])[0]

    print(f"Top {top_k} similar texts:")
    for rank, (doc_id, doc, dist) in enumerate(zip(ids, docs, dists), start=1):
        similarity = 1 - dist  
        print(f"{rank}. id={doc_id}, distance={dist:.4f} (similarity={similarity:.4f})")
        print(f"   text: {doc[:200]}{'...' if len(doc) > 200 else ''}")

    return results

In [56]:
# Search for similar texts in ChromaDB
results = find_similar_texts(collection, emb, top_k=5) # The first one is always the target texts itself


Top 5 similar texts:
1. id=img_1, distance=0.0000 (similarity=1.0000)
   text: trusted-zone/texts/text_1760137269318.txt
2. id=img_64, distance=0.0036 (similarity=0.9964)
   text: trusted-zone/texts/text_1760137274219.txt
3. id=img_160, distance=0.0041 (similarity=0.9959)
   text: trusted-zone/texts/text_1760137281032.txt
4. id=img_217, distance=0.0041 (similarity=0.9959)
   text: trusted-zone/texts/text_1760137285646.txt
5. id=img_106, distance=0.0044 (similarity=0.9956)
   text: trusted-zone/texts/text_1760137277237.txt
{'ids': [['img_1', 'img_64', 'img_160', 'img_217', 'img_106']], 'distances': [[0.0, 0.003573833, 0.0041247564, 0.004138467, 0.0043580746]], 'embeddings': None, 'metadatas': None, 'documents': [['trusted-zone/texts/text_1760137269318.txt', 'trusted-zone/texts/text_1760137274219.txt', 'trusted-zone/texts/text_1760137281032.txt', 'trusted-zone/texts/text_1760137285646.txt', 'trusted-zone/texts/text_1760137277237.txt']], 'uris': None, 'data': None, 'included': ['documents

In [58]:
def get_query_texts(results, bucket="trusted-zone"):
    """
    给定 ChromaDB 的 query() 结果，返回对应的文本内容。
    """
    docs = results.get("documents", [[]])[0]  # 取第一个 query 的 documents 列表
    texts = []

    for doc_path in docs:
        # 从路径中提取出 key，比如 'trusted-zone/texts/text_1760137269318.txt'
        # 去掉前缀 'trusted-zone/' 部分
        key = doc_path.split("trusted-zone/")[-1]
        try:
            text = get_text(bucket, key)
            texts.append({"path": doc_path, "content": text})
        except Exception as e:
            print(f"⚠️ Failed to read {key}: {e}")
            texts.append({"path": doc_path, "content": None})
    return texts
texts = get_query_texts(results)
for i, t in enumerate(texts, start=1):
    print(f"\n{i}. 📄 {t['path']}")
    print(t['content'][:300], "..." if len(t['content']) > 300 else "")


1. 📄 trusted-zone/texts/text_1760137269318.txt
Galactic Bowling is an exaggerated and stylized bowling game with an intergalactic twist. Players will engage in fast-paced single and multi-player competition while being submerged in a unique new universe filled with over-the-top humor, wild characters, unique levels, and addictive game play. The  ...

2. 📄 trusted-zone/texts/text_1760137274219.txt
This ain't your Grandma's Tetris! Vetrix is a puzzle game inspired by Tetris, but with its own original mechanics built for virtual reality. The player will have to demonstrate dexterity, organization, and good skills in the geometry of space as well as proprioception to achieve the best possible sc ...

3. 📄 trusted-zone/texts/text_1760137281032.txt
At Cub Gym the player will have a huge challenge ahead, several crazy games are waiting for him trying to take him down, test his skills and make it to the end. With a completely random level creation system, Cub Gym promises different routes for