In [1]:
pip install transformers


Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install torch torchvision


Note: you may need to restart the kernel to use updated packages.


In [3]:
import pandas as pd
import requests
import torch
from PIL import Image
from io import BytesIO
import numpy as np
from transformers import ViTModel, ViTImageProcessor
from sklearn.metrics.pairwise import cosine_similarity

# Load the Excel file
file_path = 'C:\\Users\\occid\\cleaned_data_no_duplicates.xlsx'
sheet_name = 'Sheet1'
df = pd.read_excel(file_path, sheet_name=sheet_name)

# Load the Vision Transformer (ViT) model and image processor
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = ViTModel.from_pretrained("google/vit-base-patch16-224")
model.eval()  # Set the model to evaluation mode

# Function to fetch and preprocess an image from a URL
def process_image_from_url(url):
    try:
        response = requests.get(url, timeout=10)
        if response.status_code == 200:
            image = Image.open(BytesIO(response.content)).convert('RGB')
            return image
        else:
            print(f"Failed to fetch image from {url}")
            return None
    except Exception as e:
        print(f"Error fetching image from {url}: {e}")
        return None

# Function to extract feature vector from an image using ViT
def extract_features(image):
    try:
        # Preprocess the image
        inputs = processor(images=image, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
        # Use the CLS token's embedding as the feature vector
        features = outputs.last_hidden_state[:, 0, :].detach().numpy()
        return features.squeeze(0)  # Remove batch dimension
    except Exception as e:
        print(f"Error extracting features: {e}")
        return None

# List of image URLs from the dataset
image_urls = df['image'].tolist()
barcodes = df['barcode'].tolist()
features = []

# Process each image URL and extract features
valid_barcodes = []
for url, barcode in zip(image_urls, barcodes):
    image = process_image_from_url(url)
    if image is not None:
        feature_vector = extract_features(image)
        if feature_vector is not None:
            features.append(feature_vector)
            valid_barcodes.append(barcode)

# Flatten feature vectors for similarity computation
features = np.vstack(features)

# Calculate similarity matrix
similarity_matrix = cosine_similarity(features)

# Create a similarity matrix DataFrame using valid barcodes
similarity_df = pd.DataFrame(similarity_matrix, columns=valid_barcodes, index=valid_barcodes)

# Save the similarity matrix to a file for review
output_file = "C:\\Users\\occid\\Image_Similarity_Matrix_ViT.xlsx"
similarity_df.to_excel(output_file)

print(f"Image similarity matrix saved to {output_file}")


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Failed to fetch image from https://cdn-images.kiotviet.vn/khotonghuynhphuong/aac760da977740ddb985ae8292ae6415.png
Error fetching image from https://sapo.dktcdn.net/100/566/185/variants/617c945a-cc1b-46c9-9be6-eba66aa9a3ba.jpg: HTTPSConnectionPool(host='sapo.dktcdn.net', port=443): Read timed out. (read timeout=10)
Failed to fetch image from https://bizweb.dktcdn.net/100/363/802/files/8847e9b9dgjki.jpg
Error fetching image from https://api.balance.ari.com.vn/api/v1/supermarket/util/download?key=supermarket-service/product/lcg9rf0v/nuoc-cot-gung-mat-ong-350ml-jXPySN.png: HTTPSConnectionPool(host='api.balance.ari.com.vn', port=443): Max retries exceeded with url: /api/v1/supermarket/util/download?key=supermarket-service/product/lcg9rf0v/nuoc-cot-gung-mat-ong-350ml-jXPySN.png (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x0000018F10F170A0>: Failed to establish a new connection: [Errno 11001] getaddrinfo failed'))
Failed to fetch image from https://cdn-images

Failed to fetch image from https://cdn-images.kiotviet.vn/nppannhung/c2c636590859487d9bf2b3dec370ebb0.jpeg
Error fetching image from https://cdn.nhanh.vn/cdn/store1/36027/ps/20201010/1010202091014_120882918_3545153138879361_5625452291125109538_n.jpg: HTTPSConnectionPool(host='cdn.nhanh.vn', port=443): Max retries exceeded with url: /cdn/store1/36027/ps/20201010/1010202091014_120882918_3545153138879361_5625452291125109538_n.jpg (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x0000018F111F0700>: Failed to establish a new connection: [Errno 11001] getaddrinfo failed'))
Error fetching image from http://douongcaocap.vn/wp-content/uploads/2017/10/Vang-Montes-Limited-Selection-Pinot-Noir-13.png: HTTPConnectionPool(host='no.access', port=80): Max retries exceeded with url: / (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x0000018F11202220>: Failed to establish a new connection: [Errno 11001] getaddrinfo failed'))
Error fetching image 

In [6]:
import pandas as pd
import numpy as np

# Load the full similarity matrix
file_path = 'C:\\Users\\occid\\Image_Similarity_Matrix_ViT.xlsx'
similarity_matrix = pd.read_excel(file_path, index_col=0)

# Define similarity threshold
similarity_threshold = 0.8

# Extract image IDs
image_ids = similarity_matrix.columns.astype(str).tolist()

# Convert the matrix to a NumPy array for efficient computation
matrix_values = similarity_matrix.values

# Find pairs of images with similarity above the threshold
similar_pairs = []
for i in range(len(matrix_values)):
    for j in range(i + 1, len(matrix_values)):  # Upper triangle only
        if matrix_values[i, j] > similarity_threshold:
            similar_pairs.append((image_ids[i], image_ids[j], matrix_values[i, j]))

# Create a DataFrame for better visualization
similar_pairs_df = pd.DataFrame(similar_pairs, columns=["Image ID 1", "Image ID 2", "Similarity Score"])

# Save or display results
similar_pairs_df.to_csv('similar_image_pairs_ViT0.8.csv', index=False)
print("Similar pairs saved to 'similar_image_pairs.csv'")


Similar pairs saved to 'similar_image_pairs.csv'
