In [56]:
from urllib.parse import quote_plus

In [57]:
# Securely format the MongoDB connection string with username and password.
username = ""
password = "" 

uri = "mongodb+srv://{username}:{password}@visiontransformer-based.7aqyzdx.mongodb.net/?retryWrites=true&w=majority&appName=VisionTransformer-BasedImageRetrievalSystem".format(
    username=quote_plus(username),
    password=quote_plus(password)
)

In [20]:
# Establishing connection to a MongoDB database
from pymongo import MongoClient, server_api

client = MongoClient(uri, server_api=server_api.ServerApi('1'))

try:
    client.admin.command('ping')
    print("MongoDB connection successful.")
except Exception as e:
    print("Error connecting to MongoDB:", e)

MongoDB connection successful.


In [21]:
# Accessing the collection
db = client['Vision_transformer']
features_collection = db['image_retrieval']

In [58]:
from transformers import ViTFeatureExtractor, ViTModel
import os
import torch
from PIL import Image

In [6]:
# Load the Vision Transformer model and feature extractor.
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTModel.from_pretrained('google/vit-base-patch16-224')

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [59]:
# Extract feature vectors from an image using the Vision Transformer.
def extract_features(image_path):
    image = Image.open(image_path)
    inputs = feature_extractor(images=image, return_tensors="pt")
    with torch.no_grad():  
        outputs = model(**inputs)
    features = outputs.last_hidden_state[:, 0, :].detach().numpy() # Just taking the CLS token which represenets the entire image.

    return features.flatten()

In [60]:
def process_images(main_folder_path):
    features_list = []
    image_names = []

    for foldername in os.listdir(main_folder_path):
        subfolder_path = os.path.join(main_folder_path, foldername)
        if os.path.isdir(subfolder_path):  
            for filename in os.listdir(subfolder_path):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                    image_path = os.path.join(subfolder_path, filename)
                    try:
                        features = extract_features(image_path)
                        features_list.append(features)
                        image_names.append(image_path)  
                    except Exception as e:
                        print(f"Failed to process {filename} in {foldername}: {str(e)}")

    return features_list, image_names

In [61]:
main_folder_path = '/Users/anantha_padmanaban/Documents/Academic/Spring24/web_mining/project/dataset'
features, names = process_images(main_folder_path)

In [19]:
# Storing the data in the database:

for name, feature in zip(names, features):
    document = {
        'image_name': name,
        'feature_vector': feature.tolist()  
    }
    features_collection.insert_one(document)