In [1]:
import os
from collections import defaultdict
import re
import pandas as pd
import json
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import open_clip
from open_clip import create_model_from_pretrained, get_tokenizer
from PIL import Image
import numpy as np
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import precision_recall_curve
import random
from sklearn.metrics.pairwise import cosine_similarity

base_path = '/cs/labs/tomhope/yuvalbus/pmc/pythonProject/largeListsGuy'


In [2]:
# Load your data
with open(base_path + "/retrieval_labeled_img_pairs.pkl", "rb") as f:
    labeled_img_pairs = pickle.load(f)

def extract_uid(img_path):
    """
    Extracts the uid from the image path.
    Assumes that the uid is the last directory before the image filename.
    """
    # Split the path into parts
    path_parts = os.path.normpath(img_path).split(os.sep)
    # Get the uid (second last part)
    uid = path_parts[-2]
    return uid

def dfs(uid, visited, component):
    visited.add(uid)
    component.add(uid)
    for neighbor in uid_graph[uid]:
        if neighbor not in visited:
            dfs(neighbor, visited, component)

# Build the uid_graph
uid_graph = defaultdict(set)

# Build the graph
for (img_path1, img_path2), label in labeled_img_pairs:
    uid1 = extract_uid(img_path1)
    uid2 = extract_uid(img_path2)
    uid_graph[uid1].add(uid2)
    uid_graph[uid2].add(uid1)

# Find connected components
visited = set()
components = []

for uid in uid_graph:
    if uid not in visited:
        component = set()
        dfs(uid, visited, component)
        components.append(component)

# Step 1: Build UID to component index mapping
uid_to_component_idx = {}

for idx, component in enumerate(components):
    for uid in component:
        uid_to_component_idx[uid] = idx

# Step 2: Count samples per component
component_sample_counts = [0] * len(components)

for (img_path1, img_path2), label in labeled_img_pairs:
    uid1 = extract_uid(img_path1)
    component_idx = uid_to_component_idx[uid1]
    component_sample_counts[component_idx] += 1

# Step 3: Sort components by sample count
components_with_counts = list(zip(components, component_sample_counts))
components_with_counts.sort(key=lambda x: x[1], reverse=True)

# Step 4: Assign components to training and test sets
total_samples = len(labeled_img_pairs)
desired_train_samples = int(total_samples * 0.8)

train_uids = set()
test_uids = set()

accumulated_train_samples = 0

for component, sample_count in components_with_counts:
    if accumulated_train_samples < desired_train_samples:
        train_uids.update(component)
        accumulated_train_samples += sample_count
    else:
        test_uids.update(component)

# Step 5: Assign samples to training and test sets, discard cross-set samples
train_data = []
test_data = []
discarded_samples = []

for sample in labeled_img_pairs:
    (img_path1, img_path2), label = sample
    uid1 = extract_uid(img_path1)
    uid2 = extract_uid(img_path2)

    if uid1 in train_uids and uid2 in train_uids:
        train_data.append(sample)
    elif uid1 in test_uids and uid2 in test_uids:
        test_data.append(sample)
    else:
        # Discard cross-set samples to maintain UID exclusivity
        discarded_samples.append(sample)

# Step 6: Verify the split ratio
train_sample_count = len(train_data)
test_sample_count = len(test_data)
total_sample_count = train_sample_count + test_sample_count

train_ratio = train_sample_count / total_sample_count
test_ratio = test_sample_count / total_sample_count

print(f"Training samples: {train_sample_count} ({train_ratio:.2%})")
print(f"Test samples: {test_sample_count} ({test_ratio:.2%})")
print(f"Discarded samples: {len(discarded_samples)}")

# Step 7: Ensure UID exclusivity in the test set
uids_in_training_samples = set()
for (img_path1, img_path2), label in train_data:
    uids_in_training_samples.update([extract_uid(img_path1), extract_uid(img_path2)])

uids_in_test_samples = set()
for (img_path1, img_path2), label in test_data:
    uids_in_test_samples.update([extract_uid(img_path1), extract_uid(img_path2)])

overlap_uids = uids_in_test_samples.intersection(uids_in_training_samples)
assert len(overlap_uids) == 0, "Overlap detected between training and test UIDs!"

print("UID exclusivity between training and test sets is maintained.")

counter = 0
for (img_pair, label) in labeled_img_pairs:
    if label == 1:
        counter += 1
