In [1]:
import os
import sqlite3
import hashlib
import json
from PIL import Image
from transformers import ViTModel, ViTFeatureExtractor
import torch
from config import db_path

Database path: C:\Users\Admin\Desktop\Multimodal\database\database\multimodal_rag.db


In [2]:
model = ViTModel.from_pretrained('google/vit-base-patch16-224')
processor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
image_dir = os.path.join('input', 'images')
os.makedirs(image_dir, exist_ok=True) 

In [4]:
conn = sqlite3.connect(db_path)

In [5]:
for filename in os.listdir('input/images'):
    image_path = os.path.join('input/images', filename)
    
    # Skip non-image files
    if not filename.lower().endswith(('.png', '.jpg', '.jpeg')):
        print(f"Skipping non-image file: {filename}")
        continue
    
    try:
        # Generate content hash to avoid duplicates
        with open(image_path, "rb") as f:
            content_hash = hashlib.sha256(f.read()).hexdigest()
        
        # Check if image already exists in database
        cursor = conn.cursor()
        cursor.execute('SELECT id FROM embeddings WHERE content_hash = ?', (content_hash,))
        if cursor.fetchone():
            print(f"Image already processed: {filename}")
            continue
        
        # Process image
        img = Image.open(image_path)
        inputs = processor(images=img, return_tensors="pt")
        with torch.no_grad():
            embedding = model(**inputs).last_hidden_state.mean(dim=1).numpy()
        # Ensure embedding matches expected dimension
        if embedding.shape[1] != expected_dimension:
            raise ValueError(f"Embedding dimension mismatch: {embedding.shape[1]} != {expected_dimension}")
                
        # Insert into database
        conn.execute('''
            INSERT INTO embeddings (modality, content_hash, embedding, metadata)
            VALUES (?, ?, ?, ?)
        ''', ('image', content_hash, embedding.tobytes(), json.dumps({
            "filename": filename,
            "path": image_path,
            "dimensions": img.size
        })))
        
        print(f"Processed and stored: {filename}")
    
    except Exception as e:
        print(f"Error processing {filename}: {str(e)}")


Image already processed: beach.jpg
Image already processed: cat.jpg
Image already processed: dog.jpg
Image already processed: spiderman.jpg
Image already processed: sunset.jpg


In [6]:
conn.commit()
conn.close()
print("Image processing complete!")

Image processing complete!
