### Installing and Importing necessary libraries

In [None]:
!pip install byol_pytorch -q q q

In [None]:
# Standard library imports
import json
import math
import os
import random

# Third-party library imports
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
import pandas as pd
from PIL import Image
from scipy.spatial.distance import cdist
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

# Deep learning frameworks
import torch
from torch.utils.data import DataLoader, Dataset, Subset, random_split
import torchvision
from torchvision import models, transforms as T

# Specialized model imports
from byol_pytorch import BYOL
from transformers import CLIPProcessor, CLIPModel

# Initialize models
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda" if torch.cuda.is_available() else "cpu"

## Dataloader and Dataset

In [None]:
class BaseArtDataset(Dataset):
    """Base dataset class for art images"""
    def __init__(self, folder_path):
        # Get all image files in the directory
        self.image_files = [f for f in os.listdir(folder_path) 
                          if os.path.isfile(os.path.join(folder_path, f)) and 
                          f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
        
        self.folder_path = folder_path
    
    def __len__(self):
        return len(self.image_files)
    
    def _load_image(self, idx):
        """Load image from file"""
        img_path = os.path.join(self.folder_path, self.image_files[idx])
        return Image.open(img_path).convert('RGB'), img_path


class ByolTrainDataset(BaseArtDataset):
    """Dataset for BYOL training (single image input)"""
    def __init__(self, folder_path):
        super().__init__(folder_path)
        # Basic transform for BYOL (minimal since BYOL handles augmentations)
        self.transform = T.Compose([
            T.Resize((256,256)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __getitem__(self, idx):
        image, img_path = self._load_image(idx)        
        # Apply transforms
        image = self.transform(image)            
        return image, img_path


class MetricEvalDataset(BaseArtDataset):
    """Dataset for alignment and uniformity metric evaluation (two augmented views of the same image)"""
    def __init__(self, folder_path):
        super().__init__(folder_path)
        
        # Augmentation pipeline for evaluation metrics
        self.transform = T.Compose([
            T.RandomResizedCrop(256, scale=(0.5, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.5),
            T.RandomGrayscale(p=0.1),
            T.RandomApply([T.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0))], p=0.2),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __getitem__(self, idx):
        image, img_path = self._load_image(idx)
        
        # Create two differently augmented versions of the same image
        img1 = self.transform(image)
        img2 = self.transform(image)
            
        return img1, img2, img_path

In [None]:
def create_data_loaders(dataset_path, batch_size=64, train_size=18000, val_size=1000, test_size=1000, num_workers=4):
    """
    Create train, validation, and test data loaders with different configurations:
    - Train loader: BYOL-style single images
    - Val/Test loaders: Pairs of augmented images for metric evaluation
    
    Args:
        dataset_path (str): Path to dataset directory
        batch_size (int): Batch size for loaders
        train_size (int): Number of samples for training
        val_size (int): Number of samples for validation
        test_size (int): Number of samples for testing
        num_workers (int): Number of workers for data loading
        
    Returns:
        dict: Dictionary containing train, val, test loaders and datasets
    """
    # First, get the total list of files and create indices for splitting
    base_dataset = BaseArtDataset(dataset_path)
    dataset_size = len(base_dataset)
    
    # Check if we have enough samples
    required_size = train_size + val_size + test_size
    
    if dataset_size < required_size:
        print(f"Warning: Dataset only has {dataset_size} samples, which is less than the requested {required_size}")
        # Adjust sizes proportionally
        total = train_size + val_size + test_size
        train_size = int(dataset_size * (train_size / total))
        val_size = int(dataset_size * (val_size / total))
        test_size = dataset_size - train_size - val_size
    
    # Create indices for the splits
    indices = list(range(dataset_size))
    random.shuffle(indices)
    
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:train_size + val_size + test_size]
    
    # Create the actual datasets with appropriate transformations
    train_dataset = ByolTrainDataset(dataset_path)
    val_dataset = MetricEvalDataset(dataset_path)
    test_dataset = MetricEvalDataset(dataset_path)
    
    # Create subsets based on the indices
    train_subset = Subset(train_dataset, train_indices)
    val_subset = Subset(val_dataset, val_indices)
    test_subset = Subset(test_dataset, test_indices)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    # Using a smaller batch size for validation/test due to memory concerns (2 copies of each image)
    eval_batch_size = batch_size // 2
    
    val_loader = DataLoader(
        val_subset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_subset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return {
        'train_loader': train_loader,
        'val_loader': val_loader,
        'test_loader': test_loader,
        'train_dataset': train_subset,
        'val_dataset': val_subset,
        'test_dataset': test_subset
    }

In [None]:
# Create data loaders with appropriate configurations
data = create_data_loaders(
    dataset_path="/kaggle/input/nga-unlablled-dataset/nga_images",
    batch_size=32,
    train_size=18000,
    val_size=1000,
    test_size=1000
)
    
train_loader = data['train_loader']
val_loader = data['val_loader']

## Evaluation Metric based on the paper https://arxiv.org/abs/2005.10242 , used for validation and test dataloader

In [None]:
# Alignment and uniformity metric functions
def align_loss(x, y, alpha=2):
    """Measures the average distance between embeddings of positive pairs"""
    return (x - y).norm(p=2, dim=1).pow(alpha).mean()


def uniform_loss(x, t=2):
    """Measures how uniformly the embeddings are distributed on the unit hypersphere"""
    sq_pdist = torch.pdist(x, p=2).pow(2)
    return sq_pdist.mul(-t).exp().mean().log()


def evaluate_intrinsic_metrics(model, test_loader, device, alpha=2, t=2, lam=1.0):
    """
    Evaluate intrinsic metrics (alignment and uniformity) on a test loader
    
    Args:
        model: Feature extractor model
        test_loader: DataLoader yielding (img1, img2) pairs
        device: torch device
        alpha: Power for alignment loss
        t: Temperature for uniformity loss
        lam: Weight to balance uniformity loss
        
    Returns:
        Tuple of (alignment, uniformity, total_metric)
    """
    model.eval()
    alignment_losses = []
    uniform_losses = []
    
    with torch.no_grad():
        for img1, img2, _ in test_loader:
            img1 = img1.to(device)
            img2 = img2.to(device)

            

            
            # Get embeddings from the model
            _,emb1 = model(img1, return_embedding = True)
            _,emb2 = model(img2, return_embedding = True)

            # Compute alignment loss between positive pairs
            align = align_loss(emb1, emb2, alpha=alpha)
            
            # Compute uniformity on both sets of embeddings
            unif1 = uniform_loss(emb1, t=t)
            unif2 = uniform_loss(emb2, t=t)
            unif = (unif1 + unif2) / 2
            
            alignment_losses.append(align.item())
            uniform_losses.append(unif.item())
    
    avg_align = sum(alignment_losses) / len(alignment_losses)
    avg_unif = sum(uniform_losses) / len(uniform_losses)
    
    # The total intrinsic "loss" (or metric) is given by:
    total_metric = avg_align + lam * avg_unif
    
    return avg_align, avg_unif, total_metric

## Model script

In [None]:
resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    use_momentum = False       # turn off momentum in the target encoder
).to(device)

optimizer = torch.optim.Adam(learner.parameters(), lr=3e-4)

## Training script

In [None]:
warmup_epochs = 10
epochs = 100
learning_rate = 3e-4
eval_every = 10

In [None]:
# Results tracking
metrics_history = []
best_metric = float('inf')
    
# Training loop
for epoch in range(epochs):
    # Adjust learning rate with warmup and cosine decay
    if epoch < warmup_epochs:
        lr = learning_rate * (epoch + 1) / warmup_epochs
    else:
        lr = learning_rate * 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / (epochs - warmup_epochs)))
        
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
    # Training step
    learner.train()
    train_losses = []
    
    # Add tqdm progress bar
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
    for images, _ in progress_bar:
        images = images.to(device)
            
        loss = learner(images)
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())
        
        # Update progress bar with current loss
        progress_bar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.6f}")
        
    avg_train_loss = sum(train_losses) / len(train_losses)
    print(f"Epoch {epoch+1}/{epochs}: Train Loss = {avg_train_loss:.4f}, LR = {lr:.6f}")
    
    # Evaluate metrics periodically
    if epoch % eval_every == 0 or epoch == epochs - 1:
        
        # Evaluate on validation set
        align, uniform, total = evaluate_intrinsic_metrics(
            learner, val_loader, device
        )
            
        print(f"Validation Metrics - Alignment: {align:.4f}, Uniformity: {uniform:.4f}, Total: {total:.4f}")
            
        metrics_history.append({
            'epoch': epoch,
            'train_loss': avg_train_loss,
            'alignment': align,
            'uniformity': uniform,
            'total_metric': total
        })
        
        # Save if best model by total metric
        if total < best_metric:
            best_metric = total
            torch.save({
                'epoch': epoch,
                'model_state_dict': resnet.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'metrics': {
                    'alignment': align,
                    'uniformity': uniform,
                    'total': total
                }
            }, 'best_byol_model.pt')

