In [1]:

import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
import os
import pandas as pd
import inspect
if not hasattr(inspect, 'formatargspec'):
    inspect.formatargspec = inspect.formatargvalues
import random
import logging
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.metrics import f1_score, roc_auc_score
import numpy as np
from transformers import AutoTokenizer, AutoModel, ViTModel
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.classification import MultilabelF1Score, MultilabelAUROC
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from rouge import Rouge
from sklearn.metrics import roc_auc_score
from sacrebleu import corpus_bleu
from rouge import Rouge


# Configure environment variables
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1"

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
LABEL_COLS = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
    'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion',
    'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices'
]


RARE_LABELS = ['Fracture', 'Pleural Other', 'Enlarged Cardiomediastinum', 'Pneumothorax']


# Custom loss function that balances class distribution and focuses on hard samples for multilabel classification
class DistributionBalancedFocalLoss(nn.Module):
    def __init__(self, class_freq, beta=0.9999, gamma=2, eps=1e-6):
        super().__init__()     # Compute class weights based on effective number of samples
        self.gamma = gamma
        effective_num = 1.0 - torch.pow(torch.tensor(beta), class_freq)
        self.weights = (1.0 - beta) / (effective_num + eps)
        self.weights = self.weights / self.weights.sum() * len(class_freq)

    def forward(self, inputs, targets):  # Computes the focal loss with class balancing for multilabel targets
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        focal_loss = self.weights.to(inputs.device) * (1 - pt) ** self.gamma * BCE_loss
        return focal_loss.mean()



# Sampler to ensure each batch contains a balanced number of rare and common class samples
class BalancedMultilabelSampler(torch.utils.data.Sampler):
    def __init__(self, dataset, rare_labels, samples_per_class=4, batch_size=16):
         # Initializes sampler with dataset and rare label balancing parameters
        self.dataset = dataset
        self.rare_labels = rare_labels
        self.samples_per_class = samples_per_class
        self.batch_size = batch_size

        # Calculate rare indices
        self.rare_indices = []
        for idx, row in enumerate(dataset.df[rare_labels].values):
            if any(row == 1):
                self.rare_indices.append(idx)

        self.common_indices = [i for i in range(len(dataset))
                              if i not in self.rare_indices]

        # Fix n_batches calculation
        self.n_batches = (len(self.rare_indices) // samples_per_class) + \
                        (len(self.common_indices) // (batch_size - samples_per_class))

    def __iter__(self):           # Yields indices to form balanced batches of rare and common samples
        for _ in range(self.n_batches):
            # Sample with replacement if needed
            rare = random.choices(self.rare_indices, k=self.samples_per_class)
            common = random.sample(self.common_indices, k=self.batch_size - self.samples_per_class)
            yield from rare + common  # Individual indices

    def __len__(self):           # Returns total number of samples in an epoch
        return self.n_batches * self.batch_size



# Utility function to print the distribution of positive samples for each label in a dataset split
def print_label_distribution(df, split_name):
    print(f"\nLabel distribution in {split_name} set:")
    label_counts = df[LABEL_COLS].apply(lambda x: (x == 1).sum())
    total_samples = len(df)
    for label, count in label_counts.items():
        print(f"{label}: {count} positive samples ({count/total_samples:.2%})")
    print(f"Total samples in {split_name}: {total_samples}\n")

NUM_CLASSES = len(LABEL_COLS)
BATCH_SIZE = 16
ACCUMULATION_STEPS = 4
NUM_EPOCHS = 10
THRESHOLD = 0.5
MAX_LENGTH = 256
NUM_WORKERS = 0




# Custom focal loss for multilabel classification with optional masking for uncertain labels
class MultiLabelFocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, pos_weight=None):           # Initializes focal loss parameters for multilabel tasks
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.pos_weight = pos_weight

    def forward(self, inputs, targets):         # Computes the masked focal loss for multilabel classification
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none', pos_weight=self.pos_weight)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        mask = (targets != 0.5).float()  # Mask uncertain labels
        return (focal_loss * mask).mean()

def get_stratified_subset(df, n_samples):          # Returns a stratified subset of the dataframe, preserving multilabel distribution
    msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=len(df)-n_samples, random_state=42)
    for train_idx, _ in msss.split(df, df[LABEL_COLS]):
        return df.iloc[train_idx]