print(f"Total Positive Pairs {counter}")


counter = 0
for (img_pair, label) in train_data:
    if label == 1:
        counter += 1
print(f"Training Positive Pairs {counter}")

counter = 0
for (img_pair, label) in train_data:
    if label == 0:
        counter += 1
print(f"Training Negative Pairs (before the hard visual negatives) {counter}")

counter = 0
for (img_pair, label) in test_data:
    if label == 1:
        counter += 1
print(f"Test Positive Pairs {counter}")

# Load additional negative pairs
with open(base_path + "/visual_labeled_img_negative_pairs.pkl", "rb") as f:
    visual_labeled_img_negative_pairs = pickle.load(f)

train_data.extend(visual_labeled_img_negative_pairs)
counter = 0
for (img_pair, label) in train_data:
    if label == 0:
        counter += 1
print(f"Training Negative Pairs (after the hard visual negatives) {counter}")


Training samples: 1056 (79.94%)
Test samples: 265 (20.06%)
Discarded samples: 0
UID exclusivity between training and test sets is maintained.
Total Positive Pairs 1321
Training Positive Pairs 1056
Training Negative Pairs (before the hard visual negatives) 0
Test Positive Pairs 265
Training Negative Pairs (after the hard visual negatives) 1727


In [3]:
# Dataset Classes
class TestImagePairDataset(Dataset):
    def __init__(self, pairs_list, transform=None):
        self.pairs_list = pairs_list
        self.transform = transform

    def __len__(self):
        return len(self.pairs_list)

    def __getitem__(self, idx):
        (img_path1, img_path2), label = self.pairs_list[idx]
        img1 = Image.open(img_path1).convert('RGB')
        img2 = Image.open(img_path2).convert('RGB')

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        label = torch.tensor(label, dtype=torch.float32)
        return (img1, img2), label

class TrainImagePairDataset(Dataset):
    def __init__(self, pairs_list, transform=None):
        self.pairs_list = pairs_list
        self.transform = transform

    def __len__(self):
        return len(self.pairs_list)

    def __getitem__(self, idx):
        (img_path1, img_path2), label = self.pairs_list[idx]
        img1 = Image.open(img_path1).convert('RGB')
        img2 = Image.open(img_path2).convert('RGB')

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        label = torch.tensor(label, dtype=torch.float32)
        return (img1, img2), label

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ImageEmbeddingModel(nn.Module):
    def __init__(self, model, dropout_rate):
        super(ImageEmbeddingModel, self).__init__()
        self.model = model 
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x):
        # Use encode_image to get the embeddings
        embeddings = self.model.encode_image(x)
        embeddings = self.dropout(embeddings)
        return embeddings

# Function to initialize the model
def initialize_biomedclip_model():
    biomedclip_model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
    biomedclip_model.to(device)
    tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
    return biomedclip_model, preprocess, tokenizer




# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)


In [4]:
import random
from collections import defaultdict
from torch.utils.data import Dataset
from PIL import Image
import networkx as nx  # For building graph components

