In [1]:
import os
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, ConcatDataset, Subset
from sklearn.model_selection import train_test_split
from Dataset import SiameseDataset
from scripts.utils import get_files_in_directory
from model.SNN import SiameseNN
from scripts.constants import POS_PATH, NEG_PATH, ANC_PATH, NUM_FILES, WEIGHTS_PATH

In [9]:
# Get file paths for anchor, positive, and negative images
anchor_files = get_files_in_directory(ANC_PATH, NUM_FILES)
positive_files = get_files_in_directory(POS_PATH, NUM_FILES)
negative_files = get_files_in_directory(NEG_PATH, NUM_FILES)

# Create datasets for positive and negative pairs
positive_dataset = SiameseDataset(anchor_files, positive_files, "POS")
negative_dataset = SiameseDataset(anchor_files, negative_files, "NEG")

# Concatenate positive and negative datasets
dataset = ConcatDataset([positive_dataset, negative_dataset])

# Split dataset into train and test sets
indices = list(range(len(dataset)))
train_indices, test_indices = train_test_split(indices, test_size = 0.25, random_state=42)

train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 8, shuffle = True)

In [10]:
batch = next(iter(test_loader)) # Load batch for test

In [11]:
# Load pre-trained Siamese Neural Network model

model = SiameseNN()
model.load_state_dict(torch.load(os.path.join(WEIGHTS_PATH, 'SiameseModel.pt')))
model.eval();

In [27]:
def show_images(batch, model, threshold=0.8):
    """
    Display pairs of images from a batch along with their similarity score as predicted by the model.

    Args:
    - batch (tuple): A tuple containing batches of anchor images, siamese images, and labels.
    - model: Pre-trained Siamese Neural Network model.
    - threshold (float): Threshold for considering a pair similar.

    Returns:
    - None
    """
    
    for i in range(len(batch[0])):
        image1 = batch[0][i].float()
        image2 = batch[1][i].float()
        
        outputs = model(image1, image2)
        image1_np = image1.cpu().numpy().transpose((1, 2, 0))
        image2_np = image2.cpu().numpy().transpose((1, 2, 0))

        fig, axes = plt.subplots(1, 2)
        axes[0].imshow(image1_np)
        axes[0].set_title('Image 1')
        axes[0].axis('off')
        axes[1].imshow(image2_np)
        axes[1].set_title('Image 2')
        axes[1].axis('off')
        
         
        if (outputs > threshold).any():
            print("True")
        else:
            print("False")
        
        plt.show()

In [None]:
show_images(batch, model)