In [2]:
import os
import re
import pandas as pd
import numpy as np
import json
import pickle
from PIL import Image
# from sklearn.metrics.pairwise import cosine_similarity

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPModel

In [3]:
import open_clip
from open_clip import create_model_from_pretrained, get_tokenizer # For BiomedCLIP model
from transformers import AutoTokenizer, AutoModel # For BioLinkBERT model
from sentence_transformers import SentenceTransformer # For MPNet model

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
biomedclip_model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
biomedclip_model.to(device)
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

In [5]:
with open("largeListsGuy/gpt4o_response_list.json", 'r') as f:
    gpt4o_response_list = json.load(f)

gpt4o_dict = {idx: response for idx, response in enumerate(gpt4o_response_list) if not "Other/None/Unknown" in response}

filtered_csv_path = r'/cs/labs/tomhope/dhtandguy21/envs/pmc_project/21_sep_filtered_data.csv'
filtered_df = pd.read_csv(filtered_csv_path)
filtered_df.loc[list(gpt4o_dict.keys()), "gpt4o_response"] = list(gpt4o_dict.values())
def normalize(text):
    return re.sub(r'\s+', ' ', text.strip())
    
filtered_df["gpt4o_response"] = filtered_df["gpt4o_response"].apply(normalize)


In [6]:
with open("largeListsGuy/retrieval_labeled_img_pairs.pkl", "rb") as f:
    labeled_img_pairs = pickle.load(f)

In [7]:
len(labeled_img_pairs)

1321

In the next few cells, we will split the pairs into training and test set.

In order to prevent data leakage, it is necessary that images corresponding to the same patient will not appear in both training and test sets.

In [8]:
from collections import defaultdict
import os

def extract_uid(img_path):
    """
    Extracts the uid from the image path.
    Assumes that the uid is the last directory before the image filename.
    """
    # Split the path into parts
    path_parts = os.path.normpath(img_path).split(os.sep)
    # Get the uid (second last part)
    uid = path_parts[-2]
    return uid

def dfs(uid, visited, component):
    visited.add(uid)
    component.add(uid)
    for neighbor in uid_graph[uid]:
        if neighbor not in visited:
            dfs(neighbor, visited, component)

# Build the uid_graph
uid_graph = defaultdict(set)

# Build the graph
for (img_path1, img_path2), label in labeled_img_pairs:
    uid1 = extract_uid(img_path1)
    uid2 = extract_uid(img_path2)
    uid_graph[uid1].add(uid2)
    uid_graph[uid2].add(uid1)

# Find connected components
visited = set()
components = []

for uid in uid_graph:
    if uid not in visited:
        component = set()
        dfs(uid, visited, component)
        components.append(component)

# Step 1: Build UID to component index mapping
uid_to_component_idx = {}

for idx, component in enumerate(components):
    for uid in component:
        uid_to_component_idx[uid] = idx

# Step 2: Count samples per component
component_sample_counts = [0] * len(components)

for (img_path1, img_path2), label in labeled_img_pairs:
    uid1 = extract_uid(img_path1)
    component_idx = uid_to_component_idx[uid1]
    component_sample_counts[component_idx] += 1

# Step 3: Sort components by sample count
components_with_counts = list(zip(components, component_sample_counts))
components_with_counts.sort(key=lambda x: x[1], reverse=True)

# Step 4: Assign components to training and test sets
total_samples = len(labeled_img_pairs)
desired_train_samples = int(total_samples * 0.8)

train_uids = set()
test_uids = set()

accumulated_train_samples = 0

for component, sample_count in components_with_counts:
    if accumulated_train_samples < desired_train_samples:
        train_uids.update(component)
        accumulated_train_samples += sample_count
    else:
        test_uids.update(component)

# Step 5: Assign samples to training and test sets, discard cross-set samples
train_data = []
test_data = []
discarded_samples = []

for sample in labeled_img_pairs:
    (img_path1, img_path2), label = sample
    uid1 = extract_uid(img_path1)
    uid2 = extract_uid(img_path2)

    if uid1 in train_uids and uid2 in train_uids:
        train_data.append(sample)
    elif uid1 in test_uids and uid2 in test_uids:
        test_data.append(sample)
    else:
        # Discard cross-set samples to maintain UID exclusivity
        discarded_samples.append(sample)

# Step 6: Verify the split ratio
train_sample_count = len(train_data)
test_sample_count = len(test_data)
total_sample_count = train_sample_count + test_sample_count

train_ratio = train_sample_count / total_sample_count
test_ratio = test_sample_count / total_sample_count

print(f"Training samples: {train_sample_count} ({train_ratio:.2%})")
print(f"Test samples: {test_sample_count} ({test_ratio:.2%})")
print(f"Discarded samples: {len(discarded_samples)}")

# Step 7: Ensure UID exclusivity in the test set
uids_in_training_samples = set()
for (img_path1, img_path2), label in train_data:
    uids_in_training_samples.update([extract_uid(img_path1), extract_uid(img_path2)])

uids_in_test_samples = set()
for (img_path1, img_path2), label in test_data:
    uids_in_test_samples.update([extract_uid(img_path1), extract_uid(img_path2)])

overlap_uids = uids_in_test_samples.intersection(uids_in_training_samples)
assert len(overlap_uids) == 0, "Overlap detected between training and test UIDs!"

print("UID exclusivity between training and test sets is maintained.")


Training samples: 1056 (79.94%)
Test samples: 265 (20.06%)
Discarded samples: 0
UID exclusivity between training and test sets is maintained.


In [9]:
# Count cross-set samples
cross_set_sample_count = 0

for sample in labeled_img_pairs:
    (img_path1, img_path2), label = sample
    uid1 = extract_uid(img_path1)
    uid2 = extract_uid(img_path2)

    if (uid1 in train_uids and uid2 in test_uids) or (uid1 in test_uids and uid2 in train_uids):
        cross_set_sample_count += 1

print(f"Number of cross-set samples: {cross_set_sample_count}")


Number of cross-set samples: 0


