# Data Preprocessing & Feature Extraction

In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (1

In [2]:
import os
import re
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
from datasets import Dataset, DatasetDict
import pandas as pd
import os
import requests
import matplotlib.pyplot as plt
import random
import torch
import numpy as np
import torchvision.transforms as transforms
from google.colab import drive
drive.mount('/content/drive')

def set_seed(seed_value=42):
    """Set seed for reproducibility for PyTorch and NumPy.

    Args:
        seed_value (int): The seed value to set for random number generators.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(42)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# 1) Project path
path_project = '/content/drive/MyDrive/Project'

# 2) Load + preprocess + split
def vqa_rad_setup(path_project):
    json_file = "VQA_RAD Dataset Public.json"
    image_folder = "VQA_RAD Image Folder"
    json_path = os.path.join(path_project, json_file)
    image_dir  = os.path.join(path_project, image_folder)



    df = pd.read_json(json_path)
    # build full image paths
    df['image_path'] = df['image_name'].apply(lambda fn: os.path.join(image_dir, fn))

    total_rows     = len(df)
    print(f"Total rows in df_binary:      {total_rows}")


    num_organs = df['image_organ'].nunique()
    print(f"There are {num_organs} distinct organs.")

    # 3) if you want to see them all:
    print("Organ list:", df['image_organ'].unique())

    # filter to yes/no and map to 0/1
    df['answer'] = df['answer'].str.strip().str.lower()
    df_binary = df[df['answer'].isin(['yes', 'no'])].copy()
    df_binary['label'] = df_binary['answer'].map({'yes': 1, 'no': 0})

    # shuffle, select columns, print
    df_binary = (
        df_binary[['image_path', 'question', 'label']]
        .sample(frac=1, random_state=42)
        .reset_index(drop=True)
    )

    total_rows     = len(df_binary)
    print(f"Total rows in df_binary:      {total_rows}")
    return df_binary


In [4]:
df_binary = vqa_rad_setup(path_project)

df_binary.head(10)

Total rows in df_binary:      2248
There are 3 distinct organs.
Organ list: ['HEAD' 'CHEST' 'ABD']
Total rows in df_binary:      1193


Unnamed: 0,image_path,question,label
0,/content/drive/MyDrive/Project/VQA_RAD Image F...,Are the patients' ribs symmetric on both sides?,0
1,/content/drive/MyDrive/Project/VQA_RAD Image F...,Are there cilia present at the level of alveoli?,0
2,/content/drive/MyDrive/Project/VQA_RAD Image F...,Is this coronal plane?,1
3,/content/drive/MyDrive/Project/VQA_RAD Image F...,Is the patient lying down?,1
4,/content/drive/MyDrive/Project/VQA_RAD Image F...,Do you see a cavitary lesion in this chest xray?,1
5,/content/drive/MyDrive/Project/VQA_RAD Image F...,Is there free air under the diaphragm?,0
6,/content/drive/MyDrive/Project/VQA_RAD Image F...,is there tracheal deviation?,0
7,/content/drive/MyDrive/Project/VQA_RAD Image F...,Is this in the lumbar vertebral level?,1
8,/content/drive/MyDrive/Project/VQA_RAD Image F...,Does this patient have a pneumothorax?,0
9,/content/drive/MyDrive/Project/VQA_RAD Image F...,Was this patient given IV contrast?,1


Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 225MB/s]


In [None]:
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm

# 3) Instantiate your MedCLIP processor
##model_name = "openai/clip-vit-base-patch32"  # You can choose different CLIP model variants
#processor = CLIPProcessor.from_pretrained(model_name)


def densenet_processor():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

def transform_image_pre(image_path):
    image = Image.open(image_path)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    return densenet_processor()(image)


# 6) Batch‑load all train images into one big tensor
image_tensors = []
failed_images = []

for img_path in tqdm(df_binary['image_path'], desc="Processing images"):
    try:
        tensor = transform_image_pre(img_path)
        image_tensors.append(tensor)
    except Exception as e:
        print(f"❌ Skipped {img_path}: {e}")
        failed_images.append(img_path)

full_image_tensor = torch.stack(image_tensors, dim=0)
print(f"\nProcessed {len(image_tensors)} images; skipped {len(failed_images)}.")
print("full_image_tensor.shape =", full_image_tensor.shape)


Processing images: 100%|██████████| 1193/1193 [00:12<00:00, 97.26it/s]



Processed 1193 images; skipped 0.
full_image_tensor.shape = torch.Size([1193, 3, 224, 224])


In [None]:
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict



def _bn_function_factory(norm, relu, conv):
    def bn_function(*inputs):
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = conv(relu(norm(concated_features)))
        return bottleneck_output

    return bn_function


class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                                           growth_rate, kernel_size=1, stride=1,
                                           bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1,
                                           bias=False)),
        self.drop_rate = drop_rate
        self.memory_efficient = memory_efficient

    def forward(self, *prev_features):
        bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
        if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
            bottleneck_output = cp.checkpoint(bn_function, *prev_features)
        else:
            bottleneck_output = bn_function(*prev_features)
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate,
                                     training=self.training)
        return new_features


class _DenseBlock(nn.Module):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
            self.add_module('denselayer%d' % (i + 1), layer)

    def forward(self, init_features):
        features = [init_features]
        for name, layer in self.named_children():
            new_features = layer(*features)
            features.append(new_features)
        return torch.cat(features, 1)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet121(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_featuremaps (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
          but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
    """

    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_featuremaps=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False,
                 grayscale=False):

        super(DenseNet121, self).__init__()

        # First convolution
        if grayscale:
            in_channels=1
        else:
            in_channels=3

        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(in_channels=in_channels, out_channels=num_init_featuremaps,
                                kernel_size=7, stride=2,
                                padding=3, bias=False)), # bias is redundant when using batchnorm
            ('norm0', nn.BatchNorm2d(num_features=num_init_featuremaps)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_featuremaps
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient
            )
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        logits = self.classifier(out)
        probas = F.softmax(logits, dim=1)
        return logits, probas

