In [14]:
#pip install torch torchvision transformers faiss-cpu


In [15]:
import os
import json
import torch
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

In [16]:
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [17]:
# Define your custom dataset
class AnimalDataset(Dataset):
    def __init__(self, json_file, transform=None):
        with open(json_file) as f:
            self.data = json.load(f)["images"]
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]["file"]
        image = Image.open(img_path).convert("RGB")
        label = self.data[idx]["label"]
        breed = self.data[idx].get("product", None)

        if self.transform:
            image = self.transform(image)

        return image, label, breed

In [18]:
# Define your transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load your dataset
dataset = AnimalDataset('datasets/data_local.json', transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [19]:
from transformers import AutoFeatureExtractor, AutoModel
#from datasets import load_dataset, concatenate_datasets, load_from_disk
from PIL import Image
import numpy as np

In [20]:
model_ckpt = "google/vit-base-patch16-224"

extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt)

hidden_dim = model.config.hidden_size

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# Example for two classes (cat and dog)
num_classes = 2  # Change to number of breeds if needed
classifier = Classifier(768, num_classes)  # ViT base outputs 768-dim embeddings

In [22]:
import faiss
import numpy as np

# Prepare to store embeddings
embeddings = []
labels = []

# Extract embeddings
with torch.no_grad():
    for images, label, breed in dataloader:
        outputs = model(images).last_hidden_state[:, 0, :]  # Get the [CLS] token output
        embeddings.append(outputs.numpy())
        labels.extend(label)

# Convert to numpy arrays
embeddings = np.vstack(embeddings).astype('float32')
labels = np.array(labels)

# Create a FAISS index
index = faiss.IndexFlatL2(embeddings.shape[1])  # L2 distance index
index.add(embeddings)  # Add embeddings to the index


# Save the index to local storage
faiss.write_index(index, "faiss_index.index")

In [23]:
# Create a mapping of index to (label, breed)
file_mapping = {
    i: {
        "file": img_info["file"],
        "label": img_info["label"],
        "product": img_info.get("product", "N/A")  # Use "N/A" if breed is not available
    }
    for i, img_info in enumerate(dataset.data)
}

file_mapping[0]

{'file': 'C:\\Users\\ravik\\Ravi\\Projects\\Gold_ImageSimilarity\\datasets\\Kasumalai\\Lakshmi Necklace\\1.png',
 'label': 'Kasumalai',
 'product': 'Lakshmi Necklace'}

In [39]:
def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension

    _k=5

    with torch.no_grad():
        output = model(image).last_hidden_state[:, 0, :].numpy()

    # Search in the FAISS index
    D, I = index.search(output.astype('float32'), k=_k)  # Get top 5 similar images
    print("Distances: ", D)

    # Define a threshold and calculate matching percentage
    threshold = 0.5  # Example threshold
    matches = (D < threshold).astype(int)  # 1 for match, 0 for no match

    print(matches)

    # Calculate the percentage of matches
    #matching_percentage = np.mean(matches) * 100  # Mean across all queries
    

    # Calculate percentage of matches for each index
    matching_percentages = (matches.sum(axis=1) / _k) * 100  # Mean matches per query
    print(f"Matching Percentage: {matching_percentages:.2f}%")
    
    print(I)
    # Retrieve metadata based on indices
    predicted_metadata = [
        {
            "file": file_mapping[idx]["file"],
            "label": file_mapping[idx]["label"],
            "product": file_mapping[idx]["product"],
            "Matching_Percent": (matches.sum(axis=1) / _k) * 100

        }
        for idx in I[0] # Get the first query results
    ]

    return predicted_metadata

In [40]:
# Example prediction
predicted_results = predict(r'C:\Users\ravik\Ravi\Projects\Gold_ImageSimilarity\datasets\Rings\Casuals\2.jpg')
for result in predicted_results:
    print(f"File: {result['file']}, Label: {result['label']}, Product: {result['product']}")

Distances:  [[  0.      443.49084 485.0205  498.6761  513.2071 ]]
[[1 0 0 0 0]]
Matching Percentage: 20.00%
[[17  7  3 10 22]]
File: C:\Users\ravik\Ravi\Projects\Gold_ImageSimilarity\datasets\Bracelets\Chain\3.jpg, Label: Bracelets, Product: Chain
File: C:\Users\ravik\Ravi\Projects\Gold_ImageSimilarity\datasets\Bracelets\Cuff\1.jpg, Label: Bracelets, Product: Cuff
File: C:\Users\ravik\Ravi\Projects\Gold_ImageSimilarity\datasets\Kasumalai\Lakshmi Necklace\4.png, Label: Kasumalai, Product: Lakshmi Necklace
File: C:\Users\ravik\Ravi\Projects\Gold_ImageSimilarity\datasets\Bracelets\Cuff\3.jpg, Label: Bracelets, Product: Cuff
File: C:\Users\ravik\Ravi\Projects\Gold_ImageSimilarity\datasets\Rings\Casuals\1.jpg, Label: Rings, Product: Casuals