In [10]:
overlap_uids = uids_in_test_samples.intersection(uids_in_training_samples)
assert len(overlap_uids) == 0, "Overlap detected between training and test UIDs!"


In [11]:
total_samples_accounted = train_sample_count + test_sample_count + len(discarded_samples)
original_total_samples = len(labeled_img_pairs)
assert total_samples_accounted == original_total_samples, "Mismatch in total sample count!"

print(f"Total samples accounted for: {total_samples_accounted}")
print(f"Original total samples: {original_total_samples}")


Total samples accounted for: 1321
Original total samples: 1321


In [12]:
counter = 0
for (img_pair, label) in labeled_img_pairs:
    if label == 1:
        counter += 1
print(counter)


1321


In [13]:
counter = 0
for (img_pair, label) in train_data:
    if label == 1:
        counter += 1
print(counter)


1056


In [14]:
counter = 0
for (img_pair, label) in test_data:
    if label == 1:
        counter += 1
print(counter)


265


In [15]:
import torchvision.transforms as transforms

# Define a set of augmentations
positive_augmentations = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),
    # Do not include ToTensor() or Normalize() here
])


In [16]:
# Dataset Class
class TestImagePairDataset(Dataset):
    def __init__(self, pairs_list, transform=None):
        self.pairs_list = pairs_list
        self.transform = transform

    def __len__(self):
        return len(self.pairs_list)

    def __getitem__(self, idx):
        (img_path1, img_path2), label = self.pairs_list[idx]
        img1 = Image.open(img_path1).convert('RGB')
        img2 = Image.open(img_path2).convert('RGB')

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        label = torch.tensor(label, dtype=torch.float32)
        return (img1, img2), label


In [17]:
class TrainImagePairDataset(Dataset):
    def __init__(self, pairs_list, transform=None, augmentation=None):
        self.pairs_list = pairs_list
        self.transform = transform
        self.augmentation = augmentation
        self.expanded_pairs_list = self.expand_positive_pairs()
        
    def expand_positive_pairs(self):
        expanded_list = []
        for pair, label in self.pairs_list:
            expanded_list.append(((pair[0], pair[1]), label))  # Original pair
            if label == 1:
                expanded_list.append(((pair[0], pair[1]), label))  # Augmented pair
        return expanded_list

    def __len__(self):
        return len(self.expanded_pairs_list)

    def __getitem__(self, idx):
        (img_path1, img_path2), label = self.expanded_pairs_list[idx]
        img1 = Image.open(img_path1).convert('RGB')
        img2 = Image.open(img_path2).convert('RGB')

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        # Apply augmentation to the augmented copies
        if label == 1 and idx % 2 == 1:  # Every second positive pair is augmented
            if self.augmentation:
                img1 = self.augmentation(img1)
                img2 = self.augmentation(img2)


        label = torch.tensor(label, dtype=torch.float32)
        return (img1, img2), label


In [18]:
with open("largeListsGuy/visual_labeled_img_negative_pairs.pkl", "rb") as f:
    visual_labeled_img_negative_pairs = pickle.load(f)

In [19]:
len(visual_labeled_img_negative_pairs)

1727

In [20]:
train_data.extend(visual_labeled_img_negative_pairs)

In [21]:
len(train_data)

2783

In [22]:
# Determine the batch size
batch_size = 16

# Dataloader for training set
# train_dataset = TrainImagePairDataset(train_data, transform=preprocess, augmentation=positive_augmentations)
train_dataset = TrainImagePairDataset(train_data, transform=preprocess, augmentation=None)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

# Dataloader for test set
test_dataset = TestImagePairDataset(test_data, transform=preprocess)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


In [23]:
# Begin: Determining golden image for each image in the test set #

In [24]:
# Recall - test_data is a list of tuples: [((img_path1, img_path2), label), ...]
# and labels are either 0 (negative pair) or 1 (positive pair)

In [25]:
# Step 1: Collect all unique images from test_data
test_paths_set = set()
for (img_path1, img_path2), label in test_data:
    test_paths_set.update([img_path1, img_path2])

candidate_image_paths  = list(test_paths_set)

In [26]:
len(test_paths_set)

530

In [27]:
candidate_image_paths[0:5]

['/cs/labs/tomhope/yuvalbus/pmc/pythonProject/data2/PMC6336654/6336654_1/6336654_1_5.jpg',
 '/cs/labs/tomhope/yuvalbus/pmc/pythonProject/data2/PMC8211554/8211554_1/8211554_1_1.jpg',
 '/cs/labs/tomhope/yuvalbus/pmc/pythonProject/data2/PMC5438232/5438232_1/5438232_1_4.jpg',
 '/cs/labs/tomhope/yuvalbus/pmc/pythonProject/data2/PMC2577102/2577102_1/2577102_1_3.jpg',
 '/cs/labs/tomhope/yuvalbus/pmc/pythonProject/data2/PMC4557154/4557154_1/4557154_1_2.jpg']

In [28]:
# Step 2: Create the ground_truth mapping
ground_truth = {}

for (img_path1, img_path2), label in test_data:
    if label == 1:
        if img_path1 not in ground_truth:
            ground_truth[img_path1] = img_path2
        if img_path2 not in ground_truth:
            ground_truth[img_path2] = img_path1


In [29]:
len(ground_truth)

530

In [30]:
class ImageEmbeddingModel(nn.Module):
    def __init__(self, model):
        super(ImageEmbeddingModel, self).__init__()
        self.model = model 

    def forward(self, x):
        # Use encode_image to get the embeddings
        embeddings = self.model.encode_image(x)
        return embeddings

In [31]:
# In order to check Retrieval Metrics for the original BiomedCLIP
image_embedding_model = ImageEmbeddingModel(biomedclip_model)
image_embedding_model.to(device)
image_embedding_model.eval();

In [32]:
# # Reload the trained model
# base_path = '/cs/labs/tomhope/yuvalbus/pmc/pythonProject/largeListsGuy'
# image_embedding_model = torch.load(base_path+'/image_embedding_model_full.pth')
# image_embedding_model.to(device)
# image_embedding_model.eval();