class TripletImageDataset(Dataset):
    def __init__(self, pair_dataset, transform=None, augmentation=None):
        """
        pair_dataset: A list of (((img_path1, img_path2), label), ...).
                      label=1 means img_path1 and img_path2 belong to the same class,
                      label=0 means different classes.

        transform: Transform applied to all images.
        augmentation: Additional augmentation for positive images.

        This class:
        1) Constructs a graph of images connected by edges where label=1.
        2) Finds connected components to identify classes.
        3) From these classes, creates triplets (anchor, positive, negative).
        """

        self.transform = transform
        self.augmentation = augmentation

        # Extract classes from pairs
        self.class_to_images = self._group_images_into_classes(pair_dataset)

        # Generate triplets
        self.triplets = self._generate_triplets()

    def _group_images_into_classes(self, pair_dataset):
        # Build a graph where each node is an image path.
        # Edges with label=1 connect images of the same class.
        G = nx.Graph()
        all_images = set()
        for (img1, img2), label in pair_dataset:
            all_images.add(img1)
            all_images.add(img2)
            G.add_node(img1)
            G.add_node(img2)
            if label == 1:
                G.add_edge(img1, img2)

        # Find connected components - each component is a class
        components = list(nx.connected_components(G))

        class_to_images = {}
        for idx, comp in enumerate(components):
            # comp is a set of image paths
            class_to_images[idx] = list(comp)

        return class_to_images

    def _generate_triplets(self):
        # We now have class_to_images as {class_id: [img_paths]}
        # For each class, we'll create triplets by picking an anchor and positive from this class
        # and a negative from a different class.
        triplets = []

        # Get list of classes
        classes = list(self.class_to_images.keys())

        # We need at least 2 classes to form a negative example
        if len(classes) < 2:
            print("Not enough classes to form triplets.")
            return triplets

        for cls in classes:
            img_paths = self.class_to_images[cls]
            # Need at least 2 images to form anchor-positive pair
            if len(img_paths) < 2:
                continue

            for anchor_idx, anchor_path in enumerate(img_paths):
                # Positive candidates are all other images in the same class
                positive_candidates = img_paths[:anchor_idx] + img_paths[anchor_idx+1:]
                if not positive_candidates:
                    continue
                positive_path = random.choice(positive_candidates)

                # Choose negative class
                negative_cls = random.choice([c for c in classes if c != cls])
                negative_path = random.choice(self.class_to_images[negative_cls])

                triplets.append((anchor_path, positive_path, negative_path))

        return triplets

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        anchor_path, positive_path, negative_path = self.triplets[idx]

        anchor_img = Image.open(anchor_path).convert('RGB')
        positive_img = Image.open(positive_path).convert('RGB')
        negative_img = Image.open(negative_path).convert('RGB')

        if self.transform:
            anchor_img = self.transform(anchor_img)
            positive_img = self.transform(positive_img)
            negative_img = self.transform(negative_img)

        # Apply augmentation to the positive image if desired
        if self.augmentation:
            positive_img = self.augmentation(positive_img)

        return anchor_img, positive_img, negative_img


In [5]:
# Set hyperparameters
margin = 0.1
gamma_loss = 60
learning_rate = 3e-7
dropout_rate = 0.0
num_epochs = 5
optimizer_name = 'Adam'
batch_size = 16
weight_decay = 1e-5

# Initialize the model
biomedclip_model, preprocess, _ = initialize_biomedclip_model()

# Initialize the image embedding model
image_embedding_model = ImageEmbeddingModel(biomedclip_model, dropout_rate=dropout_rate).to(device)

# Ensure all parameters are trainable
for param in image_embedding_model.parameters():
    param.requires_grad = True


In [6]:
# Dataloader for training set
train_dataset = TripletImageDataset(train_data, transform=preprocess)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

# Dataloader for test set
test_dataset = TripletImageDataset(test_data, transform=preprocess)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


In [7]:
criterion = nn.TripletMarginLoss(margin=margin) 
# Initialize optimizer
optimizer = torch.optim.Adam(image_embedding_model.parameters(), lr=learning_rate, weight_decay=weight_decay)


In [8]:
# CandidateImageDataset in order to calculate efficiently visual embeddings for Retrieval Metrics
class CandidateImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return img_path, image

# QueryImageDataset in order to calculate efficiently visual embeddings for Retrieval Metrics
class QueryImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return img_path, image


