# Concatenation Fusion Model

## Imports & Setup

In [1]:
import os
import random
import json
import numpy as np
import pandas as pd
import pydicom
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    roc_auc_score,
    average_precision_score,
    confusion_matrix
)
from transformers import (
    BertTokenizer,
    BertModel,
    get_linear_schedule_with_warmup,
    ViTModel
)
from torchvision import transforms

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


## Load Paths & Label

In [2]:
image_dir = "/mnt/e/ecs289l/mimic-cxr-download/imageData/"
report_dir = "/mnt/e/ecs289l/mimic-cxr-download/textData/"
labels_file = "../download_data/metadata/edema+pleural_effusion_samples_v2.csv"
model_name_text = 'dmis-lab/biobert-base-cased-v1.1'
model_name_vision = 'google/vit-base-patch16-224-in21k'
max_length = 128

# ---------------------
# Labels and File Loading
# ---------------------
# Load metadata
meta = pd.read_csv(labels_file, dtype={'study_id': str})
meta['study_id'] = 's' + meta['study_id']
label_map = meta.set_index('study_id')[['edema', 'effusion']].to_dict(orient='index')

# Collect image paths and labels
all_image_paths = []
for root, _, files in os.walk(image_dir):
    for f in files:
        if f.endswith('.dcm'):
            all_image_paths.append(os.path.join(root, f))
paths, labels = [], []
for p in all_image_paths:
    sid = os.path.basename(os.path.dirname(p))
    if sid in label_map:
        paths.append(p)
        labels.append(label_map[sid]['edema'] + label_map[sid]['effusion']*2)  # placeholder, we will use list below
# Actually build multi-label list
labels = [ [label_map[os.path.basename(os.path.dirname(p))]['edema'],
            label_map[os.path.basename(os.path.dirname(p))]['effusion']]
          for p in paths ]


## Dataset Tokenizer and Transforms