In [33]:
# Create a Dataset class for candidate images
class CandidateImageDataset(Dataset):
    def __init__(self, img_paths, transform=None):
        self.img_paths = img_paths
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return img_path, image


In [34]:
# Create a candidate dataset
candidate_dataset = CandidateImageDataset(candidate_image_paths, transform=preprocess)

# Create a candidate dataloader
candidate_dataloader = DataLoader(candidate_dataset, batch_size=64, shuffle=False, num_workers=2)

In [35]:
# Compute candidate embeddings
candidate_embeddings = {}
with torch.no_grad():
    for img_paths, images in candidate_dataloader:
        images = images.to(device)
        embeddings = image_embedding_model(images)
        embeddings = embeddings.cpu().numpy()
        for img_path, embedding in zip(img_paths, embeddings):
            candidate_embeddings[img_path] = embedding


In [36]:
len(candidate_dataset)

530

In [37]:
# Create a Dataset class for query images
class QueryImageDataset(Dataset):
    def __init__(self, img_paths, transform=None):
        self.img_paths = img_paths
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return img_path, image


In [38]:
query_image_paths = list(ground_truth.keys())

# Create a query dataset
query_dataset = QueryImageDataset(query_image_paths, transform=preprocess)

# Create a query dataloader
query_dataloader = DataLoader(query_dataset, batch_size=64, shuffle=False, num_workers=2)

In [39]:
# Compute query embeddings
query_embeddings = {}
with torch.no_grad():
    for img_paths, images in query_dataloader:
        images = images.to(device)
        embeddings = image_embedding_model(images)
        embeddings = embeddings.cpu().numpy()
        for img_path, embedding in zip(img_paths, embeddings):
            query_embeddings[img_path] = embedding


In [40]:
# Retrieval Evaluation with Exclusion of Query Images

hits_at_k = {1: 0, 3: 0, 5: 0}
reciprocal_ranks = []
num_queries = len(query_embeddings)

# To ensure the alignment of the keys with their values
candidate_ids = list(candidate_embeddings.keys())
candidate_emb_matrix = np.array([candidate_embeddings[cid] for cid in candidate_ids])

for query_id, query_emb in query_embeddings.items():
    # Exclude the query image from the candidate set
    adjusted_candidate_ids = [cid for cid in candidate_ids if cid != query_id]
    adjusted_candidate_emb_matrix = np.array([candidate_embeddings[cid] for cid in adjusted_candidate_ids])

    # Compute similarities between the query and adjusted candidates
    query_emb_vector = np.expand_dims(query_emb, axis=0) # Shape (1, 512)
    similarities = cosine_similarity(query_emb_vector, adjusted_candidate_emb_matrix)[0] # (1, 512) x (512, 5112) -> (1, 5112)                                  

    # Pair adjusted_cadidate lists with similarities
    similarities_ids_pair_list = list(zip(adjusted_candidate_ids, similarities))

    # Sort in descending order the similarities of the query with respect to each candidate
    ranked_similarities = sorted(similarities_ids_pair_list, key=lambda x: x[1], reverse=True)
    ranked_candidate_ids = [candidate_id for candidate_id, _ in ranked_similarities]

    # Get the golden image for the query
    golden_image = ground_truth[query_id]

    # Compute Hits@K
    for K in hits_at_k:
        if golden_image in ranked_candidate_ids[:K]:
            hits_at_k[K] += 1

    # Compute Reciprocal Rank
    if golden_image in ranked_candidate_ids:
        golden_img_idx = ranked_candidate_ids.index(golden_image)
        golden_img_rank = golden_img_idx + 1
        reciprocal_ranks.append(1.0 / golden_img_rank)
    else:
        reciprocal_ranks.append(0.0)


In [41]:
# Compute and Display Metrics
for K in hits_at_k:
    hits_at_k[K] = hits_at_k[K] / num_queries
    print(f"Hits@{K}: {hits_at_k[K]:.4f}")

reciprocal_ranks_sum = sum(reciprocal_ranks)
mrr = reciprocal_ranks_sum / num_queries
print(f"MRR: {mrr:.4f}")


Hits@1: 0.4283
Hits@3: 0.6472
Hits@5: 0.7472
MRR: 0.5702


In [None]:
Hits@1: 0.1396
Hits@3: 0.2775
Hits@5: 0.3633
MRR: 0.2542


CURRENTLY THE BEST.

Before Fine-Tuning pos > 0.86, 0.7<vis_neg<0.8 in range(1, 130), no augmentation:

Hits@1: 0.2321

Hits@3: 0.4291

Hits@5: 0.5218

MRR: 0.3706

After Fine-Tuning:

Adam:

RMSprop:

SGD:


Before Fine-Tuning best f1-score model:

Hits@1: 0.1429

Hits@3: 0.2857

Hits@5: 0.3829

MRR: 0.2589


Before Fine-Tuning pos > 0.9, 0.55<neg<0.65 & vis_neg > 0.7 in range(1, 130):

Hits@1: 0.2978

Hits@3: 0.5393

Hits@5: 0.6461

MRR: 0.4553

After Fine-Tuning pos > 0.9, 0.55<neg<0.65 & vis_neg > 0.7 in range(1, 130):

This is after 1 epoch, it got worse afterwards:

MRR: 0.458105

Hits@1: 0.3146

Hits@3: 0.5337

Hits@5: 0.6292

Before Fine-Tuning pos > 0.8, 0.55<neg<0.65 & vis_neg > 0.7 in range(1, 130):

Hits@1: 0.0909

Hits@3: 0.1878

Hits@5: 0.2695

MRR: 0.1866


Before Fine-Tuning pos > 0.9, 0.60<neg<0.65 & 0.7 < vis_neg < 0.8 in range(1, 130):
MRR: 0.3751
Hits@1: 0.2188
Hits@3: 0.4467
Hits@5: 0.5515

