<a href="https://colab.research.google.com/github/JanjaTomic/image_comparison/blob/main/image_comparison.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import numpy as np
import cv2
import torch.nn.functional as F
import sys

def preprocess_image(image_path, use_edges=False):
    img = Image.open(image_path).convert('RGB')

    if use_edges:
        img_np = np.array(img)
        edges = cv2.Canny(img_np, threshold1=100, threshold2=200)
        edges = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
        img = Image.fromarray(edges)

    img_tensor = preprocess(img).unsqueeze(0)
    return img_tensor

def extract_features(image_path):
    rgb_features = preprocess_image(image_path)
    edge_features = preprocess_image(image_path, use_edges=True)

    with torch.no_grad():
        features_rgb = model(rgb_features)
        features_edge = model(edge_features)

    combined_features = 0.5 * features_rgb + 0.5 * features_edge
    return combined_features.squeeze()

def compare_images(image1_path, image2_path):
    features1 = extract_features(image1_path)
    features2 = extract_features(image2_path)

    distance = torch.dist(features1, features2, p=2)
    return distance.item()

if _name_ == "_main_":
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    model = models.resnet50(pretrained=True)
    model.eval()

    baseline_image = sys.argv[1]
    test_image = sys.argv[2]

    distance_score = compare_images(baseline_image, test_image)
    result = "pass" if distance_score < 8.00 else "fail"
    print(result)