In [3]:
tokenizer = BertTokenizer.from_pretrained(model_name_text)
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(0.2,0.2,0.2,0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

## Data Preprocess and Caching

In [4]:
import os
import torch
import pydicom
from PIL import Image
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

cache_dir = "./cached_data"
os.makedirs(cache_dir, exist_ok=True)

def process_and_save(i, path, label):
    try:
        # Ignore if the file already exists
        cache_path = os.path.join(cache_dir, f'{i}.pt')
        if os.path.exists(cache_path):
            return
        # Image
        dcm = pydicom.dcmread(path)
        arr = dcm.pixel_array.astype(np.float32)
        img = Image.fromarray((arr / arr.max() * 255).astype(np.uint8)).convert('RGB')
        img_tensor = transform(img)

        # Text
        sid = os.path.basename(os.path.dirname(path))
        report_path = os.path.join(report_dir, sid, 'report.txt')
        with open(report_path, 'r', encoding='utf-8') as f:
            text = f.read()
        encoding = tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        # Label
        label_tensor = torch.tensor(label, dtype=torch.float32)

        # Save all
        torch.save({
            'pixel_values': img_tensor,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': label_tensor
        }, os.path.join(cache_dir, f'{i}.pt'))

    except Exception as e:
        print(f"Failed to process index {i}: {e}")

# Set number of threads based on your CPU (e.g., 8 or 16)
num_threads = 8

with ThreadPoolExecutor(max_workers=num_threads) as executor:
    futures = [
        executor.submit(process_and_save, i, path, label)
        for i, (path, label) in enumerate(zip(paths, labels))
    ]

    for _ in tqdm(as_completed(futures), total=len(futures), desc="Caching"):
        pass


Caching: 100%|██████████| 7199/7199 [00:02<00:00, 3564.01it/s]


## Fusion Dataset

In [5]:
class FusionDataset(Dataset):
    def __init__(self, image_paths, report_dir, labels_map, tokenizer, max_length, transform=None):
        self.image_paths = image_paths
        self.report_dir = report_dir
        self.labels_map = labels_map
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        p = self.image_paths[idx]
        dcm = pydicom.dcmread(p)
        arr = dcm.pixel_array.astype(np.float32)
        img = Image.fromarray((arr/arr.max()*255).astype(np.uint8)).convert('RGB')
        if self.transform:
            img = self.transform(img)
        sid = os.path.basename(os.path.dirname(p))
        report_path = os.path.join(self.report_dir, sid, 'report.txt')
        with open(report_path, 'r', encoding='utf-8') as f:
            text = f.read()
        encoding = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        labels = list(self.labels_map[sid].values())
        return {
            'pixel_values': img,
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(labels, dtype=torch.float32)
        }

# Custom collate function to batch fusion data
def fusion_collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {
        'pixel_values': pixel_values,
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

class CachedFusionDataset(Dataset):
    def __init__(self, cache_dir, indices):
        self.cache_dir = cache_dir
        self.indices = indices  # list of indices used in pre-split sets

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        data = torch.load(os.path.join(self.cache_dir, f'{self.indices[idx]}.pt'))
        return data


## Split and DataLoaders

In [6]:
# train_data, test_val_data, train_labels, test_val_labels = train_test_split(
#     paths, labels, test_size=0.3, random_state=SEED, shuffle=True, stratify=labels
# )
# test_data, val_data, test_labels, val_labels = train_test_split(
#     test_val_data, test_val_labels, test_size=1/3, random_state=SEED, shuffle=True, stratify=test_val_labels
# )

indices = list(range(len(paths)))
train_idx, test_val_idx, train_labels, test_val_labels = train_test_split(
    indices, labels, test_size=0.3, random_state=SEED, shuffle=True, stratify=labels
)
test_idx, val_idx, test_labels, val_labels = train_test_split(
    test_val_idx, test_val_labels, test_size=1/3, random_state=SEED, shuffle=True, stratify=test_val_labels
)

# print("Train / Val / Test sizes:", len(train_data), len(val_data), len(test_data))
print("Train / Val / Test sizes:", len(train_idx), len(val_idx), len(test_idx))


batch_size = 32
# train_ds = FusionDataset(train_data, report_dir, label_map, tokenizer, max_length, transform)
# val_ds   = FusionDataset(val_data,   report_dir, label_map, tokenizer, max_length, transform)
# test_ds  = FusionDataset(test_data,  report_dir, label_map, tokenizer, max_length, transform)

train_ds = CachedFusionDataset(cache_dir, train_idx)
val_ds = CachedFusionDataset(cache_dir, val_idx)
test_ds = CachedFusionDataset(cache_dir, test_idx)

def get_loader(ds, bs, shuffle=False):
    return DataLoader(ds, batch_size=bs, shuffle=shuffle, collate_fn=fusion_collate_fn, num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=2)


Train / Val / Test sizes: 5039 720 1440


## Concatenation Fusion Model

In [7]:
class FusionModel(nn.Module):
    def __init__(self, vision_model_name, text_model_name, vision_drop=0.1, text_drop=0.1, hidden_dim=256):
        super().__init__()
        self.image_model = ViTModel.from_pretrained(vision_model_name)
        self.text_model  = BertModel.from_pretrained(text_model_name)
        self.vision_dropout = nn.Dropout(vision_drop)
        self.text_dropout   = nn.Dropout(text_drop)
        img_dim = self.image_model.config.hidden_size
        txt_dim = self.text_model.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(img_dim+txt_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, 2)
        )
    def forward(self, pixel_values, input_ids, attention_mask):
        img_out = self.image_model(pixel_values=pixel_values).pooler_output
        txt_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        fusion = torch.cat([self.vision_dropout(img_out), self.text_dropout(txt_out)], dim=1)
        logits = self.classifier(fusion)
        return logits


## Metrics

In [8]:
def compute_metrics(y_true, y_pred, y_probs):
    metrics = {}
    for i, name in enumerate(['edema','effusion']):
        acc = accuracy_score(y_true[:,i], y_pred[:,i])
        precision, recall, f1, _ = precision_recall_fscore_support(
            y_true[:,i], y_pred[:,i], zero_division=0
        )
        try:
            auroc = roc_auc_score(y_true[:,i], y_probs[:,i])
        except ValueError:
            auroc = float('nan')
        try:
            auprc = average_precision_score(y_true[:,i], y_probs[:,i])
        except ValueError:
            auprc = float('nan')
        tn, fp, fn, tp = confusion_matrix(y_true[:,i], y_pred[:,i]).ravel()
        sens = tp/(tp+fn) if (tp+fn)>0 else 0.0
        spec = tn/(tn+fp) if (tn+fp)>0 else 0.0
        metrics[name] = {
            'accuracy': acc,
            'precision': precision.tolist(),
            'recall': recall.tolist(),
            'f1': f1.tolist(),
            'auroc': auroc,
            'auprc': auprc,
            'sensitivity': sens,
            'specificity': spec
        }
    return metrics

## Training and Validation Loop

In [9]:
from tqdm import tqdm

def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    for batch in tqdm(loader, desc="Training"):
        optimizer.zero_grad()
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        logits = model(pixel_values, input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def validate_epoch(model, loader, criterion):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            logits = model(pixel_values, input_ids, attention_mask)
            loss = criterion(logits, labels)
            total_loss += loss.item()
    avg_loss = total_loss / len(loader)

    return avg_loss


## Evaluation

In [10]:
from tqdm import tqdm

def evaluate(model, loader):
    model.eval()
    all_labels, all_preds, all_probs = [], [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()
            logits = model(pixel_values, input_ids, attention_mask).cpu().numpy()
            probs = torch.sigmoid(torch.tensor(logits)).numpy()
            preds = (probs > 0.5).astype(int)
            all_labels.append(labels)
            all_preds.append(preds)
            all_probs.append(probs)
    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    y_probs = np.vstack(all_probs)
    return compute_metrics(y_true, y_pred, y_probs)

## Hyperparameter Combinations

In [11]:
hyperparameter_combinations = []
for vision_drop in [0.1, 0.2]:
    for text_drop in [0.1, 0.2]:
        for lr in [1e-5, 5e-5, 2e-4]:
            for wd in [0, 0.01, 0.1]:
                for bs in [32]:
                    hyperparameter_combinations.append({
                        'vision_drop': vision_drop,
                        'text_drop':   text_drop,
                        'learning_rate': lr,
                        'weight_decay': wd,
                        'batch_size': bs,
                        'num_epochs': 20
                    })

results_file = 'concatenation_fusion_results.json'
if not os.path.exists(results_file):
    with open(results_file, 'w') as f:
        json.dump([], f)

def execute_hyperparameter_combo(combo):
    name = f"VD{combo['vision_drop']}_TD{combo['text_drop']}_LR{combo['learning_rate']}_WD{combo['weight_decay']}_BS{combo['batch_size']}_EP{combo['num_epochs']}"
    print(f"Running combo: {name}")

    train_loader = get_loader(train_ds, combo['batch_size'], shuffle=True)
    val_loader = get_loader(val_ds, combo['batch_size'])
    test_loader = get_loader(test_ds, combo['batch_size'])

    model = FusionModel(
        vision_model_name=model_name_vision,
        text_model_name=model_name_text,
        vision_drop=combo['vision_drop'],
        text_drop=combo['text_drop']
    ).to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=combo['learning_rate'],
        weight_decay=combo['weight_decay']
    )

    total_steps = len(train_loader) * combo['num_epochs']
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1*total_steps),
        num_training_steps=total_steps
    )

    best_val_loss = float('inf')
    patience = 3
    no_improve = 0

    for epoch in range(1, combo['num_epochs']+1):
        train_loss = train_epoch(model, train_loader, optimizer, criterion)
        val_loss = validate_epoch(model, val_loader, criterion)
        print(f"Epoch {epoch}/{combo['num_epochs']} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        scheduler.step()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    test_metrics = evaluate(model, test_loader)
    with open(results_file, 'r') as f:
        results = json.load(f)
    results.append({'name': name, 'combo': combo, 'metrics': test_metrics})
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=4)
    print(f"Saved results for {name}\n")