In [None]:
test_align, test_uniform, test_total = evaluate_intrinsic_metrics(learner, data['test_loader'], device)    
print(f"Final Test Metrics - Alignment: {test_align:.4f}, Uniformity: {test_uniform:.4f}, Total: {test_total:.4f}")    

### Visual evaluation and use of CLIP Score

In [None]:
def find_top_similar_pairs(model, test_loader, device, top_k=5):
    """
    Find the top_k most similar image pairs in the test dataset.
    
    Args:
        model: Your trained BYOL model
        test_loader: Your test dataloader that yields (img1, img2, path)
        device: Device to run inference on ('cuda' or 'cpu')
        top_k: Number of most similar pairs to return
        
    Returns:
        List of dictionaries containing the most similar pairs
    """
    model.eval()
    
    # Collect all embeddings and paths
    all_embeddings = []
    all_paths = []
    
    print("Extracting embeddings from test dataset...")
    with torch.no_grad():
        for img1, img2, rest in tqdm(test_loader):
            
            if len(rest) > 0:
                paths = rest  # Get image paths
            else:
                # If paths are not provided, create dummy paths
                paths = [f"image_{len(all_paths) + i}" for i in range(len(img1))]
            
            img1 = img1.to(device)
            
            # Get embeddings for the first augmentation only
            # (we don't need both augmentations for finding similar pairs)
            _, emb1 = model(img1, return_embedding=True)
            
            all_embeddings.append(emb1.cpu().numpy())
            all_paths.extend(paths)
    
    # Concatenate all embeddings
    all_embeddings = np.vstack(all_embeddings)
    
    print(f"Extracted embeddings for {len(all_paths)} images")
    
    # Calculate similarity matrix (cosine similarity)
    print("Calculating similarity matrix...")
    similarity_matrix = 1 - cdist(all_embeddings, all_embeddings, 'cosine')
    
    # Set diagonal to -inf (to exclude self-comparisons)
    np.fill_diagonal(similarity_matrix, -np.inf)
    
    # Find the top-k most similar pairs
    most_similar_pairs = []
    
    print(f"Finding top {top_k} most similar pairs...")
    for _ in range(top_k):
        # Find the indices of the maximum similarity
        i, j = np.unravel_index(np.argmax(similarity_matrix), similarity_matrix.shape)
        similarity = similarity_matrix[i, j]
        
        # Add the pair to the result
        most_similar_pairs.append({
            'image1_idx': i,
            'image2_idx': j,
            'image1_path': all_paths[i],
            'image2_path': all_paths[j],
            'learner similarity': similarity
        })
        
        # Set this pair's similarity to -inf to exclude it from future consideration
        similarity_matrix[i, j] = -np.inf
        similarity_matrix[j, i] = -np.inf
    
    return most_similar_pairs