After Fine-Tuning pos > 0.9, 0.60<neg<0.65 & 0.7 < vis_neg < 0.8 in range(1, 130):
MRR: 0.3781
Hits@1: 0.2316
Hits@3: 0.4375
Hits@5: 0.5368

Before Fine-Tuning pos > 0.9, 0.55<neg<0.65 & 0.7<vis_neg<0.8 in range(1, 130):


In [None]:
# End: Determining golden image for each image in the test set #

In [25]:
print(preprocess)

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    <function _convert_to_rgb at 0x7f3c39fcb240>
    ToTensor()
    Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
)


In [42]:
class ImageEmbeddingModel(nn.Module):
    def __init__(self, model):
        super(ImageEmbeddingModel, self).__init__()
        self.model = model 

    def forward(self, x):
        # Use encode_image to get the embeddings
        embeddings = self.model.encode_image(x)
        return embeddings

In [43]:
# Instantiate the model
image_embedding_model = ImageEmbeddingModel(biomedclip_model)

# Set model to train mode and move to device
image_embedding_model.train()
image_embedding_model.to(device)
# Ensure parameters are trainable
for param in image_embedding_model.parameters():
    param.requires_grad = True


In [28]:
# class FocalContrastiveLoss(nn.Module):
#     def __init__(self, margin, gamma=2.0):
#         super(FocalContrastiveLoss, self).__init__()
#         self.margin = margin
#         self.gamma = gamma

#     def forward(self, embedding1, embedding2, label):
#         cos_sim = nn.functional.cosine_similarity(embedding1, embedding2)
#         cos_dist = 1 - cos_sim
#         pos_loss = label * cos_dist.pow(2) * (1 - cos_sim).pow(self.gamma)
#         neg_loss = (1 - label) * nn.functional.relu(self.margin - cos_dist).pow(2) * cos_sim.pow(self.gamma)
#         loss = pos_loss + neg_loss
#         return loss.mean()


In [29]:
# margin = 0.7
# criterion = FocalContrastiveLoss(margin=margin)

# # 6. Optimizer
# optimizer = torch.optim.Adam(image_embedding_model.parameters(), lr=5e-7)


In [30]:
# 5. Loss Function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin, pos_weight):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.pos_weight = pos_weight

    def forward(self, embedding1, embedding2, label):
        cos_sim = nn.functional.cosine_similarity(embedding1, embedding2)
        cos_dist = 1 - cos_sim
        # loss = label * cos_dist.pow(2) + \
        #        (1 - label) * nn.functional.relu(self.margin - cos_dist).pow(2)
        loss = self.pos_weight * label * cos_dist.pow(2) + \
       (1 - label) * nn.functional.relu(self.margin - cos_dist).pow(2)
        return loss.mean()

margin = 0.55
criterion = ContrastiveLoss(margin=margin, pos_weight=9)

# 6. Optimizer
optimizer = torch.optim.Adam(image_embedding_model.parameters(), lr=2e-7)



In [32]:
# class FocalContrastiveLoss(nn.Module):
#     def __init__(self, margin, gamma=2.0):
#         super(FocalContrastiveLoss, self).__init__()
#         self.margin = margin
#         self.gamma = gamma

#     def forward(self, embedding1, embedding2, label):
#         cos_sim = nn.functional.cosine_similarity(embedding1, embedding2)
#         cos_dist = 1 - cos_sim

#         # Calculate focal factors
#         pos_focal = (1 - cos_sim).pow(self.gamma)
#         neg_focal = cos_sim.pow(self.gamma)

#         # Compute losses
#         pos_loss = label * pos_focal * cos_dist.pow(2)
#         neg_loss = (1 - label) * neg_focal * nn.functional.relu(self.margin - cos_dist).pow(2)

#         loss = pos_loss + neg_loss
#         return loss.mean()
        
# margin = 0.7
# gamma = 3.0  # Start with gamma = 2.0
# criterion = FocalContrastiveLoss(margin=margin, gamma=gamma)

# # 6. Optimizer
# optimizer = torch.optim.Adam(image_embedding_model.parameters(), lr=5e-7)


In [32]:
# class FocalContrastiveLoss(nn.Module):
#     def __init__(self, margin, gamma=2.0, pos_weight=1.0):
#         super(FocalContrastiveLoss, self).__init__()
#         self.margin = margin
#         self.gamma = gamma
#         self.pos_weight = pos_weight

#     def forward(self, embedding1, embedding2, label):
#         cos_sim = nn.functional.cosine_similarity(embedding1, embedding2)
#         cos_dist = 1 - cos_sim

#         # Calculate focal factors
#         pos_focal = (1 - cos_sim).pow(self.gamma)
#         neg_focal = cos_sim.pow(self.gamma)

#         # Compute losses with pos_weight applied to positive loss
#         pos_loss = self.pos_weight * label * pos_focal * cos_dist.pow(2)
#         neg_loss = (1 - label) * neg_focal * nn.functional.relu(self.margin - cos_dist).pow(2)

#         loss = pos_loss + neg_loss
#         return loss.mean()
        
# margin = 0.5
# gamma = 3.0  
# pos_weight=8

# criterion = FocalContrastiveLoss(margin=margin, gamma=gamma, pos_weight=pos_weight)

# # 6. Optimizer
# optimizer = torch.optim.Adam(image_embedding_model.parameters(), lr=5e-7)


In [33]:
# 7. Training Loop
num_epochs = 12
train_losses = []
test_losses = []
all_cos_sims = []
all_labels = []