In [13]:
combo = {
    "vision_drop": 0.2,
    "text_drop": 0.2,
    "weight_decay": 0.01,
    "learning_rate": 5e-05,
    "batch_size": 32,
    "num_epochs": 20,
    "patience": 3
}
execute_hyperparameter_combo(combo)

Running combo: VD0.2_TD0.2_LR5e-05_WD0.01_BS32_EP20


Training: 100%|██████████| 158/158 [12:52<00:00,  4.89s/it]
Validation: 100%|██████████| 23/23 [00:45<00:00,  1.99s/it]


Epoch 1/20 - Train Loss: 0.6946, Val Loss: 0.6949


Training: 100%|██████████| 158/158 [13:38<00:00,  5.18s/it]
Validation: 100%|██████████| 23/23 [00:47<00:00,  2.06s/it]


Epoch 2/20 - Train Loss: 0.6935, Val Loss: 0.6906


Training: 100%|██████████| 158/158 [13:31<00:00,  5.14s/it]
Validation: 100%|██████████| 23/23 [00:46<00:00,  2.03s/it]


Epoch 3/20 - Train Loss: 0.6882, Val Loss: 0.6839


Training: 100%|██████████| 158/158 [13:56<00:00,  5.29s/it]
Validation: 100%|██████████| 23/23 [00:47<00:00,  2.05s/it]


