# Concatenation Fusion Model

## Imports & Setup

In [None]:
import os
import random
import json
import numpy as np
import pandas as pd
import pydicom
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    roc_auc_score,
    average_precision_score,
    confusion_matrix
)
from transformers import (
    BertTokenizer,
    BertModel,
    get_linear_schedule_with_warmup,
    ViTModel
)
from torchvision import transforms

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

## Load Paths & Label

In [None]:
image_dir = "/mnt/e/ecs289l/mimic-cxr-download/imageData/"
report_dir = "../download_data/textData/"
labels_file = "../download_data/metadata/edema+pleural_effusion_samples_v2.csv"
model_name_text = 'dmis-lab/biobert-base-cased-v1.1'
model_name_vision = 'google/vit-base-patch16-224-in21k'
max_length = 256

# ---------------------
# Labels and File Loading
# ---------------------
# Load metadata
meta = pd.read_csv(labels_file, dtype={'study_id': str})
meta['study_id'] = 's' + meta['study_id']
label_map = meta.set_index('study_id')[['edema', 'effusion']].to_dict(orient='index')

# Collect image paths and labels
all_image_paths = []
for root, _, files in os.walk(image_dir):
    for f in files:
        if f.endswith('.dcm'):
            all_image_paths.append(os.path.join(root, f))
paths, labels = [], []
for p in all_image_paths:
    sid = os.path.basename(os.path.dirname(p))
    if sid in label_map:
        paths.append(p)
        labels.append(label_map[sid]['edema'] + label_map[sid]['effusion']*2)  # placeholder, we will use list below
# Actually build multi-label list
labels = [ [label_map[os.path.basename(os.path.dirname(p))]['edema'],
            label_map[os.path.basename(os.path.dirname(p))]['effusion']]
          for p in paths ]


## Dataset Tokenizer and Transforms