for epoch in range(num_epochs):
    image_embedding_model.train()
    train_loss = 0

    for (img1, img2), labels in train_dataloader:
        img1 = img1.to(device)
        img2 = img2.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        embeddings1 = image_embedding_model(img1)
        embeddings2 = image_embedding_model(img2)

        loss = criterion(embeddings1, embeddings2, labels)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    avg_train_loss = train_loss / len(train_dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.6f}')

    # Evaluate on the test set after each epoch
    image_embedding_model.eval()  # Set model to evaluation mode
    total_samples = 0
    test_loss = 0
    all_cos_sims = []
    all_labels = []
    with torch.no_grad():  # Disable gradient computation
        for (img1, img2), labels in test_dataloader:
            img1 = img1.to(device)
            img2 = img2.to(device)
            labels = labels.to(device)

            embeddings1 = image_embedding_model(img1)
            embeddings2 = image_embedding_model(img2)

            loss = criterion(embeddings1, embeddings2, labels)
            test_loss += loss.item()

            # Calculate cosine similarity
            cos_sim = nn.functional.cosine_similarity(embeddings1, embeddings2)
            
            # Collect all cosine similarities and labels
            all_cos_sims.extend(cos_sim.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_test_loss = test_loss / len(test_dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Test Loss: {avg_test_loss:.6f}')



Epoch [1/5], Training Loss: 0.219748
Epoch [1/5], Test Loss: 0.209443
Epoch [2/5], Training Loss: 0.158338
Epoch [2/5], Test Loss: 0.217761
Epoch [3/5], Training Loss: 0.120097
Epoch [3/5], Test Loss: 0.246501
Epoch [4/5], Training Loss: 0.087535
Epoch [4/5], Test Loss: 0.301915
Epoch [5/5], Training Loss: 0.061370
Epoch [5/5], Test Loss: 0.345229


In [34]:
# Calculate the optimal threshold based on precision-recall curve
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import precision_recall_curve

precision, recall, thresholds = precision_recall_curve(all_labels, all_cos_sims)
f1_scores = 2 * precision * recall / (precision + recall + 1e-6)
optimal_idx = np.argmax(f1_scores)
optimal_threshold = thresholds[optimal_idx]
print(f"Optimal Threshold: {optimal_threshold}")


# Apply optimal threshold to generate predictions
all_predictions = (torch.tensor(all_cos_sims) > optimal_threshold).float()

# Generate classification metrics
cm = confusion_matrix(all_labels, all_predictions)
print("Confusion Matrix:")
print(cm)

report = classification_report(all_labels, all_predictions, digits=4)
print("Classification Report:")
print(report)


Optimal Threshold: 0.6327815055847168
Confusion Matrix:
[[2018  224]
 [ 153  351]]
Classification Report:
              precision    recall  f1-score   support

         0.0     0.9295    0.9001    0.9146      2242
         1.0     0.6104    0.6964    0.6506       504

    accuracy                         0.8627      2746
   macro avg     0.7700    0.7983    0.7826      2746
weighted avg     0.8710    0.8627    0.8661      2746



In [29]:
# NO FINE-TUNING
from sklearn.metrics import precision_recall_curve, f1_score, confusion_matrix, classification_report

# Collect all true labels and cosine similarities
all_labels = []
all_cos_sims = []

with torch.no_grad():
    for (img1, img2), labels in test_dataloader:
        img1 = img1.to(device)
        img2 = img2.to(device)
        labels = labels.to(device)

        embeddings1 = biomedclip_model.encode_image(img1)
        embeddings2 = biomedclip_model.encode_image(img2)

        cos_sim = nn.functional.cosine_similarity(embeddings1, embeddings2)
        
        all_cos_sims.extend(cos_sim.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Convert lists to arrays for compatibility
all_cos_sims = np.array(all_cos_sims)
all_labels = np.array(all_labels)

# Compute precision, recall, and F1 score across different thresholds
precision, recall, thresholds = precision_recall_curve(all_labels, all_cos_sims)
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-6)  # Avoid division by zero

# Find the threshold that gives the highest F1 score
optimal_idx = np.argmax(f1_scores)
optimal_threshold = thresholds[optimal_idx]

print(f"Optimal Threshold for F1 Score: {optimal_threshold:.4f}")

# Now, using the optimal threshold to evaluate performance
optimal_predictions = (all_cos_sims > optimal_threshold).astype(float)

# Compute confusion matrix and classification report
cm = confusion_matrix(all_labels, optimal_predictions)
print("Confusion Matrix at Optimal Threshold:")
print(cm)

report = classification_report(all_labels, optimal_predictions, digits=4)
print("Classification Report at Optimal Threshold:")
print(report)


Optimal Threshold for F1 Score: 0.7050
Confusion Matrix at Optimal Threshold:
[[2020  222]
 [ 159  345]]
Classification Report at Optimal Threshold:
              precision    recall  f1-score   support

         0.0     0.9270    0.9010    0.9138      2242
         1.0     0.6085    0.6845    0.6443       504

    accuracy                         0.8613      2746
   macro avg     0.7677    0.7928    0.7790      2746
weighted avg     0.8686    0.8613    0.8643      2746



In [31]:
# torch.save({
#     'model_lreightToMinusSeven_posWeightEight_marginSix_fiveEpochs_augmentation_dict': image_embedding_model.state_dict(),
#     'optimizer_lreightToMinusSeven_posWeightEight_MarginSix_fiveEpochs_augmentation_dict': optimizer.state_dict(),
# }, 'model_lreightToMinusSeven_posWeightEight_MarginSix_fiveEpochs_augmentation.pth')


In [32]:
# # Load the saved dictionary
# checkpoint = torch.load('posWeight_lrTenToMinusFour_PredZeroPointEight.pth')

# # Create the model instance
# image_embedding_model = ImageEmbeddingModel(image_encoder)

# # Load the model state dictionary from the checkpoint
# image_embedding_model.load_state_dict(checkpoint['model_posWeight_lrTenToMinusFour_PredZeroPointEight_dict'])

# # Set the model to evaluation mode
# image_embedding_model.eval()
# image_embedding_model.to(device)

# # If you want to load the optimizer state too, you can initialize the optimizer and load its state
# optimizer.load_state_dict(checkpoint['optimizer_posWeight_lrTenToMinusFour_PredZeroPointEight_dict'])


In [33]:
# from sklearn.metrics import confusion_matrix, classification_report

# # Collect all true labels and predictions
# all_labels = []
# all_predictions = []
# cos_sims = []

# with torch.no_grad():
#     for (img1, img2), labels in test_dataloader:
#         img1 = img1.to(device)
#         img2 = img2.to(device)
#         labels = labels.to(device)

#         embeddings1 = image_embedding_model(img1)
#         embeddings2 = image_embedding_model(img2)

#         # Normalize embeddings
#         embeddings1_norm = embeddings1 / embeddings1.norm(dim=1, keepdim=True)
#         embeddings2_norm = embeddings2 / embeddings2.norm(dim=1, keepdim=True)
        
#         # Compute cosine similarity per pair
#         cos_sim = nn.functional.cosine_similarity(embeddings1_norm, embeddings2_norm)

#         # Apply threshold
#         predictions = (cos_sim > 0.6).float()

#         # Convert to CPU and numpy for metrics
#         all_predictions.extend(predictions.cpu().numpy())
#         all_labels.extend(labels.cpu().numpy())
#         cos_sims.extend(cos_sim.cpu().numpy())

# # Compute confusion matrix
# cm = confusion_matrix(all_labels, all_predictions)
# print("Confusion Matrix:")
# print(cm)

# # Classification report
# report = classification_report(all_labels, all_predictions, digits=4)
# print("Classification Report:")
# print(report)

# # Optional: Print cosine similarities if needed
# # print("Cosine Similarities:")
# # print(cos_sims)


In [None]:
# using optuna to determine best hyperparameters
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.metrics import f1_score
import optuna
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import precision_recall_curve
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ImageEmbeddingModel(nn.Module):
    def __init__(self, model, dropout_rate):
        super(ImageEmbeddingModel, self).__init__()
        self.model = model 
        self.dropout = nn.Dropout(p=dropout_rate)


    def forward(self, x):
        # Use encode_image to get the embeddings
        embeddings = self.model.encode_image(x)
        embeddings = self.dropout(embeddings)
        return embeddings


# Function to initialize the model
def initialize_biomedclip_model():
    # Replace with your actual model initialization code
    # Using your provided code
    biomedclip_model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
    biomedclip_model.to(device)
    tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
    return biomedclip_model, preprocess, tokenizer

# Define a callback function to log the best trial so far
def log_and_save_best_trial(study, trial):
    if study.best_trial == trial:
        # Log the best trial so far
        print(f"[{trial.datetime_start}] New best trial: Trial {trial.number}")
        print(f"  F1-Score: {trial.value}")
        print(f"  Parameters: {trial.params}")

        # Save the best trial to a file
        with open("best_trial_so_far.txt", "w") as f:
            f.write(f"Best trial so far: Trial {trial.number}\n")
            f.write(f"  F1-Score: {trial.value}\n")
            for key, value in trial.params.items():
                f.write(f"  {key}: {value}\n")

# Define the Focal Contrastive Loss
class FocalContrastiveLoss(nn.Module):
    def __init__(self, margin, gamma=2.0, pos_weight=1.0):
        super(FocalContrastiveLoss, self).__init__()
        self.margin = margin
        self.gamma = gamma
        self.pos_weight = pos_weight
        
    def forward(self, embedding1, embedding2, label):
        cos_sim = F.cosine_similarity(embedding1, embedding2)
        cos_dist = 1 - cos_sim

        pos_focal = (1 - cos_sim).pow(self.gamma)
        neg_focal = cos_sim.pow(self.gamma)

        pos_loss = self.pos_weight * label * pos_focal * cos_dist.pow(2)
        neg_loss = (1 - label) * neg_focal * F.relu(self.margin - cos_dist).pow(2)

        loss = pos_loss + neg_loss
        return loss.mean()


# Define the Optuna objective function with discrete options and pruning
def objective(trial):
    # Set random seed for reproducibility
    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Suggest discrete hyperparameters based on previous results
    margin = trial.suggest_categorical('margin', [0.5, 0.55])
    pos_weight = trial.suggest_categorical('pos_weight', [6.0, 8.0, 11.0])
    learning_rate = trial.suggest_categorical('lr', [1.2e-07, 1.5e-07, 1.8e-07, 2.2e-07, 2.5e-07, 3.5e-07])
    gamma = trial.suggest_categorical('gamma', [0])
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64])
    optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD', 'RMSprop'])
    dropout_rate = trial.suggest_categorical('dropout_rate', [0.0, 0.1, 0.3])


    # Initialize the model within the objective function
    biomedclip_model, preprocess, _ = initialize_biomedclip_model()

    # Initialize the image embedding model
    image_embedding_model = ImageEmbeddingModel(biomedclip_model, dropout_rate=dropout_rate).to(device)

    # Data augmentation parameters
    def get_augmentation(trial):
        rotation_degree = trial.suggest_categorical('rotation_degree', [0, 10, 15])
        horizontal_flip = trial.suggest_categorical('horizontal_flip', [True, False])
        vertical_flip = trial.suggest_categorical('vertical_flip', [False])  # Generally avoid vertical flip in medical images
        brightness = trial.suggest_categorical('brightness', [0.0, 0.1, 0.2])
        contrast = trial.suggest_categorical('contrast', [0.0, 0.1, 0.2])
        saturation = trial.suggest_categorical('saturation', [0.0, 0.1, 0.2])
        hue = trial.suggest_categorical('hue', [0.0, 0.05, 0.1])

        transform_list = []

        if horizontal_flip:
            transform_list.append(transforms.RandomHorizontalFlip(p=0.5))
        if vertical_flip:
            transform_list.append(transforms.RandomVerticalFlip(p=0.5))
        if rotation_degree > 0:
            transform_list.append(transforms.RandomRotation(degrees=rotation_degree))
        if any([brightness > 0.0, contrast > 0.0, saturation > 0.0, hue > 0.0]):
            transform_list.append(transforms.ColorJitter(
                brightness=brightness,
                contrast=contrast,
                saturation=saturation,
                hue=hue))
        if transform_list is not None:
            transform = transforms.Compose(transform_list)
        else:
            transform = None
            
        return transform

    positive_augmentation = get_augmentation(trial)

    # Dataloader for training set
    train_dataset = TrainImagePairDataset(train_data, transform=preprocess, augmentation=positive_augmentation)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    # Dataloader for test set
    test_dataset = TestImagePairDataset(test_data, transform=preprocess)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # Define the criterion and optimizer
    criterion = FocalContrastiveLoss(margin=margin, gamma=gamma, pos_weight=pos_weight)

    # Regularization
    weight_decay = trial.suggest_categorical('weight_decay', [0.0, 1e-5, 1e-4, 1e-3])

    # Initialize optimizer
    if optimizer_name == 'Adam':
        optimizer = torch.optim.Adam(image_embedding_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_name == 'SGD':
        optimizer = torch.optim.SGD(image_embedding_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
    elif optimizer_name == 'RMSprop':
        optimizer = torch.optim.RMSprop(image_embedding_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Initialize Learning Rate Scheduler
    scheduler_name = trial.suggest_categorical('scheduler', ['None', 'StepLR', 'CosineAnnealingLR'])
    # For StepLR rate scheduler
    step_size = trial.suggest_int('step_size', 2, 5)
    gamma_scheduler = trial.suggest_categorical('gamma_scheduler', [0.1, 0.5])
    # Initialize scheduler
    if scheduler_name == 'StepLR':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma_scheduler)
    elif scheduler_name == 'CosineAnnealingLR':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    else:
        scheduler = None  # No scheduler used


    num_epochs = 11  # Adjust as needed

    for epoch in range(num_epochs):
        image_embedding_model.train()
        train_loss = 0

        for (img1, img2), labels in train_dataloader:
            img1 = img1.to(device)
            img2 = img2.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            embeddings1 = image_embedding_model(img1)
            embeddings2 = image_embedding_model(img2)

            loss = criterion(embeddings1, embeddings2, labels)
            train_loss += loss.item()

            loss.backward()
            optimizer.step()

        avg_train_loss = train_loss / len(train_dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.6f}')
        
        # In the training loop, at the end of each epoch
        if scheduler is not None:
            scheduler.step()


        # Validation loop
        image_embedding_model.eval()
        all_cos_sims = []
        all_labels = []
        with torch.no_grad():
            for (img1, img2), labels in test_dataloader:
                img1 = img1.to(device)
                img2 = img2.to(device)
                labels = labels.to(device)

                embeddings1 = image_embedding_model(img1)
                embeddings2 = image_embedding_model(img2)

                cos_sim = F.cosine_similarity(embeddings1, embeddings2)
                all_cos_sims.extend(cos_sim.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Convert to numpy arrays
        all_cos_sims = np.array(all_cos_sims)
        all_labels = np.array(all_labels)

        # Compute F1-score for positive class
        precision, recall, thresholds = precision_recall_curve(all_labels, all_cos_sims)
        f1_scores = 2 * precision * recall / (precision + recall + 1e-6)
        optimal_idx = np.argmax(f1_scores)
        optimal_threshold = thresholds[optimal_idx]
        f1_positive = f1_scores[optimal_idx]
        print(f'Epoch [{epoch+1}/{num_epochs}], F1 Positive: {f1_positive:.6f}')

        # Store the optimal threshold
        trial.set_user_attr('optimal_threshold', optimal_threshold)

        # Report intermediate result to Optuna
        trial.report(f1_positive, epoch)

        # Check if the trial should be pruned
        if trial.should_prune():
            raise optuna.TrialPruned()

    return f1_positive  # Optuna will maximize this value

# Create an Optuna study with MedianPruner
# The pruner waits until at least 3 trials have fully completed before starting to prune subsequent trials
# The pruner ignores pruning for the first 3 epochs of each trial

study = optuna.create_study(
    direction='maximize',
    pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=3)
)

# Optimize the objective function
study.optimize(objective, n_trials=50, callbacks=[log_and_save_best_trial])

# Display the best hyperparameters
print('Best trial:')
trial = study.best_trial

print(f'  F1-Score (Positive Class): {trial.value}')
print('  Best hyperparameters:')
for key, value in trial.params.items():
    print(f'    {key}: {value}')


[I 2024-11-23 11:19:30,391] A new study created in memory with name: no-name-a80077f5-0545-43fb-b14c-fa1a8f6e3ec2


Epoch [1/11], Training Loss: 0.653435
Epoch [1/11], F1 Positive: 0.645463
Epoch [2/11], Training Loss: 0.648978
Epoch [2/11], F1 Positive: 0.646067
Epoch [3/11], Training Loss: 0.649246
Epoch [3/11], F1 Positive: 0.646067
Epoch [4/11], Training Loss: 0.644916
Epoch [4/11], F1 Positive: 0.646616


In [1]:
# # using optuna to determine best hyperparameters
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import Dataset, DataLoader
# from torchvision import transforms
# from PIL import Image
# import numpy as np
# from sklearn.metrics import f1_score
# import optuna
# from sklearn.metrics import confusion_matrix, classification_report
# from sklearn.metrics import precision_recall_curve
# import random

# # Function to initialize the model
# def initialize_biomedclip_model():
#     # Replace with your actual model initialization code
#     # Using your provided code
#     biomedclip_model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
#     biomedclip_model.to(device)
#     tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
#     return biomedclip_model, preprocess, tokenizer

# # Define a callback function to log the best trial so far
# def log_and_save_best_trial(study, trial):
#     if study.best_trial == trial:
#         # Log the best trial so far
#         print(f"[{trial.datetime_start}] New best trial: Trial {trial.number}")
#         print(f"  F1-Score: {trial.value}")
#         print(f"  Parameters: {trial.params}")

#         # Save the best trial to a file
#         with open("best_trial_so_far.txt", "w") as f:
#             f.write(f"Best trial so far: Trial {trial.number}\n")
#             f.write(f"  F1-Score: {trial.value}\n")
#             for key, value in trial.params.items():
#                 f.write(f"  {key}: {value}\n")

# # Define the Focal Contrastive Loss
# class FocalContrastiveLoss(nn.Module):
#     def __init__(self, margin, gamma=2.0, pos_weight=1.0):
#         super(FocalContrastiveLoss, self).__init__()
#         self.margin = margin
#         self.gamma = gamma
#         self.pos_weight = pos_weight

#     def forward(self, embedding1, embedding2, label):
#         cos_sim = F.cosine_similarity(embedding1, embedding2)
#         cos_dist = 1 - cos_sim

#         pos_focal = (1 - cos_sim).pow(self.gamma)
#         neg_focal = cos_sim.pow(self.gamma)

#         pos_loss = self.pos_weight * label * pos_focal * cos_dist.pow(2)
#         neg_loss = (1 - label) * neg_focal * F.relu(self.margin - cos_dist).pow(2)

#         loss = pos_loss + neg_loss
#         return loss.mean()

# # Define the Optuna objective function
# def objective(trial):
#     # Set random seed for reproducibility
#     seed = 42
#     torch.manual_seed(seed)
#     np.random.seed(seed)
#     random.seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed_all(seed)

#     # Suggest hyperparameters
#     margin = trial.suggest_float('margin', 0.5, 0.6, step=0.05)
#     pos_weight = trial.suggest_float('pos_weight', 6.0, 10.0, step=1.0)
#     learning_rate = trial.suggest_float('lr', 1e-7, 1e-6, log=True)
#     gamma = trial.suggest_int('gamma', 0, 0)

#     # Initialize the model within the objective function
#     biomedclip_model, preprocess, tokenizer = initialize_biomedclip_model()

#     # Initialize the image embedding model
#     image_embedding_model = ImageEmbeddingModel(biomedclip_model).to(device)

#     # Define the criterion and optimizer
#     criterion = FocalContrastiveLoss(margin=margin, gamma=gamma, pos_weight=pos_weight)
#     optimizer = torch.optim.Adam(image_embedding_model.parameters(), lr=learning_rate)

#     num_epochs = 11  # Adjust as needed

#     for epoch in range(num_epochs):
#         image_embedding_model.train()
#         train_loss = 0

#         for (img1, img2), labels in train_dataloader:
#             img1 = img1.to(device)
#             img2 = img2.to(device)
#             labels = labels.to(device)

#             optimizer.zero_grad()

#             embeddings1 = image_embedding_model(img1)
#             embeddings2 = image_embedding_model(img2)

#             loss = criterion(embeddings1, embeddings2, labels)
#             train_loss += loss.item()

#             loss.backward()
#             optimizer.step()

#         avg_train_loss = train_loss / len(train_dataloader)
#         print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.6f}')

#         # Validation loop
#         image_embedding_model.eval()
#         all_cos_sims = []
#         all_labels = []
#         with torch.no_grad():
#             for (img1, img2), labels in test_dataloader:
#                 img1 = img1.to(device)
#                 img2 = img2.to(device)
#                 labels = labels.to(device)

#                 embeddings1 = image_embedding_model(img1)
#                 embeddings2 = image_embedding_model(img2)

#                 cos_sim = F.cosine_similarity(embeddings1, embeddings2)
#                 all_cos_sims.extend(cos_sim.cpu().numpy())
#                 all_labels.extend(labels.cpu().numpy())

#         # Convert to numpy arrays
#         all_cos_sims = np.array(all_cos_sims)
#         all_labels = np.array(all_labels)

#         # Compute F1-score for positive class
#         precision, recall, thresholds = precision_recall_curve(all_labels, all_cos_sims)
#         f1_scores = 2 * precision * recall / (precision + recall + 1e-6)
#         optimal_idx = np.argmax(f1_scores)
#         optimal_threshold = thresholds[optimal_idx]
#         f1_positive = f1_scores[optimal_idx]
#         print(f'Epoch [{epoch+1}/{num_epochs}], F1 Positive: {f1_positive:.6f}')

#         # Store the optimal threshold
#         trial.set_user_attr('optimal_threshold', optimal_threshold)

#         # Report intermediate result to Optuna
#         trial.report(f1_positive, epoch)

#     return f1_positive  # Optuna will maximize this value

# # Create an Optuna study
# # study = optuna.create_study(direction='maximize', pruner=optuna.pruners.MedianPruner())
# study = optuna.create_study(direction='maximize')

# # Optimize the objective function
# study.optimize(objective, n_trials=50, callbacks=[log_and_save_best_trial])

# # Display the best hyperparameters
# print('Best trial:')
# trial = study.best_trial

# print(f'  F1-Score (Positive Class): {trial.value}')
# print('  Best hyperparameters:')
# for key, value in trial.params.items():
#     print(f'    {key}: {value}')


In [None]:
# Retrain your model on the full training set
num_epochs = 10  # Increase the number of epochs for final training

for epoch in range(num_epochs):
    image_embedding_model.train()
    train_loss = 0

    for (img1, img2), labels in train_dataloader:
        img1 = img1.to(device)
        img2 = img2.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        embeddings1 = image_embedding_model(img1)
        embeddings2 = image_embedding_model(img2)

        loss = criterion(embeddings1, embeddings2, labels)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss/len(train_dataloader):.6f}')

# Evaluate on the validation set
image_embedding_model.eval()
all_labels = []
all_predictions = []
with torch.no_grad():
    for (img1, img2), labels in test_dataloader:
        img1 = img1.to(device)
        img2 = img2.to(device)
        labels = labels.to(device)

        embeddings1 = image_embedding_model(img1)
        embeddings2 = image_embedding_model(img2)

        cos_sim = F.cosine_similarity(embeddings1, embeddings2)
        predictions = (cos_sim > 0.5).float()  # Use the threshold that worked best

        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predictions.cpu().numpy())

# Compute final F1-score for positive class
all_labels = np.array(all_labels)
all_predictions = np.array(all_predictions)
f1_positive = f1_score(all_labels, all_predictions, pos_label=1, average='binary')

print(f'Final F1-Score for Positive Class: {f1_positive:.4f}')