Epoch 4/20 - Train Loss: 0.6792, Val Loss: 0.6695


Training: 100%|██████████| 158/158 [13:18<00:00,  5.05s/it]
Validation: 100%|██████████| 23/23 [00:46<00:00,  2.01s/it]


Epoch 5/20 - Train Loss: 0.6522, Val Loss: 0.6243


Training: 100%|██████████| 158/158 [13:54<00:00,  5.28s/it]
Validation: 100%|██████████| 23/23 [00:47<00:00,  2.05s/it]


Epoch 6/20 - Train Loss: 0.5778, Val Loss: 0.5056


Training: 100%|██████████| 158/158 [13:49<00:00,  5.25s/it]
Validation: 100%|██████████| 23/23 [00:47<00:00,  2.06s/it]


Epoch 7/20 - Train Loss: 0.4604, Val Loss: 0.3999


Training: 100%|██████████| 158/158 [13:48<00:00,  5.24s/it]
Validation: 100%|██████████| 23/23 [00:47<00:00,  2.06s/it]


Epoch 8/20 - Train Loss: 0.3770, Val Loss: 0.3470


Training: 100%|██████████| 158/158 [13:43<00:00,  5.21s/it]
Validation: 100%|██████████| 23/23 [00:47<00:00,  2.06s/it]


Epoch 9/20 - Train Loss: 0.3316, Val Loss: 0.3131


Training: 100%|██████████| 158/158 [13:53<00:00,  5.28s/it]
Validation: 100%|██████████| 23/23 [00:47<00:00,  2.08s/it]


Epoch 10/20 - Train Loss: 0.2983, Val Loss: 0.2906


Training: 100%|██████████| 158/158 [13:48<00:00,  5.24s/it]
Validation: 100%|██████████| 23/23 [00:47<00:00,  2.06s/it]


Epoch 11/20 - Train Loss: 0.2717, Val Loss: 0.2728


Training: 100%|██████████| 158/158 [14:39<00:00,  5.56s/it]
Validation: 100%|██████████| 23/23 [00:50<00:00,  2.18s/it]


Epoch 12/20 - Train Loss: 0.2494, Val Loss: 0.2564


Training: 100%|██████████| 158/158 [14:45<00:00,  5.61s/it]
Validation: 100%|██████████| 23/23 [00:51<00:00,  2.22s/it]


Epoch 13/20 - Train Loss: 0.2264, Val Loss: 0.2398


Training: 100%|██████████| 158/158 [14:15<00:00,  5.41s/it]
Validation: 100%|██████████| 23/23 [00:48<00:00,  2.10s/it]


Epoch 14/20 - Train Loss: 0.2047, Val Loss: 0.2173


Training: 100%|██████████| 158/158 [14:29<00:00,  5.50s/it]
Validation: 100%|██████████| 23/23 [00:49<00:00,  2.15s/it]


Epoch 15/20 - Train Loss: 0.1736, Val Loss: 0.1923


Training: 100%|██████████| 158/158 [14:37<00:00,  5.55s/it]
Validation: 100%|██████████| 23/23 [00:48<00:00,  2.13s/it]


Epoch 16/20 - Train Loss: 0.1441, Val Loss: 0.1827