In [None]:
def display_similar_pairs(similar_pairs):
    """
    Display the most similar pairs of images
    
    Args:
        similar_pairs: List of dictionaries with similar pair information
    """
    for i, pair in enumerate(similar_pairs):
        print(f"{i+1}. Similarity: {pair['learner similarity']:.4f}")
        print(f"   Image 1: {pair['image1_path']}")
        print(f"   Image 2: {pair['image2_path']}")

In [None]:
def find_and_display_similar_pairs(model, test_loader, device='cuda', top_k=5):
    """
    Find and display the top_k most similar image pairs
    
    Args:
        model: Your trained BYOL model
        test_loader: Your test dataloader
        device: Device to run inference on
        top_k: Number of most similar pairs to return
        
    Returns:
        List of dictionaries containing the most similar pairs
    """
    similar_pairs = find_top_similar_pairs(model, test_loader, device, top_k)
    display_similar_pairs(similar_pairs)
    return similar_pairs

In [None]:
# Find and display top 5 most similar pairs
similar_pairs = find_and_display_similar_pairs(
    model=learner, 
    test_loader=data['test_loader'], 
    device=device, 
    top_k=5
)

In [None]:
similar_pairs 

In [None]:
def calculate_clip_similarity(data):
    """
    Calculate the CLIP-based cosine similarity for each image pair in the input data.
    
    Args:
        data (list): List of dictionaries with image pair info
        
    Returns:
        list: The same list with added CLIP similarity scores
    """
    # Load the CLIP model and processor
    print("Loading CLIP model...")
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    # Process each pair
    print("Processing image pairs...")
    for pair in tqdm(data):
        image1_path = pair['image1_path']
        image2_path = pair['image2_path']
        
        try:
            # Load the images
            image1 = Image.open(image1_path).convert("RGB")
            image2 = Image.open(image2_path).convert("RGB")
            
            # Process the images
            inputs1 = processor(images=image1, return_tensors="pt")
            inputs2 = processor(images=image2, return_tensors="pt")
            
            # Get image features
            with torch.no_grad():
                image_features1 = model.get_image_features(**inputs1)
                image_features2 = model.get_image_features(**inputs2)
            
            # Normalize the features
            image_features1 = image_features1 / image_features1.norm(dim=1, keepdim=True)
            image_features2 = image_features2 / image_features2.norm(dim=1, keepdim=True)
            
            # Calculate cosine similarity
            clip_similarity = torch.nn.functional.cosine_similarity(image_features1, image_features2).item()
            
            # Add CLIP similarity to the pair data
            pair['clip_similarity'] = clip_similarity
            
        except Exception as e:
            print(f"Error processing {image1_path} and {image2_path}: {e}")
            pair['clip_similarity'] = None
    
    return data