def get_cached_splits():        # Loads or generates and caches train/val/test splits, ensuring stratification and valid file paths
    cache_files = ["train_df.parquet", "val_df.parquet", "test_df.parquet"]
    if all(os.path.exists(f) for f in cache_files):
        logger.info("Loading cached splits")
        return (
            pd.read_parquet("train_df.parquet"),
            pd.read_parquet("val_df.parquet"),
            pd.read_parquet("test_df.parquet")
        )

    logger.info("Generating new splits")
    df1 = pd.read_csv('cxr-record-list.csv.gz')
    df2 = pd.read_csv('mimic-cxr-2.0.0-chexpert.csv')
    df3 = pd.read_csv('mimic-cxr-2.0.0-metadata.csv')

    merge_df = df1.merge(df2, on=['subject_id', 'study_id'], how='inner')
    merge_df = merge_df.merge(df3[['dicom_id', 'ViewPosition']], on='dicom_id', how='inner')

    with open('IMAGE_FILENAMES.txt', 'r') as f:
        image_paths = [line.strip() for line in f if line.strip()]

    image_paths = [p for p in image_paths if p.startswith(('files/p10/', 'files/p11/', 'files/p12/'))]
    dicom_to_path = {os.path.splitext(os.path.basename(p))[0]: p for p in image_paths}

    df = merge_df.copy()
    df['image_rel_path'] = df['dicom_id'].astype(str).map(dicom_to_path)
    df = df[df['image_rel_path'].notnull()]

    # Enhanced label processing
    df[LABEL_COLS] = df[LABEL_COLS].replace(-1, 0.5).fillna(0.0)

    df['image_path'] = df['image_rel_path'].apply(lambda x: os.path.join('mimic-cxr-jpg', '2.1.0', x))

    def build_report_path(row):         # Generates the file path for each radiology report using subject and study IDs from the DataFrame row
        p_prefix = f"p{str(row['subject_id'])[:2]}"
        return os.path.join(
            'mimic-cxr-reports', 'files', p_prefix,
            f"p{row['subject_id']}", f"s{row['study_id']}.txt"
        )

    df['report_path'] = df.apply(build_report_path, axis=1)
    df = df[df['image_path'].apply(os.path.exists) & df['report_path'].apply(os.path.exists)]

    def get_stratified_splits(df):              # Splits the dataset into train/val/test (60/20/20) using multilabel stratification to preserve label distributions, then caches splits as Parquet files
        y = df[LABEL_COLS].values
        msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
        train_idx, temp_idx = next(msss.split(df, y))
        msss_val = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
        val_idx, test_idx = next(msss_val.split(df.iloc[temp_idx], df.iloc[temp_idx][LABEL_COLS].values))
        return df.iloc[train_idx], df.iloc[temp_idx].iloc[val_idx], df.iloc[temp_idx].iloc[test_idx]

    train_df, val_df, test_df = get_stratified_splits(df)
    train_df.to_parquet("train_df.parquet")
    val_df.to_parquet("val_df.parquet")
    test_df.to_parquet("test_df.parquet")
    return train_df, val_df, test_df