Training: 100%|██████████| 158/158 [14:31<00:00,  5.51s/it]
Validation: 100%|██████████| 23/23 [00:48<00:00,  2.11s/it]


Epoch 17/20 - Train Loss: 0.1228, Val Loss: 0.1606


Training: 100%|██████████| 158/158 [14:42<00:00,  5.59s/it]
Validation: 100%|██████████| 23/23 [00:48<00:00,  2.09s/it]


Epoch 18/20 - Train Loss: 0.1039, Val Loss: 0.1556


Training: 100%|██████████| 158/158 [14:51<00:00,  5.64s/it]
Validation: 100%|██████████| 23/23 [00:50<00:00,  2.18s/it]


Epoch 19/20 - Train Loss: 0.0861, Val Loss: 0.1471


Training: 100%|██████████| 158/158 [14:21<00:00,  5.45s/it]
Validation: 100%|██████████| 23/23 [00:48<00:00,  2.12s/it]


Epoch 20/20 - Train Loss: 0.0737, Val Loss: 0.1368


Evaluating: 100%|██████████| 45/45 [01:38<00:00,  2.18s/it]


Test Metrics for VD0.2_TD0.2_LR5e-05_WD0.01_BS32_EP20: {'edema': {'accuracy': 0.9409722222222222, 'precision': array([0.94209891, 0.93944354]), 'recall': array([0.95476773, 0.92282958]), 'f1': array([0.94839101, 0.93106245]), 'auroc': 0.982000644659156, 'auprc': 0.976772964724611, 'sensitivity': 0.9228295819935691, 'specificity': 0.9547677261613692}, 'effusion': {'accuracy': 0.95625, 'precision': array([0.9652568 , 0.94858612]), 'recall': array([0.94108984, 0.96977661]), 'f1': array([0.95302013, 0.95906433]), 'auroc': 0.98328491888241, 'auprc': 0.9767008398933996, 'sensitivity': 0.9697766097240473, 'specificity': 0.9410898379970545}}


TypeError: Object of type ndarray is not JSON serializable

In [12]:
combo = {
    "vision_drop": 0.1,
    "text_drop": 0.1,
    "weight_decay": 0.1,
    "learning_rate": 1e-05,
    "batch_size": 32,
    "num_epochs": 20,
    "patience": 3
}
execute_hyperparameter_combo(combo)

Running combo: VD0.1_TD0.1_LR1e-05_WD0.1_BS32_EP20


Training: 100%|██████████| 158/158 [13:26<00:00,  5.10s/it]
Validation: 100%|██████████| 23/23 [00:51<00:00,  2.23s/it]


Epoch 1/20 - Train Loss: 0.6944, Val Loss: 0.6949


Training: 100%|██████████| 158/158 [14:58<00:00,  5.69s/it]
Validation: 100%|██████████| 23/23 [00:51<00:00,  2.25s/it]


Epoch 2/20 - Train Loss: 0.6949, Val Loss: 0.6939


Training: 100%|██████████| 158/158 [15:13<00:00,  5.78s/it]
Validation: 100%|██████████| 23/23 [00:51<00:00,  2.24s/it]


Epoch 3/20 - Train Loss: 0.6933, Val Loss: 0.6920


Training: 100%|██████████| 158/158 [15:09<00:00,  5.76s/it]
Validation: 100%|██████████| 23/23 [00:51<00:00,  2.23s/it]


Epoch 4/20 - Train Loss: 0.6917, Val Loss: 0.6896


Training: 100%|██████████| 158/158 [15:13<00:00,  5.78s/it]
Validation: 100%|██████████| 23/23 [00:48<00:00,  2.11s/it]


Epoch 5/20 - Train Loss: 0.6879, Val Loss: 0.6866


Training: 100%|██████████| 158/158 [13:41<00:00,  5.20s/it]
Validation: 100%|██████████| 23/23 [00:47<00:00,  2.06s/it]


Epoch 6/20 - Train Loss: 0.6849, Val Loss: 0.6827


Training: 100%|██████████| 158/158 [16:11<00:00,  6.15s/it]
Validation: 100%|██████████| 23/23 [01:01<00:00,  2.66s/it]


Epoch 7/20 - Train Loss: 0.6803, Val Loss: 0.6771