In [None]:
result_data = calculate_clip_similarity(similar_pairs)

In [None]:
result_data

In [None]:
def visualize_image_pairs(data, output_dir=None, num_pairs=5):
    """
    Visualize pairs of images with their similarity scores.
    
    Args:
        data (list): List of dictionaries with image pair info and similarity scores
        output_dir (str): Directory to save visualizations (if None, just displays them)
        num_pairs (int): Number of pairs to visualize
    """
    # Create output directory if specified
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Limit to the requested number of pairs
    pairs_to_visualize = data[:min(num_pairs, len(data))]
    
    for i, pair in enumerate(pairs_to_visualize):
        # Extract information
        image1_path = pair['image1_path']
        image2_path = pair['image2_path']
        model_similarity = pair['learner similarity']
        clip_similarity = pair['clip_similarity']
        
        # Open images
        try:
            img1 = Image.open(image1_path).convert('RGB')
            img2 = Image.open(image2_path).convert('RGB')
            
            # Create figure with properly spaced layout
            fig = plt.figure(figsize=(12, 8))
            
            # Use GridSpec with more space between plots
            gs = GridSpec(3, 2, height_ratios=[4, 1, 2], hspace=0.1)
            
            # Display first image
            ax1 = fig.add_subplot(gs[0, 0])
            ax1.imshow(img1)
            ax1.set_title(f"Image 1 (idx: {pair['image1_idx']})")
            ax1.axis('off')
            
            # Display second image
            ax2 = fig.add_subplot(gs[0, 1])
            ax2.imshow(img2)
            ax2.set_title(f"Image 2 (idx: {pair['image2_idx']})")
            ax2.axis('off')
            
            # Create a separate subplot for the similarity bars that's clearly separated
            ax3 = fig.add_subplot(gs[2, :])
            
            # Create score bars directly in the subplot (no nested axes)
            labels = ['Model', 'CLIP']
            values = [model_similarity, clip_similarity]
            colors = [plt.cm.viridis(model_similarity), plt.cm.viridis(clip_similarity)]
            
            y_pos = np.arange(len(labels))
            
            # Plot horizontal bars
            bars = ax3.barh(y_pos, values, color=colors, height=0.4)
            
            # Add labels and formatting
            ax3.set_yticks(y_pos)
            ax3.set_yticklabels(labels)
            ax3.set_xlim(0, 1.0)
            ax3.set_xticks(np.arange(0, 1.1, 0.1))
            ax3.set_title('Similarity Scores')
            
            # Add value labels to the end of each bar
            for bar, value in zip(bars, values):
                ax3.text(value + 0.01, bar.get_y() + bar.get_height()/2, 
                         f'{value:.4f}', va='center')
            
            plt.suptitle(f"Image Pair {i+1}", fontsize=16)
            plt.tight_layout()
            
            # Save or display
            if output_dir:
                plt.savefig(os.path.join(output_dir, f"pair_{i+1}.png"), dpi=150)
                plt.close()
            else:
                plt.show()
                
        except Exception as e:
            print(f"Error visualizing pair {i+1}: {e}")

In [None]:
# Visualize image pairs
visualize_image_pairs(result_data)