In [None]:
import torch
import numpy as np
from transformers import BertTokenizer

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size
print(f"Loaded BERT tokenizer with vocab size: {vocab_size}")


# Set max sequence length for tokenization
max_seq_length = 64  # or any other value appropriate for your dataset

# Tokenize all questions
all_tokenized_texts = []
for description in df_binary['question']:
    tokenized_text = tokenizer(
        description,
        padding='max_length',
        truncation=True,
        max_length=max_seq_length,
        return_tensors="pt"
    )
    all_tokenized_texts.append(tokenized_text)

print(f"\nTokenized {len(all_tokenized_texts)} questions.")

# Stack all input_ids and attention_mask tensors
Text_tensor = torch.cat([item['input_ids'] for item in all_tokenized_texts], dim=0)  # [N, max_seq_length]
Text_attention_mask_tensor = torch.cat([item['attention_mask'] for item in all_tokenized_texts], dim=0)

# Final checks
print(f"\n✅ Shape of Text_tensor: {Text_tensor.shape}")                             # [N, max_seq_length]
print(f"✅ Shape of Attention Mask Tensor: {Text_attention_mask_tensor.shape}")       # [N, max_seq_length]
print(f"ℹ️  Max token ID: {Text_tensor.max().item()} < Vocab Size: {vocab_size}? {'Yes ✅' if Text_tensor.max().item() < vocab_size else 'No ❌'}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Loaded BERT tokenizer with vocab size: 30522

Tokenized 1193 questions.

✅ Shape of Text_tensor: torch.Size([1193, 64])
✅ Shape of Attention Mask Tensor: torch.Size([1193, 64])
ℹ️  Max token ID: 29561 < Vocab Size: 30522? Yes ✅


In [None]:
# we use torch for modeling neural networks including CNNs
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

In [None]:
import torch
import torch.nn as nn

class BiLSTM(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(
            input_size=emb_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0.0
        )
        # project bidirectional hidden to single vector
        self.fc = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, input_ids, attention_mask):
        # input_ids: [B, T], attention_mask: [B, T]
        embeds = self.embedding(input_ids)                      # [B, T, emb_dim]
        lengths = attention_mask.sum(dim=1).cpu()               # actual lengths
        packed = nn.utils.rnn.pack_padded_sequence(
            embeds, lengths, batch_first=True, enforce_sorted=False
        )
        packed_out, (h_n, _) = self.lstm(packed)
        # h_n: [num_layers*2, B, hidden_dim]
        # take last forward & backward layers
        h_fwd = h_n[-2]                                         # [B, hidden_dim]
        h_bwd = h_n[-1]                                         # [B, hidden_dim]
        h_cat = torch.cat([h_fwd, h_bwd], dim=1)                # [B, hidden_dim*2]
        out = self.fc(h_cat)                                    # [B, hidden_dim]
        return out


In [None]:
from torch.utils.data import TensorDataset, DataLoader

In [None]:

y_train_tensor = torch.tensor(df_binary['label'], dtype=torch.float32)

# 1) Instantiate your BiLSTM text encoder
hidden_dim = 256
text_model = BiLSTM(
    vocab_size=vocab_size,
    emb_dim=256,
    hidden_dim=hidden_dim,
    num_layers=2,
    dropout=0.2
).to(device).eval()

# 2) Load DenseNet121 as image feature extractor
from torchvision import models
densenet = models.densenet121(pretrained=True).to(device).eval()
feature_extractor = nn.Sequential(
    densenet.features,
    nn.ReLU(inplace=True),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(start_dim=1)
).to(device).eval()

# 3) Feature extraction function
def data_all_features(image_data, text_data, text_attention):
    image_features = []
    text_features = []

    for img, input_ids, attention_mask in tqdm(zip(image_data, text_data, text_attention), total=len(image_data)):
        img = img.unsqueeze(0).to(device)
        input_ids = input_ids.unsqueeze(0).to(device)
        attention_mask = attention_mask.unsqueeze(0).to(device)

        with torch.no_grad():
            # Image: [1, 3, 224, 224] → [1, 1024]
            img_feat = feature_extractor(img)

            # Text: [1, seq_len] → [1, hidden_dim]
            txt_feat = text_model(input_ids, attention_mask)

        image_features.append(img_feat.squeeze(0))
        text_features.append(txt_feat.squeeze(0))

    return torch.stack(image_features), torch.stack(text_features)

# 4) Run feature extraction
image_features, text_features = data_all_features(
    full_image_tensor, Text_tensor, Text_attention_mask_tensor
)

# 5) Check shapes
print(f"✅ image_features shape: {image_features.shape}")  # [N, 1024]
print(f"✅ text_features shape: {text_features.shape}")    # [N, 256]

N = 75

Image_train = image_features[N:]
Image_test = image_features[:N]
Text_train = text_features[N:]
Text_test = text_features[:N]
y_train = y_train_tensor[N:]
y_test = y_train_tensor[:N]

## Dataloader with (image, text, labels)
batch_size = 64
train_dataset = TensorDataset(Image_train,Text_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = TensorDataset(Image_test,Text_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)



Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 134MB/s]
100%|██████████| 1193/1193 [00:26<00:00, 45.71it/s]

✅ image_features shape: torch.Size([1193, 1024])
✅ text_features shape: torch.Size([1193, 256])





In [None]:
print(train_loader.dataset.tensors[0].shape)
print(train_loader.dataset.tensors[1].shape)
print(train_loader.dataset.tensors[2].shape)

torch.Size([1118, 1024])
torch.Size([1118, 256])
torch.Size([1118])


In [None]:
import torch
import torch.nn as nn

class CoAttentionFusionClassifier(nn.Module):
    def __init__(
        self,
        image_dim: int,
        text_dim: int,
        fusion_dim: int = 512,
        num_heads: int = 8,
        dropout: float = 0.1,
    ):
        super().__init__()
        # project each modality into the same hidden space
        self.img_proj = nn.Linear(image_dim, fusion_dim)
        self.txt_proj = nn.Linear(text_dim, fusion_dim)

        # cross‑attention layers
        # image queries, text keys/values
        self.attn_img2txt = nn.MultiheadAttention(
            embed_dim=fusion_dim, num_heads=num_heads, dropout=dropout, batch_first=True
        )
        # text queries, image keys/values
        self.attn_txt2img = nn.MultiheadAttention(
            embed_dim=fusion_dim, num_heads=num_heads, dropout=dropout, batch_first=True
        )

        # classification head
        self.dropout    = nn.Dropout(dropout)
        self.classifier = nn.Linear(fusion_dim, 1)

    def forward(self, img_feat: torch.Tensor, txt_feat: torch.Tensor):
        """
        img_feat: (B, image_dim)
        txt_feat: (B, text_dim)
        """
        # 1) project and add sequence dim → (B, 1, fusion_dim)
        img = self.img_proj(img_feat).unsqueeze(1)
        txt = self.txt_proj(txt_feat).unsqueeze(1)

        # 2) co‑attention
        # img attends to text:
        img2txt, _ = self.attn_img2txt(query=img, key=txt, value=txt)
        # text attends to image:
        txt2img, _ = self.attn_txt2img(query=txt, key=img, value=img)

        # 3) fuse by averaging both attended outputs and pooling the sequence dim
        # each is (B,1,fusion_dim) → stack → (B,2,fusion_dim)
        fused_seq = torch.cat([img2txt, txt2img], dim=1)
        fused     = fused_seq.mean(dim=1)    # (B, fusion_dim)

        # 4) classification
        x      = self.dropout(fused)
        logit  = self.classifier(x).squeeze(1)  # (B,)
        return logit


In [None]:

image_dim = image_features.shape[1]
text_dim  = text_features.shape[1]



In [None]:
from sklearn.metrics import roc_auc_score

rs = [1e-3, 1e-4, 1e-5]
dropouts = [0.1, 0.2, 0.3]
n_epochs = 50 # fewer epochs for quick tuning
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2) helper to train & eval one config
def run_trial(lr, dropout, img2txt_mult, txt2img_mult):
    # re‑init model
    model = CoAttentionFusionClassifier(
        image_dim=image_dim,
        text_dim=text_dim,
        fusion_dim=512,
        num_heads=64,
        dropout=dropout
    ).to(device)

    # 2) Define Parameter Groups (this is where you change the LR multipliers):
    param_groups = [
        {'params': [p for n, p in model.named_parameters() if 'attn_img2txt' in n], 'lr': lr * img2txt_mult}, # Fixed multiplier
        {'params': [p for n, p in model.named_parameters() if 'attn_txt2img' in n], 'lr': lr * txt2img_mult}, # Fixed multiplier
        {'params': [p for n, p in model.named_parameters() if 'attn' not in n]}, # Default LR for other params
    ]

    # 3) Initialize Optimizer with Parameter Groups:
    opt = optim.Adam(param_groups, lr=lr)
    crit = nn.BCEWithLogitsLoss()
    best_f1 = 0.0
    best_acc = 0.0
    best_auroc = 0.0

    for epoch in range(1, n_epochs+1):
        # train one epoch
        model.train()
        for imgs, txts, labs in train_loader:
            imgs, txts = imgs.to(device), txts.to(device)
            labs = labs.to(device).float()
            opt.zero_grad()
            logits = model(imgs, txts)
            loss = crit(logits, labs)
            loss.backward()
            opt.step()

        # eval
        model.eval()
        all_p, all_l = [], []
        with torch.no_grad():
            for imgs, txts, labs in test_loader:
                imgs, txts = imgs.to(device), txts.to(device)
                logits = model(imgs, txts)
                preds = (torch.sigmoid(logits) > 0.5).long()
                all_p.extend(preds.cpu().tolist())
                all_l.extend(labs.tolist())

        f1 = f1_score(all_l, all_p)
        acc = accuracy_score(all_l, all_p)
        auroc = roc_auc_score(all_l, all_p)
        best_f1 = max(best_f1, f1)
        best_acc = max(best_acc, acc)
        best_auroc = max(best_auroc, auroc)

    return best_f1, best_acc, best_auroc