Training: 100%|██████████| 158/158 [19:41<00:00,  7.48s/it]
Validation: 100%|██████████| 23/23 [01:03<00:00,  2.76s/it]


Epoch 8/20 - Train Loss: 0.6734, Val Loss: 0.6675


Training: 100%|██████████| 158/158 [24:10<00:00,  9.18s/it]
Validation: 100%|██████████| 23/23 [01:06<00:00,  2.87s/it]


Epoch 9/20 - Train Loss: 0.6617, Val Loss: 0.6514


Training: 100%|██████████| 158/158 [20:10<00:00,  7.66s/it]
Validation: 100%|██████████| 23/23 [00:50<00:00,  2.20s/it]


Epoch 10/20 - Train Loss: 0.6420, Val Loss: 0.6252


Training: 100%|██████████| 158/158 [15:02<00:00,  5.71s/it]
Validation: 100%|██████████| 23/23 [01:04<00:00,  2.80s/it]


Epoch 11/20 - Train Loss: 0.6092, Val Loss: 0.5826


Training: 100%|██████████| 158/158 [17:19<00:00,  6.58s/it]
Validation: 100%|██████████| 23/23 [00:40<00:00,  1.74s/it]


Epoch 12/20 - Train Loss: 0.5614, Val Loss: 0.5280


Training: 100%|██████████| 158/158 [11:49<00:00,  4.49s/it]
Validation: 100%|██████████| 23/23 [00:48<00:00,  2.10s/it]


Epoch 13/20 - Train Loss: 0.5097, Val Loss: 0.4763


Training: 100%|██████████| 158/158 [12:56<00:00,  4.91s/it]
Validation: 100%|██████████| 23/23 [00:44<00:00,  1.92s/it]


Epoch 14/20 - Train Loss: 0.4623, Val Loss: 0.4331


Training: 100%|██████████| 158/158 [13:22<00:00,  5.08s/it]
Validation: 100%|██████████| 23/23 [00:47<00:00,  2.05s/it]


Epoch 15/20 - Train Loss: 0.4235, Val Loss: 0.4028


Training: 100%|██████████| 158/158 [14:08<00:00,  5.37s/it]
Validation: 100%|██████████| 23/23 [00:48<00:00,  2.09s/it]


Epoch 16/20 - Train Loss: 0.3919, Val Loss: 0.3846


Training: 100%|██████████| 158/158 [13:46<00:00,  5.23s/it]
Validation: 100%|██████████| 23/23 [00:46<00:00,  2.03s/it]


Epoch 17/20 - Train Loss: 0.3686, Val Loss: 0.3552


Training: 100%|██████████| 158/158 [14:30<00:00,  5.51s/it]
Validation: 100%|██████████| 23/23 [00:49<00:00,  2.17s/it]


Epoch 18/20 - Train Loss: 0.3500, Val Loss: 0.3462


Training: 100%|██████████| 158/158 [14:46<00:00,  5.61s/it]
Validation: 100%|██████████| 23/23 [00:46<00:00,  2.04s/it]


Epoch 19/20 - Train Loss: 0.3320, Val Loss: 0.3277


Training: 100%|██████████| 158/158 [13:16<00:00,  5.04s/it]
Validation: 100%|██████████| 23/23 [00:46<00:00,  2.02s/it]


Epoch 20/20 - Train Loss: 0.3174, Val Loss: 0.3155


Evaluating: 100%|██████████| 45/45 [01:32<00:00,  2.05s/it]


Saved results for VD0.1_TD0.1_LR1e-05_WD0.1_BS32_EP20



In [12]:
combo = {
    "vision_drop": 0.2,
    "text_drop": 0.1,
    "weight_decay": 0.01,
    "learning_rate": 5e-05,
    "batch_size": 32,
    "num_epochs": 20,
    "patience": 3
}
execute_hyperparameter_combo(combo)

Running combo: VD0.2_TD0.1_LR5e-05_WD0.01_BS32_EP20


Training: 100%|██████████| 158/158 [12:56<00:00,  4.91s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.48s/it]


Epoch 1/20 - Train Loss: 0.6944, Val Loss: 0.6949