In [9]:
# Training and Validation Loop
for epoch in range(num_epochs):
    image_embedding_model.train()
    train_loss = 0

    # Now the dataloader returns triplets: anchor, positive, negative
    for anchor_img, positive_img, negative_img in train_dataloader:
        anchor_img = anchor_img.to(device)
        positive_img = positive_img.to(device)
        negative_img = negative_img.to(device)

        optimizer.zero_grad()

        anchor_embeddings = image_embedding_model(anchor_img)
        positive_embeddings = image_embedding_model(positive_img)
        negative_embeddings = image_embedding_model(negative_img)

        # Compute triplet loss
        loss = criterion(anchor_embeddings, positive_embeddings, negative_embeddings)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    avg_train_loss = train_loss / len(train_dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.6f}')

    # Validation loop (no changes needed, this remains the same as previously)
    # Collect all unique image paths from test_data
    test_paths_set = set()
    for (img_path1, img_path2), _ in test_data:
        test_paths_set.update([img_path1, img_path2])

    test_paths_list = list(test_paths_set)

    # Prepare candidate dataset and dataloader
    candidate_dataset = CandidateImageDataset(test_paths_list, transform=preprocess)
    candidate_dataloader = DataLoader(candidate_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    image_embedding_model.eval()
    # Compute embeddings for retrieval evaluation
    with torch.no_grad():
        # Compute candidate embeddings
        candidate_embeddings = {}
        for img_paths, images in candidate_dataloader:
            images = images.to(device)
            embeddings = image_embedding_model(images)
            embeddings = embeddings.cpu().numpy()
            for img_path, embedding in zip(img_paths, embeddings):
                candidate_embeddings[img_path] = embedding

        # Prepare ground truth
        ground_truth = {}
        for (img_path1, img_path2), label in test_data:
            if label == 1:
                if img_path1 not in ground_truth:
                    ground_truth[img_path1] = img_path2
                if img_path2 not in ground_truth:
                    ground_truth[img_path2] = img_path1

        # Prepare query embeddings
        query_image_paths = list(ground_truth.keys())
        query_dataset = QueryImageDataset(query_image_paths, transform=preprocess)
        query_dataloader = DataLoader(query_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

        query_embeddings = {}
        for img_paths, images in query_dataloader:
            images = images.to(device)
            embeddings = image_embedding_model(images)
            embeddings = embeddings.cpu().numpy()
            for img_path, embedding in zip(img_paths, embeddings):
                query_embeddings[img_path] = embedding

    # Retrieval Evaluation (no changes needed)
    hits_at_k = {1: 0, 3: 0, 5: 0}
    reciprocal_ranks = []
    num_queries = len(query_embeddings)

    candidate_ids = list(candidate_embeddings.keys())
    candidate_emb_matrix = np.array([candidate_embeddings[cid] for cid in candidate_ids])

    for query_id, query_emb in query_embeddings.items():
        # Exclude the query image from the candidate set
        adjusted_candidate_ids = [cid for cid in candidate_ids if cid != query_id]
        adjusted_candidate_emb_matrix = np.array([candidate_embeddings[cid] for cid in adjusted_candidate_ids])

        # Compute similarities between the query and adjusted candidates
        query_emb_vector = np.expand_dims(query_emb, axis=0)
        similarities = cosine_similarity(query_emb_vector, adjusted_candidate_emb_matrix)[0]

        # Pair adjusted candidate IDs with similarities
        similarities_ids_pair_list = list(zip(adjusted_candidate_ids, similarities))

        # Rank candidates by similarity (descending order)
        ranked_similarities = sorted(similarities_ids_pair_list, key=lambda x: x[1], reverse=True)
        ranked_candidate_ids = [candidate_id for candidate_id, _ in ranked_similarities]

        # Get the golden image for the query
        golden_image = ground_truth[query_id]

        # Compute Hits@K
        for K in hits_at_k:
            if golden_image in ranked_candidate_ids[:K]:
                hits_at_k[K] += 1

        # Compute Reciprocal Rank
        if golden_image in ranked_candidate_ids:
            golden_img_idx = ranked_candidate_ids.index(golden_image)
            golden_img_rank = golden_img_idx + 1
            reciprocal_ranks.append(1.0 / golden_img_rank)
        else:
            reciprocal_ranks.append(0.0)

    # Compute final metrics
    for K in hits_at_k:
        hits_at_k[K] = hits_at_k[K] / num_queries

    mrr = sum(reciprocal_ranks) / num_queries

    print(f'Epoch [{epoch+1}/{num_epochs}], MRR: {mrr:.6f}')
    print(f'Hits@1: {hits_at_k[1]:.4f}')
    print(f'Hits@3: {hits_at_k[3]:.4f}')
    print(f'Hits@5: {hits_at_k[5]:.4f}')


Epoch [1/5], Training Loss: 0.060176
Epoch [1/5], MRR: 0.570493
Hits@1: 0.4283
Hits@3: 0.6491
Hits@5: 0.7453
Epoch [2/5], Training Loss: 0.035246
Epoch [2/5], MRR: 0.570903
Hits@1: 0.4283
Hits@3: 0.6509
Hits@5: 0.7472
Epoch [3/5], Training Loss: 0.015603
Epoch [3/5], MRR: 0.571211
Hits@1: 0.4283
Hits@3: 0.6509
Hits@5: 0.7491
Epoch [4/5], Training Loss: 0.008347
Epoch [4/5], MRR: 0.572075
Hits@1: 0.4302
Hits@3: 0.6509
Hits@5: 0.7472
Epoch [5/5], Training Loss: 0.003862
Epoch [5/5], MRR: 0.572080
Hits@1: 0.4302
Hits@3: 0.6509
Hits@5: 0.7472
