# Introduction

This notebook will contain a workflow for `ViT-g-14` with pretrained weights `laion2B-s12B-b42k`. This workflow will also include tiling to capture entire image as the model has a restriction of (1024, 1024, 3).

## Embedding Generation

In [1]:
import os
import json
import numpy as np
import torch
import faiss
import cv2
from PIL import Image
import open_clip

# ===================== #
#   Image Preprocessing #
# ===================== #
def preprocess_sketch(img_path):
    """Enhance edges, clean noise, and produce normalized binary image."""
    img = cv2.imread(img_path)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blur = cv2.bilateralFilter(gray, 9, 75, 75)
    norm = cv2.normalize(blur, None, 0, 255, cv2.NORM_MINMAX)
    edges = cv2.Canny(norm, 50, 150)
    _, binary = cv2.threshold(edges, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    kernel = np.ones((3, 3), np.uint8)
    refined = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
    processed = cv2.bitwise_not(refined)
    processed_rgb = cv2.cvtColor(processed, cv2.COLOR_GRAY2RGB)
    return processed_rgb

# def preprocess_sketch(img_path):
#     """
#     Dummy preprocessing function.
#     (Placeholder to maintain workflow compatibility ‚Äî returns RGB image as-is.)
#     """
#     img = cv2.imread(img_path)
#     if img is None:
#         raise ValueError(f"Could not read image: {img_path}")
#     img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#     return img_rgb

# ===================== #
#   Tiling for CLIP     #
# ===================== #
def tile_image(image, tile_size=1024, overlap=0.2):
    """Split image into overlapping tiles for high-resolution processing."""
    w, h = image.size
    step = int(tile_size * (1 - overlap))
    tiles = []
    for y in range(0, h, step):
        for x in range(0, w, step):
            box = (x, y, min(x + tile_size, w), min(y + tile_size, h))
            tile = image.crop(box)
            tiles.append(tile)
    return tiles

# ===================== #
#   Build FAISS Index   #
# ===================== #
def build_faiss_index(
    image_folder: str,
    index_path: str = "./sketch_index.faiss",
    mapping_path: str = "./id_mapping.json",
    model_name: str = "ViT-g-14",
    pretrained: str = "laion2B-s12B-b42K",
    distance_metric: str = "cosine",
):
    """
    Build FAISS index from large sketch images using LAION CLIP (ViT-g-14).
    Automatically tiles and preprocesses large images before embedding.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
    model = model.to(device)
    tokenizer = open_clip.get_tokenizer(model_name)

    embeddings, filenames = [], []

    for fname in os.listdir(image_folder):
        if fname.lower().endswith((".png", ".jpg", ".jpeg")):
            path = os.path.join(image_folder, fname)
            try:
                pre_img = preprocess_sketch(path)
                pil_img = Image.fromarray(pre_img)

                # --- Tile Large Image ---
                tiles = tile_image(pil_img, tile_size=1024, overlap=0.2)
                tile_embeds = []

                for tile in tiles:
                    tile_input = preprocess(tile).unsqueeze(0).to(device)
                    with torch.no_grad(), torch.cuda.amp.autocast():
                        emb = model.encode_image(tile_input)
                    emb = emb / emb.norm(dim=-1, keepdim=True)
                    tile_embeds.append(emb.cpu().numpy())

                # --- Pool Tile Embeddings ---
                image_embed = np.mean(np.vstack(tile_embeds), axis=0)
                embeddings.append(image_embed.astype("float32"))
                filenames.append(fname)

                print(f"Processed {fname} ({len(tiles)} tiles)")
            except Exception as e:
                print(f"Skipping {fname}: {e}")

    if not embeddings:
        raise ValueError(f"No valid images found in {image_folder}")

    embeddings = np.stack(embeddings, axis=0)
    print(f"Total embeddings computed: {embeddings.shape}")

    # --- Build FAISS Index ---
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim) if distance_metric == "cosine" else faiss.IndexFlatL2(dim)
    index.add(embeddings)

    faiss.write_index(index, index_path)
    print(f"FAISS index saved at: {index_path}")

    id_mapping = {i: filenames[i] for i in range(len(filenames))}
    with open(mapping_path, "w") as f:
        json.dump(id_mapping, f, indent=2)
    print(f"Mapping saved at: {mapping_path}")

    return index, id_mapping

# ===================== #
#       Example Run     #
# ===================== #
image_folder = "/home/ayushkum/archimera/inputs/input_png"
index_path = "/home/ayushkum/archimera/vit-g-14/sketch_index.faiss"
mapping_path = "/home/ayushkum/archimera/vit-g-14/id_mapping.json"

index, mapping = build_faiss_index(
    image_folder=image_folder,
    index_path=index_path,
    mapping_path=mapping_path,
    model_name="ViT-g-14",
    pretrained="/home/ayushkum/archimera/vit-g-14/open_clip_pytorch_model.bin",
    distance_metric="cosine"
)


Using device: cuda


  with torch.no_grad(), torch.cuda.amp.autocast():


Processed pdf5.png (63 tiles)
Processed pdf8.png (63 tiles)
Processed pdf4.png (63 tiles)
Processed pdf3.png (63 tiles)
Processed pdf1.png (63 tiles)
Processed pdf6.png (63 tiles)
Processed pdf7.png (63 tiles)
Processed pdf2.png (63 tiles)
Total embeddings computed: (8, 1024)
FAISS index saved at: /home/ayushkum/archimera/vit-g-14/sketch_index.faiss
Mapping saved at: /home/ayushkum/archimera/vit-g-14/id_mapping.json


## Query Search

In [2]:
import os
import json
import numpy as np
import torch
import faiss
import cv2
from PIL import Image
import open_clip


# ===================== #
#   Image Preprocessing #
# ===================== #
def preprocess_sketch(img_path):
    """Enhance edges, clean noise, and produce normalized binary image."""
    img = cv2.imread(img_path)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blur = cv2.bilateralFilter(gray, 9, 75, 75)
    norm = cv2.normalize(blur, None, 0, 255, cv2.NORM_MINMAX)
    edges = cv2.Canny(norm, 50, 150)
    _, binary = cv2.threshold(edges, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    kernel = np.ones((3, 3), np.uint8)
    refined = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
    processed = cv2.bitwise_not(refined)
    processed_rgb = cv2.cvtColor(processed, cv2.COLOR_GRAY2RGB)
    return processed_rgb

# def preprocess_sketch(img_path):
#     """
#     Dummy preprocessing function.
#     (Placeholder to maintain workflow compatibility ‚Äî returns RGB image as-is.)
#     """
#     img = cv2.imread(img_path)
#     if img is None:
#         raise ValueError(f"Could not read image: {img_path}")
#     img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#     return img_rgb

# ===================== #
#   Tiling for CLIP     #
# ===================== #
def tile_image(image, tile_size=1024, overlap=0.2):
    """Split large query image into overlapping tiles."""
    w, h = image.size
    step = int(tile_size * (1 - overlap))
    tiles = []
    for y in range(0, h, step):
        for x in range(0, w, step):
            box = (x, y, min(x + tile_size, w), min(y + tile_size, h))
            tile = image.crop(box)
            tiles.append(tile)
    return tiles


# ===================== #
#   Search Function     #
# ===================== #
def search_similar_sketches(
    query_path: str,
    index_path: str = "./sketch_index.faiss",
    mapping_path: str = "./id_mapping.json",
    model_name: str = "ViT-g-14",
    pretrained: str = "laion2B-s12B-b42K",
    top_k: int = 5,
    distance_metric: str = "cosine",
):
    """
    Search for similar sketches using LAION CLIP embeddings and FAISS index.

    ---
    Parameters:
        query_path (str): Path to query image (large architectural sketch).
        index_path (str): Path to FAISS index file.
        mapping_path (str): Path to JSON mapping (ID -> filename).
        model_name (str): OpenCLIP model name (default ViT-g-14).
        pretrained (str): Pretrained weights for OpenCLIP.
        top_k (int): Number of results to return.
        distance_metric (str): 'cosine' or 'L2'.

    ---
    Returns:
        list[dict]: ranked list of matches with filenames and scores.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load FAISS index and mapping
    index = faiss.read_index(index_path)
    with open(mapping_path, "r") as f:
        id_mapping = json.load(f)

    # Load OpenCLIP model
    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
    model = model.to(device)
    model.eval()

    # Preprocess + Tile Query Image
    pre_img = preprocess_sketch(query_path)
    pil_img = Image.fromarray(pre_img)
    tiles = tile_image(pil_img, tile_size=1024, overlap=0.2)

    tile_embeds = []
    for tile in tiles:
        tile_input = preprocess(tile).unsqueeze(0).to(device)
        with torch.no_grad(), torch.cuda.amp.autocast():
            emb = model.encode_image(tile_input)
        emb = emb / emb.norm(dim=-1, keepdim=True)
        tile_embeds.append(emb.cpu().numpy())

    # Mean-pool tile embeddings ‚Üí one vector for whole query image
    query_emb = np.mean(np.vstack(tile_embeds), axis=0).astype("float32")
    query_emb = np.expand_dims(query_emb, axis=0)

    # Search
    D, I = index.search(query_emb, top_k)

    # Prepare results
    results = []
    for rank, (idx, dist) in enumerate(zip(I[0], D[0]), start=1):
        fname = id_mapping.get(str(idx)) or id_mapping.get(idx)
        score = dist if distance_metric.lower() == "cosine" else (1 / (1 + dist))
        results.append({
            "rank": rank,
            "filename": fname,
            "score": float(score),
        })

    return results


# ===================== #
#     Example Run       #
# ===================== #
if __name__ == "__main__":
    query_folder = "/home/ayushkum/archimera/query_png"
    index_path = "/home/ayushkum/archimera/vit-g-14/sketch_index.faiss"
    mapping_path = "/home/ayushkum/archimera/vit-g-14/id_mapping.json"

    top_k = 5
    distance_metric = "cosine"

    for filename in os.listdir(query_folder):
        if filename.lower().endswith((".png", ".jpg", ".jpeg")):
            query_image_path = os.path.join(query_folder, filename)
            print(f"\nüîç Query: {filename}")
            results = search_similar_sketches(
                query_path=query_image_path,
                index_path=index_path,
                mapping_path=mapping_path,
                top_k=top_k,
                distance_metric=distance_metric,
                model_name="ViT-g-14",
                pretrained="/home/ayushkum/archimera/vit-g-14/open_clip_pytorch_model.bin"
            )

            for r in results:
                print(f"{r['rank']}. {r['filename']} ‚Äî score: {round(r['score'] * 100, 2)}%")



üîç Query: pdf3_SIM.png
Using device: cuda


  with torch.no_grad(), torch.cuda.amp.autocast():


1. pdf2.png ‚Äî score: 86.01%
2. pdf8.png ‚Äî score: 85.17%
3. pdf4.png ‚Äî score: 84.83%
4. pdf7.png ‚Äî score: 84.48%
5. pdf3.png ‚Äî score: 83.77%