In [None]:
tokenizer = BertTokenizer.from_pretrained(model_name_text)
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(0.2,0.2,0.2,0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

## Fusion Dataset Definitions

In [None]:
class FusionDataset(Dataset):
    def __init__(self, image_paths, report_dir, labels_map, tokenizer, max_length, transform=None):
        self.image_paths = image_paths
        self.report_dir = report_dir
        self.labels_map = labels_map
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        p = self.image_paths[idx]
        dcm = pydicom.dcmread(p)
        arr = dcm.pixel_array.astype(np.float32)
        img = Image.fromarray((arr/arr.max()*255).astype(np.uint8)).convert('RGB')
        if self.transform:
            img = self.transform(img)
        sid = os.path.basename(os.path.dirname(p))
        report_path = os.path.join(self.report_dir, sid, 'report.txt')
        with open(report_path, 'r', encoding='utf-8') as f:
            text = f.read()
        encoding = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        labels = self.labels_map[sid]
        return {
            'pixel_values': img,
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(labels, dtype=torch.float32)
        }

# Custom collate function to batch fusion data
def fusion_collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {
        'pixel_values': pixel_values,
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

## Split and DataLoaders

In [None]:
train_data, test_val_data, train_labels, test_val_labels = train_test_split(
    paths, labels, test_size=0.3, random_state=SEED, shuffle=True, stratify=labels
)
test_data, val_data, test_labels, val_labels = train_test_split(
    test_val_data, test_val_labels, test_size=1/3, random_state=SEED, shuffle=True, stratify=test_val_labels
)

print("Train / Val / Test sizes:", len(train_data), len(val_data), len(test_data))


batch_size = 32
train_ds = FusionDataset(train_paths, report_dir, label_map, tokenizer, max_length, transform)
val_ds   = FusionDataset(val_paths,   report_dir, label_map, tokenizer, max_length, transform)
test_ds  = FusionDataset(test_paths,  report_dir, label_map, tokenizer, max_length, transform)

def get_loader(ds, bs, shuffle=False):
    return DataLoader(ds, batch_size=bs, shuffle=shuffle, collate_fn=fusion_collate_fn)


## Model

In [None]:
class FusionModel(nn.Module):
    def __init__(self, vision_model_name, text_model_name, vision_drop=0.1, text_drop=0.1, hidden_dim=256):
        super().__init__()
        self.image_model = ViTModel.from_pretrained(vision_model_name)
        self.text_model  = BertModel.from_pretrained(text_model_name)
        self.vision_dropout = nn.Dropout(vision_drop)
        self.text_dropout   = nn.Dropout(text_drop)
        img_dim = self.image_model.config.hidden_size
        txt_dim = self.text_model.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(img_dim+txt_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, 2)
        )
    def forward(self, pixel_values, input_ids, attention_mask):
        img_out = self.image_model(pixel_values=pixel_values).pooler_output
        txt_out = self.text_model(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        fusion = torch.cat([self.vision_dropout(img_out), self.text_dropout(txt_out)], dim=1)
        logits = self.classifier(fusion)
        return logits


## Metrics

In [None]:
def compute_metrics(y_true, y_pred, y_probs):
    metrics = {}
    for i, name in enumerate(['edema','effusion']):
        acc = accuracy_score(y_true[:,i], y_pred[:,i])
        precision, recall, f1, _ = precision_recall_fscore_support(
            y_true[:,i], y_pred[:,i], zero_division=0
        )
        try:
            auroc = roc_auc_score(y_true[:,i], y_probs[:,i])
        except ValueError:
            auroc = float('nan')
        try:
            auprc = average_precision_score(y_true[:,i], y_probs[:,i])
        except ValueError:
            auprc = float('nan')
        tn, fp, fn, tp = confusion_matrix(y_true[:,i], y_pred[:,i]).ravel()
        sens = tp/(tp+fn) if (tp+fn)>0 else 0.0
        spec = tn/(tn+fp) if (tn+fp)>0 else 0.0
        metrics[name] = {
            'accuracy': acc,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'auroc': auroc,
            'auprc': auprc,
            'sensitivity': sens,
            'specificity': spec
        }
    return metrics

## Training

In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    for batch in loader:
        optimizer.zero_grad()
        pixel_values = batch['pixel_values'].to(device)
        input_ids     = batch['input_ids'].to(device)
        attention_mask= batch['attention_mask'].to(device)
        labels        = batch['labels'].to(device)
        logits = model(pixel_values, input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


## Evaluation

In [None]:
def evaluate(model, loader):
    model.eval()
    all_labels, all_preds, all_probs = [], [], []
    with torch.no_grad():
        for batch in loader:
            pixel_values = batch['pixel_values'].to(device)
            input_ids     = batch['input_ids'].to(device)
            attention_mask= batch['attention_mask'].to(device)
            labels        = batch['labels'].cpu().numpy()
            logits = model(pixel_values, input_ids, attention_mask).cpu().numpy()
            probs = torch.sigmoid(torch.tensor(logits)).numpy()
            preds = (probs > 0.5).astype(int)
            all_labels.append(labels)
            all_preds.append(preds)
            all_probs.append(probs)
    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)
    y_probs= np.vstack(all_probs)
    return compute_metrics(y_true, y_pred, y_probs)

## Hyperparameter Combination

In [None]:
hyperparameter_combinations = []
for vision_drop in [0.1, 0.2]:
    for text_drop in [0.1, 0.2]:
        for lr in [1e-5, 5e-5, 2e-4]:
            for wd in [0, 0.01, 0.1]:
                for bs in [16, 32, 64]:
                    hyperparameter_combinations.append({
                        'vision_drop': vision_drop,
                        'text_drop':   text_drop,
                        'learning_rate': lr,
                        'weight_decay': wd,
                        'batch_size': bs,
                        'num_epochs': 20
                    })

results_file = 'fusion_results.json'
if not os.path.exists(results_file):
    with open(results_file, 'w') as f:
        json.dump([], f)

for combo in hyperparameter_combinations:
    name = f"VD{combo['vision_drop']}_TD{combo['text_drop']}_LR{combo['learning_rate']}_WD{combo['weight_decay']}_BS{combo['batch_size']}_EP{combo['num_epochs']}"
    print(f"üîß Running combo: {name}")

    train_loader = get_loader(train_ds, combo['batch_size'], shuffle=True)
    val_loader   = get_loader(val_ds,   combo['batch_size'])
    test_loader  = get_loader(test_ds,  combo['batch_size'])

    model = FusionModel(
        vision_model_name=model_name_vision,
        text_model_name=model_name_text,
        vision_drop=combo['vision_drop'],
        text_drop=combo['text_drop']
    ).to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=combo['learning_rate'],
        weight_decay=combo['weight_decay']
    )

    total_steps = len(train_loader) * combo['num_epochs']
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1*total_steps),
        num_training_steps=total_steps
    )

    best_val_loss = float('inf')
    patience = 3
    no_improve = 0

    for epoch in range(1, combo['num_epochs']+1):
        train_loss = train_epoch(model, train_loader, criterion)
        val_metrics = evaluate(model, val_loader)
        val_loss = np.mean([m['accuracy'] for m in val_metrics.values()])
        print(f"Epoch {epoch}/{combo['num_epochs']} - Train Loss: {train_loss:.4f}, Val Avg Acc: {val_loss:.4f}")

        scheduler.step()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

    test_metrics = evaluate(model, test_loader)
    print(f"üìù Test Metrics for {name}: {test_metrics}")

    with open(results_file, 'r') as f:
        results = json.load(f)
    results.append({'name': name, 'combo': combo, 'metrics': test_metrics})
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=4)
    print(f"‚úÖ Saved results for {name}\n")