# 3) grid‐search
# You also need to define img2txt_mults and txt2img_mults
img2txt_mults = [0.5, 1.0, 2.0, 3.0]
txt2img_mults = [0.5, 1.0, 2.0, 5.0]

results = []
for lr, dp, img2txt_mult, txt2img_mult in product(lrs, dropouts, img2txt_mults, txt2img_mults):
    f1, acc, auroc = run_trial(lr, dp, img2txt_mult, txt2img_mult)  # Get all metrics
    print(f" → lr={lr:.0e}, dropout={dp:.1f}, img2txt_mult={img2txt_mult}, txt2img_mult={txt2img_mult} → best Val F1 = {f1:.4f}, best Val Acc = {acc:.4f}, best Val AUROC = {auroc:.4f}\n")
    results.append((f1, acc, auroc, lr, dp, img2txt_mult, txt2img_mult))  # Store all metrics

# 4) pick best
best_result = max(results, key=lambda x: x[0])  # Find best by F1 score
best_f1, best_acc, best_auroc, best_lr, best_dp, best_img2txt_mult, best_txt2img_mult = best_result

# Print Best Configuration with LR Multipliers and AUROC:
print(f">>> Best config by F1: lr={best_lr:.0e}, dropout={best_dp:.1f}, img2txt_mult={best_img2txt_mult}, txt2img_mult={best_txt2img_mult} with Val F1={best_f1:.4f} (Acc={best_acc:.4f}, AUROC={best_auroc:.4f})")

 → lr=1e-03, dropout=0.1, img2txt_mult=0.5, txt2img_mult=0.5 → best Val F1 = 0.7073, best Val Acc = 0.6800, best Val AUROC = 0.6848

 → lr=1e-03, dropout=0.1, img2txt_mult=0.5, txt2img_mult=1.0 → best Val F1 = 0.7059, best Val Acc = 0.6800, best Val AUROC = 0.6795

 → lr=1e-03, dropout=0.1, img2txt_mult=0.5, txt2img_mult=2.0 → best Val F1 = 0.7473, best Val Acc = 0.6933, best Val AUROC = 0.7030

 → lr=1e-03, dropout=0.1, img2txt_mult=0.5, txt2img_mult=5.0 → best Val F1 = 0.7253, best Val Acc = 0.7067, best Val AUROC = 0.7073

 → lr=1e-03, dropout=0.1, img2txt_mult=1.0, txt2img_mult=0.5 → best Val F1 = 0.7273, best Val Acc = 0.7067, best Val AUROC = 0.7073

 → lr=1e-03, dropout=0.1, img2txt_mult=1.0, txt2img_mult=1.0 → best Val F1 = 0.7294, best Val Acc = 0.6933, best Val AUROC = 0.6998

 → lr=1e-03, dropout=0.1, img2txt_mult=1.0, txt2img_mult=2.0 → best Val F1 = 0.7111, best Val Acc = 0.6800, best Val AUROC = 0.6795

 → lr=1e-03, dropout=0.1, img2txt_mult=1.0, txt2img_mult=5.0 → best V