Training: 100%|██████████| 158/158 [13:36<00:00,  5.17s/it]
Validation: 100%|██████████| 23/23 [00:34<00:00,  1.48s/it]


Epoch 2/20 - Train Loss: 0.6931, Val Loss: 0.6904


Training: 100%|██████████| 158/158 [14:00<00:00,  5.32s/it]
Validation: 100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Epoch 3/20 - Train Loss: 0.6874, Val Loss: 0.6832


Training: 100%|██████████| 158/158 [13:37<00:00,  5.18s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.47s/it]


Epoch 4/20 - Train Loss: 0.6781, Val Loss: 0.6670


Training: 100%|██████████| 158/158 [13:46<00:00,  5.23s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.46s/it]


Epoch 5/20 - Train Loss: 0.6465, Val Loss: 0.6122


Training: 100%|██████████| 158/158 [13:27<00:00,  5.11s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.47s/it]


Epoch 6/20 - Train Loss: 0.5581, Val Loss: 0.4844


Training: 100%|██████████| 158/158 [13:28<00:00,  5.11s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.47s/it]


Epoch 7/20 - Train Loss: 0.4431, Val Loss: 0.3919


Training: 100%|██████████| 158/158 [13:58<00:00,  5.31s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.48s/it]


Epoch 8/20 - Train Loss: 0.3683, Val Loss: 0.3419


Training: 100%|██████████| 158/158 [13:59<00:00,  5.31s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.47s/it]


Epoch 9/20 - Train Loss: 0.3266, Val Loss: 0.3100


Training: 100%|██████████| 158/158 [13:17<00:00,  5.05s/it]
Validation: 100%|██████████| 23/23 [00:32<00:00,  1.42s/it]


Epoch 10/20 - Train Loss: 0.2952, Val Loss: 0.2872


Training: 100%|██████████| 158/158 [13:32<00:00,  5.14s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.45s/it]


Epoch 11/20 - Train Loss: 0.2695, Val Loss: 0.2701


Training: 100%|██████████| 158/158 [13:59<00:00,  5.31s/it]
Validation: 100%|██████████| 23/23 [00:34<00:00,  1.49s/it]


Epoch 12/20 - Train Loss: 0.2473, Val Loss: 0.2559


Training: 100%|██████████| 158/158 [13:56<00:00,  5.29s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.45s/it]


Epoch 13/20 - Train Loss: 0.2263, Val Loss: 0.2410


Training: 100%|██████████| 158/158 [13:58<00:00,  5.30s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.46s/it]


Epoch 14/20 - Train Loss: 0.2070, Val Loss: 0.2273


Training: 100%|██████████| 158/158 [13:47<00:00,  5.24s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.46s/it]


Epoch 15/20 - Train Loss: 0.1792, Val Loss: 0.2038


Training: 100%|██████████| 158/158 [13:09<00:00,  4.99s/it]
Validation: 100%|██████████| 23/23 [00:32<00:00,  1.41s/it]


Epoch 16/20 - Train Loss: 0.1498, Val Loss: 0.1886


Training: 100%|██████████| 158/158 [14:01<00:00,  5.33s/it]
Validation: 100%|██████████| 23/23 [00:34<00:00,  1.49s/it]


Epoch 17/20 - Train Loss: 0.1267, Val Loss: 0.1654


Training: 100%|██████████| 158/158 [14:11<00:00,  5.39s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.47s/it]


Epoch 18/20 - Train Loss: 0.1056, Val Loss: 0.1592


Training: 100%|██████████| 158/158 [13:17<00:00,  5.05s/it]
Validation: 100%|██████████| 23/23 [00:32<00:00,  1.41s/it]


Epoch 19/20 - Train Loss: 0.0888, Val Loss: 0.1537


Training: 100%|██████████| 158/158 [13:47<00:00,  5.24s/it]
Validation: 100%|██████████| 23/23 [00:33<00:00,  1.46s/it]


Epoch 20/20 - Train Loss: 0.0743, Val Loss: 0.1437


Evaluating: 100%|██████████| 45/45 [01:05<00:00,  1.47s/it]


Saved results for VD0.2_TD0.1_LR5e-05_WD0.01_BS32_EP20