class MemoryOptimizedDataset(Dataset):              # Memory-efficient dataset for loading images, reports, and labels with augmentations and error handling
    def __init__(self, df, tokenizer, rare_labels):   # Initializes dataset with data, tokenizer, rare label info, and sets up augmentations
        self.df = df
        self.tokenizer = tokenizer
        self.rare_labels = rare_labels
        self.rare_indices = []
        self.common_indices = []

        # Precompute rare indices
        for idx, row in df.iterrows():
            if any(row[label] == 1 for label in self.rare_labels):
                self.rare_indices.append(idx)

        # Base augmentations
        self.base_transform = A.Compose([
            A.Resize(height=256, width=256),  # ✅ Explicit parameter names
            A.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),  # ✅ Tuple for size
            A.Rotate(limit=10),
            A.RandomBrightnessContrast(p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

        # Strong augmentations for rare classes
        self.rare_transform = A.Compose([
            A.RandomResizedCrop(size=(224, 224), scale=(0.2, 1.0)),  # Fixed here
            A.Resize(height=256, width=256),
            A.RandomResizedCrop(size=(224, 224), scale=(0.4, 1.0)),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=45, p=0.7),
            A.ElasticTransform(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.CLAHE(p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

        self.valid_indices = self._validate_files()
        self.sample_weights = self._calculate_sample_weights()

    def _calculate_sample_weights(self):        # Calculates sample weights based on inverse class frequency
        # Calculate inverse class frequencies
        class_counts = self.df[LABEL_COLS].sum(axis=0)
        class_weights = 1 / (class_counts + 1e-6)

        # Assign sample weight = sum of weights for its positive labels
        sample_weights = self.df[LABEL_COLS].apply(
            lambda x: sum(class_weights[x == 1]),
            axis=1
        )
        return sample_weights.values

    def _validate_files(self):          # Validates image and report file existence, logging any missing files
        valid_indices = []
        for idx in range(len(self.df)):
            row = self.df.iloc[idx]
            if os.path.exists(row['image_path']) and os.path.exists(row['report_path']):
                valid_indices.append(idx)
            else:
                logger.warning(f"Missing files for index {idx}")
        return valid_indices

    def __len__(self):               # Returns number of valid samples in the dataset
        return len(self.valid_indices)

    def __getitem__(self, idx):             # Loads and returns a transformed image, multilabels, and tokenized report for a given index and retries up to 5 times if loading fails, returns dummy data if all attempts fail
        for attempt in range(5):  # Try up to 5 times
            try:
                idx = int(idx)
                actual_idx = idx
                row = self.df.iloc[actual_idx]

                # Load image
                image = Image.open(row['image_path']).convert("RGB")
                image_np = np.array(image)

                # Apply rare or base transforms
                if actual_idx in self.rare_indices:
                    transformed = self.rare_transform(image=image_np)
                else:
                    transformed = self.base_transform(image=image_np)
                image_tensor = transformed["image"]

                # Load and tokenize report
                with open(row['report_path'], 'r', encoding='utf-8') as f:
                    report = f.read().strip()

                tokens = self.tokenizer(
                    report[:500],
                    max_length=256,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )

                return (
                    image_tensor,
                    torch.FloatTensor(row[LABEL_COLS].values.astype(np.float32)),
                    tokens.input_ids.squeeze(0).long(),
                    tokens.attention_mask.squeeze(0)
                )

            except Exception as e:
                logger.warning(f"Attempt {attempt+1} failed: {str(e)}")
                # Use rare_indices and common_indices for retry
                if self.rare_indices and self.common_indices:
                    idx = random.choice(self.rare_indices + self.common_indices)
                else:
                    idx = random.randint(0, len(self.df) - 1)

        # Fallback to dummy data after 5 failed attempts
        logger.error(f"Failed to load sample after 5 attempts. Returning dummy data.")
        return (
            torch.randn(3, 224, 224),
            torch.zeros(len(LABEL_COLS)),
            torch.zeros(256, dtype=torch.long),
            torch.zeros(256)
        )


class EfficientFlamingo(nn.Module):   # Multimodal model combining vision (ViT) and text (BioClinicalBERT) encoders for classification and report generation
    def __init__(self, num_classes=NUM_CLASSES, vocab_size=None):       # Initializes vision and text encoders, fusion layers, classifier, and report head
        super().__init__()
        self.vision_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.text_encoder = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

        # Freeze first 6 layers of text encoder
        text_encoder_params = list(self.text_encoder.parameters())
        for param in text_encoder_params[:6]:
            param.requires_grad_(False)

        self.vision_proj = nn.Linear(768, 256)
        self.text_proj = nn.Linear(768, 256)

        # Modified fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.classifier = nn.Linear(256, num_classes)

        # Changed report head to output sequence
        self.report_head = nn.Linear(256, vocab_size)

    def forward(self, images, input_ids=None, attention_mask=None):          # Computes fused vision and text features for multilabel classification and report generation
        # Image features
        vision_features = self.vision_encoder(images).last_hidden_state.mean(1)
        vision_features = self.vision_proj(vision_features)

        # Text features (keep sequence dimension)
        text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_features = text_outputs.last_hidden_state  # [batch, seq_len, 768]
        text_features = self.text_proj(text_features)   # [batch, seq_len, 256]

        # Fuse features for each token in sequence
        fused = torch.cat([
            vision_features.unsqueeze(1).expand(-1, text_features.size(1), -1),
            text_features
        ], dim=-1)  # [batch, seq_len, 512]

        fused = self.fusion(fused)  # [batch, seq_len, 256]

        # Classification (global average pooling)
        cls_logits = self.classifier(fused.mean(1))

        # Report generation (per-token predictions)
        report_logits = self.report_head(text_features)  # [batch, seq_len, vocab_size]

        return cls_logits, report_logits


def create_dataloader(dataset, batch_size, sampler=None):       # Creates a DataLoader for a given dataset and optional sampler
    return DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=0,
        pin_memory=True,
        persistent_workers=False
    )

def train_epoch(model, loader, criterion, optimizer, scaler, device):       # Trains the model for one epoch, computes combined classification and report generation loss, and updates weights
    model.train()
    total_loss = 0.0
    progress_bar = tqdm(total=len(loader), desc="Training", unit="batch")

    for batch_idx, (images, labels, input_ids, attn_mask) in enumerate(loader):
        try:
            # Move data to device FIRST
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            input_ids = input_ids.to(device, non_blocking=True)
            attn_mask = attn_mask.to(device, non_blocking=True)

            optimizer.zero_grad()

            # Mixed precision context
            with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
                # Forward pass
                cls_logits, report_logits = model(images, input_ids, attn_mask)

                # Classification loss
                loss_cls = criterion[0](cls_logits, labels)

                # Report generation loss (FIXED DIMENSIONS)
                loss_report = criterion[1](
                    report_logits.view(-1, report_logits.size(-1)),  # [batch_size*seq_len, vocab_size]
                    input_ids.view(-1)                               # [batch_size*seq_len]
                )

                # Combined loss
                loss = 0.7 * loss_cls + 0.3 * loss_report

            # Backpropagation
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

            # Update progress bar
            progress_bar.set_postfix({
                "Loss": f"{loss.item():.4f}",
                "CLS Loss": f"{loss_cls.item():.4f}",
                "RPT Loss": f"{loss_report.item():.4f}"
            })
            progress_bar.update(1)

        except Exception as e:
            print(f"\nError in batch {batch_idx}: {str(e)}")
            traceback.print_exc()
            continue

    progress_bar.close()
    return total_loss / len(loader)


def validate(model, loader, device, display_samples=0):     # Evaluates the model on validation/test data, computes classification and report generation metrics
    model.eval()
    f1_metric.reset()
    auroc_metric.reset()
    per_class_f1_metric = MultilabelF1Score(num_labels=NUM_CLASSES, average=None, threshold=THRESHOLD).to(device)

    # Initialize text metrics
    hypotheses = []
    references = []
    import rougescore


    from nltk.translate.bleu_score import SmoothingFunction
    from nltk import word_tokenize
    import warnings
    warnings.filterwarnings('ignore', category=UserWarning, module='nltk.translate.bleu_score')
    smooth = SmoothingFunction().method4

    with torch.no_grad():
        for images, labels, input_ids, attn_mask in loader:
            images = images.to(device)
            input_ids = input_ids.to(device)
            attn_mask = attn_mask.to(device)
            labels = labels.to(device)

            # Forward pass
            cls_logits, report_logits = model(images, input_ids, attn_mask)
            probs = torch.sigmoid(cls_logits)

            # Update classification metrics
            f1_metric.update(probs, labels.int())
            auroc_metric.update(probs, labels.int())
            per_class_f1_metric.update(probs, labels.int())

            # Decode reports
            pred_ids = report_logits.argmax(-1)
            hyps = [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                    for ids in pred_ids]
            refs = [tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                    for ids in input_ids]

            hypotheses.extend(hyps)
            references.extend(refs)

    # Compute classification metrics
    macro_f1 = f1_metric.compute().item()
    macro_auroc = auroc_metric.compute().item()
    per_class_f1 = per_class_f1_metric.compute()


    # Ensure hypotheses are strings
    if isinstance(hypotheses[0], list):
        hypotheses = [' '.join(h) for h in hypotheses]

    # Ensure references is a list of lists of strings
    if isinstance(references[0], str):
        references = [[r] for r in references]

    bleu_score = corpus_bleu(
        hypotheses,
        references,
        tokenize='none'
    ).score / 100


    # Compute ROUGE-L using rougescore
    rouge_scores = []
    for hyp, ref in zip(hypotheses, references):
        try:
            score = rougescore.rouge_l(hyp, ref, alpha=0.5)
            rouge_scores.append(score)
        except:
            rouge_scores.append(0.0)
    avg_rouge = sum(rouge_scores)/len(rouge_scores) if rouge_scores else 0.0

    # Log rare classes
    logger.info("\nRare Class Performance:")
    for cls in RARE_LABELS:
        idx = LABEL_COLS.index(cls)
        logger.info(f"{cls}: {per_class_f1[idx]:.4f}")

    # Display first 5 samples if requested
    if display_samples > 0:
        logger.info("\nGenerated Reports (First 5):")
        for i in range(min(5, len(hypotheses))):
            logger.info(f"\nSample {i+1}:")
            logger.info(f"Generated: {hypotheses[i]}")
            logger.info(f"Reference: {references[i]}")
            logger.info("-"*80)

    return {
        'macro_f1': macro_f1,
        'macro_auroc': macro_auroc,
        'bleu4': bleu_score,
        'rouge_l': avg_rouge
    }







def validate_with_thresholds(model, loader, device, thresholds):   # Evaluates the model using class-specific thresholds for multilabel classification
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for images, labels, input_ids, attention_mask in loader:
            images = images.to(device)
            labels = labels.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            cls_logits, _ = model(images, input_ids, attention_mask)
            probs = torch.sigmoid(cls_logits)
            all_probs.append(probs.cpu())
            all_labels.append(labels.cpu())

    all_probs = torch.cat(all_probs)
    all_labels = torch.cat(all_labels)

    # Apply class-specific thresholds
    preds = torch.zeros_like(all_probs)
    for cls, thresh in thresholds.items():
        if cls in LABEL_COLS:
            idx = LABEL_COLS.index(cls)
            preds[:, idx] = (all_probs[:, idx] > thresh).float()

    # Compute metrics
    macro_f1 = f1_score(all_labels.numpy(), preds.numpy(), average='macro', zero_division=0)
    macro_auroc = safe_macro_roc_auc_score(all_labels.numpy(), all_probs.numpy())

    return macro_f1, macro_auroc


def safe_macro_roc_auc_score(y_true, y_score):          # Computes macro AUROC, skipping classes with only one label present

    aucs = []
    for i in range(y_true.shape[1]):
        y_true_i = y_true[:, i]
        y_score_i = y_score[:, i]
        if len(np.unique(y_true_i)) < 2:
            continue  # skip this class
        try:
            auc = roc_auc_score(y_true_i, y_score_i)
            aucs.append(auc)
        except Exception:
            continue
    if len(aucs) == 0:
        return np.nan
    return np.mean(aucs)


def identify_rare_labels(df, categorical_columns, threshold=0.05):          # Identifies rare labels in the dataset based on a frequency threshold
    rare_labels = {}
    for col in categorical_columns:
        # Calculate positive class frequency (label=1)
        pos_freq = df[col].mean()
        # Store column name if positive samples are rare
        if pos_freq < threshold:
            rare_labels[col] = [1]  # Mark positive class as rare
    return rare_labels


if __name__ == "__main__":          # Main script: sets up device, data, model, training, and evaluation pipeline
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    assert torch.cuda.is_available(), "CUDA not available!"
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB")

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.cuda.empty_cache()

    logger.info("Loading data splits")
    train_df, val_df, test_df = get_cached_splits()

    train_df = get_stratified_subset(train_df, 2000)
    val_df = get_stratified_subset(val_df, 400)
    test_df = get_stratified_subset(test_df, 400)

    print_label_distribution(train_df, 'Training (used)')
    print_label_distribution(val_df, 'Validation (used)')
    print_label_distribution(test_df, 'Testing (used)')

    thresholds = {
        'Pleural Other': 0.15,
        'Fracture': 0.10,
        'Enlarged Cardiomediastinum': 0.20
    }

    # Calculate class frequencies from the ACTUAL training subset
    class_freq = torch.tensor(
        train_df[LABEL_COLS].replace(0.5, np.nan).mean(axis=0).values,
        dtype=torch.float32
    )

    categorical_columns = LABEL_COLS  # Use your existing label list
    threshold = 0.05
    rare_labels = identify_rare_labels(
    train_df,
    categorical_columns,  # Positional argument
    threshold=threshold
)

    # Inverse frequency weighting (handle division by zero)
    pos_weight = (1 - class_freq) / (class_freq + 1e-6)
    pos_weight = pos_weight.to(device)

    RARE_LABELS = ['Pleural Other', 'Fracture', 'Enlarged Cardiomediastinum']

    logger.info("Initializing tokenizer")
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

    logger.info("Creating datasets")
    train_dataset = MemoryOptimizedDataset(
        df=train_df,
        tokenizer=tokenizer,
        rare_labels=RARE_LABELS
    )  # Add RARE_LABELS
    val_dataset = MemoryOptimizedDataset(val_df, tokenizer, RARE_LABELS)  # Add RARE_LABELS

    class_freq = torch.tensor(train_df[LABEL_COLS].mean(axis=0).values, dtype=torch.float32)

    # Initialize balanced sampler and loss
    sampler = BalancedMultilabelSampler(
        train_dataset,
        rare_labels=RARE_LABELS,
        samples_per_class=4
    )

    criterion = (
    DistributionBalancedFocalLoss(
        class_freq=class_freq,
        beta=0.99999,
        gamma=5
    ),
    nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
)

    logger.info("Creating dataloaders")
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=sampler,
        num_workers=0,
        pin_memory=False,
        persistent_workers=False
    )
    val_loader = create_dataloader(val_dataset, BATCH_SIZE)

    logger.info("Initializing model")
    model = EfficientFlamingo(num_classes=NUM_CLASSES, vocab_size=len(tokenizer)).to(device)

    f1_metric = MultilabelF1Score(num_labels=NUM_CLASSES, average='macro', threshold=0.5).to(device)
    auroc_metric = MultilabelAUROC(num_labels=NUM_CLASSES, average='macro').to(device)

    optimizer = torch.optim.AdamW([
        {'params': model.vision_encoder.parameters(), 'lr': 1e-5},
        {'params': model.text_encoder.parameters(), 'lr': 1e-5},
        {'params': model.fusion.parameters(), 'lr': 1e-4},
        {'params': model.classifier.parameters(), 'lr': 1e-4},
        {'params': model.report_head.parameters(), 'lr': 1e-4}
    ], weight_decay=0.01)

    scaler = GradScaler()
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=1e-4,
        steps_per_epoch=len(train_loader),
        epochs=NUM_EPOCHS
    )

    # Training loop
    best_f1 = 0
    logger.info("Starting training")
    try:
        for epoch in range(1, NUM_EPOCHS+1):
            logger.info(f"Epoch {epoch}/{NUM_EPOCHS}")
            train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device)

            # Validation with metrics
            val_metrics = validate(model, val_loader, device)

            logger.info(
                f"Epoch {epoch:02d} | "
                f"Loss: {train_loss:.4f} | "
                f"Val F1: {val_metrics['macro_f1']:.4f} | "
                f"BLEU-4: {val_metrics['bleu4']:.4f} | "
                f"ROUGE-L: {val_metrics['rouge_l']:.4f}"
            )

            if val_metrics['macro_f1'] > best_f1:
                best_f1 = val_metrics['macro_f1']
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'thresholds': thresholds,
                    'rare_labels': rare_labels,  # <-- ADD THIS LINE
                }, 'best_model.pth')
                logger.info("New best model saved")

            scheduler.step()
    except KeyboardInterrupt:
        logger.info("Training interrupted by user")
    except Exception as e:
        logger.error(f"Critical error: {str(e)}")
        raise

    logger.info(f"Training completed. Best validation Macro F1: {best_f1:.4f}")

    logger.info("\nTesting with adjusted thresholds...")

    # Load checkpoint with device mapping
    checkpoint = torch.load("best_model.pth", map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    # Get preserved parameters from checkpoint
    rare_labels = checkpoint.get('rare_labels', RARE_LABELS)
    thresholds = checkpoint.get('thresholds', {
        'Pleural Other': 0.3,
        'Fracture': 0.25,
        'Enlarged Cardiomediastinum': 0.35
    })

    # Reinitialize test dataset with proper parameters
    test_dataset = MemoryOptimizedDataset(
        test_df,
        tokenizer,
        rare_labels=rare_labels
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )

    # Run validation with sample display
    test_metrics = validate(
        model=model,
        loader=test_loader,
        device=device,
        display_samples=5  # Show top 5 reports
    )

    # Log comprehensive results
    logger.info("\n=== Final Test Metrics ===")
    logger.info(f"Clinical Metrics:")
    logger.info(f"  Macro F1: {test_metrics['macro_f1']:.4f}")
    logger.info(f"  Macro AUROC: {test_metrics['macro_auroc']:.4f}")
    logger.info(f"\nReport Generation Metrics:")
    logger.info(f"  BLEU-4: {test_metrics['bleu4']:.4f}")
    logger.info(f"  ROUGE-L: {test_metrics['rouge_l']:.4f}")

    # Log threshold-adjusted performance if needed
    if 'thresholds' in checkpoint:
        logger.info("\nThreshold-Adjusted Performance:")
        test_f1, test_auroc = validate_with_thresholds(
            model,
            test_loader,
            device,
            thresholds
        )
        logger.info(f"  Macro F1: {test_f1:.4f}")
        logger.info(f"  Macro AUROC: {test_auroc:.4f}")




[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\13joe\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\13joe\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
INFO:__main__:Using device: cuda
INFO:__main__:Loading data splits
INFO:__main__:Loading cached splits


CUDA device name: NVIDIA GeForce RTX 4070 Laptop GPU
CUDA memory allocated: 0.00GB


INFO:__main__:Initializing tokenizer



Label distribution in Training (used) set:
Atelectasis: 347 positive samples (17.35%)
Cardiomegaly: 348 positive samples (17.40%)
Consolidation: 77 positive samples (3.85%)
Edema: 195 positive samples (9.75%)
Enlarged Cardiomediastinum: 52 positive samples (2.60%)
Fracture: 41 positive samples (2.05%)
Lung Lesion: 56 positive samples (2.80%)
Lung Opacity: 403 positive samples (20.15%)
No Finding: 760 positive samples (38.00%)
Pleural Effusion: 410 positive samples (20.50%)
Pleural Other: 21 positive samples (1.05%)
Pneumonia: 137 positive samples (6.85%)
Pneumothorax: 76 positive samples (3.80%)
Support Devices: 450 positive samples (22.50%)
Total samples in Training (used): 2000


Label distribution in Validation (used) set:
Atelectasis: 69 positive samples (17.25%)
Cardiomegaly: 70 positive samples (17.50%)
Consolidation: 15 positive samples (3.75%)
Edema: 39 positive samples (9.75%)
Enlarged Cardiomediastinum: 10 positive samples (2.50%)
Fracture: 8 positive samples (2.00%)
Lung Le

INFO:__main__:Creating datasets
INFO:__main__:Creating dataloaders
INFO:__main__:Initializing model
INFO:__main__:Starting training
INFO:__main__:Epoch 1/10


Training:   0%|          | 0/184 [00:00<?, ?batch/s]

INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0236
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:Epoch 01 | Loss: 3.0605 | Val F1: 0.0678 | BLEU-4: 0.4065 | ROUGE-L: 0.3036
INFO:__main__:New best model saved
INFO:__main__:Epoch 2/10


Training:   0%|          | 0/184 [00:00<?, ?batch/s]

INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0000
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:Epoch 02 | Loss: 2.9666 | Val F1: 0.0699 | BLEU-4: 0.3914 | ROUGE-L: 0.3093
INFO:__main__:New best model saved
INFO:__main__:Epoch 3/10


Training:   0%|          | 0/184 [00:00<?, ?batch/s]

INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0000
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:Epoch 03 | Loss: 2.8890 | Val F1: 0.0778 | BLEU-4: 0.3914 | ROUGE-L: 0.2868
INFO:__main__:New best model saved
INFO:__main__:Epoch 4/10


Training:   0%|          | 0/184 [00:00<?, ?batch/s]

INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0000
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:Epoch 04 | Loss: 2.8181 | Val F1: 0.0736 | BLEU-4: 0.3938 | ROUGE-L: 0.2853
INFO:__main__:Epoch 5/10


Training:   0%|          | 0/184 [00:00<?, ?batch/s]

INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0000
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:Epoch 05 | Loss: 2.7666 | Val F1: 0.0728 | BLEU-4: 0.3950 | ROUGE-L: 0.2919
INFO:__main__:Epoch 6/10


Training:   0%|          | 0/184 [00:00<?, ?batch/s]

INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0000
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:Epoch 06 | Loss: 2.7119 | Val F1: 0.0620 | BLEU-4: 0.4125 | ROUGE-L: 0.3017
INFO:__main__:Epoch 7/10


Training:   0%|          | 0/184 [00:00<?, ?batch/s]

INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0000
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:Epoch 07 | Loss: 2.6576 | Val F1: 0.0531 | BLEU-4: 0.4233 | ROUGE-L: 0.3106
INFO:__main__:Epoch 8/10


Training:   0%|          | 0/184 [00:00<?, ?batch/s]

INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0000
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:Epoch 08 | Loss: 2.6040 | Val F1: 0.0492 | BLEU-4: 0.4304 | ROUGE-L: 0.3180
INFO:__main__:Epoch 9/10


Training:   0%|          | 0/184 [00:00<?, ?batch/s]

INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0000
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:Epoch 09 | Loss: 2.5506 | Val F1: 0.0320 | BLEU-4: 0.4280 | ROUGE-L: 0.3322
INFO:__main__:Epoch 10/10


Training:   0%|          | 0/184 [00:00<?, ?batch/s]

INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0000
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:Epoch 10 | Loss: 2.5084 | Val F1: 0.0132 | BLEU-4: 0.4350 | ROUGE-L: 0.3462
INFO:__main__:Training completed. Best validation Macro F1: 0.0778
INFO:__main__:
Testing with adjusted thresholds...
INFO:__main__:
Rare Class Performance:
INFO:__main__:Pleural Other: 0.0000
INFO:__main__:Fracture: 0.0000
INFO:__main__:Enlarged Cardiomediastinum: 0.0000
INFO:__main__:
Generated Reports (First 5):
INFO:__main__:
Sample 1:
INFO:__main__:Generated: . _ _ _ _ _ _ _ _ _. _ _ _ _ _ _ _. _ _ _ _. _ _. _ _ _. _ _ _ _ _ _ _ _ _ _ _. _ _ _ _ _ _ _ _ _ _ _ _ _ _. _ _ _ _ _ _ _ _ _ _ _ transportation. _ _... _ _ _ : _ _ _ _ _. _ _.. _.. _ _ _ _ प _.. _ _ _. _ _ _ _. _ :...... _... Siemens Siemens geological :. _. _ _. _. _.. : : _ _ : : : adaptations vegetable Reef. tributaries :bbed......... _rants : : _ histories : _ Transylvania _ _ _.