In [None]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [None]:
from google.colab import drive

# Mount the root Google Drive directory first
drive.mount('/content/drive')

# Then, you can access your specific folder within the Drive
import os
# Create the target directory if it doesn't exist
target_dir = '/content/drive/MyDrive/IU-Xray' # Assuming your Drive path is 'MyDrive/IU-Xray'
if not os.path.exists(target_dir):
  os.makedirs(target_dir)

print(f"IU-Xray directory is mounted at: {target_dir}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
IU-Xray directory is mounted at: /content/drive/MyDrive/IU-Xray


In [None]:
!pip install rouge-score



In [None]:
# Enhanced Radiology Report Generator with Medical Context
# This implementation incorporates clinical data to improve accuracy

import os
import torch
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models, transforms
from transformers import T5Tokenizer, T5ForConditionalGeneration
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import json
import nltk
from nltk.translate.bleu_score import corpus_bleu
from rouge_score import rouge_scorer
import random
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration parameters
CONFIG = {
    'image_size': 128,
    'batch_size': 32,
    'epochs': 5,
    'learning_rate': 2e-5,
    'max_report_length': 256,
    't5_model_name': 'google/flan-t5-small',  # More powerful model for medical context
    'image_encoder': 'resnet18',
    'hidden_dim': 768,
    'seed': 42,
    # Clinical finding categories (based on CheXpert labels)
    'clinical_findings': [
        'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly',
        'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation',
        'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion',
        'Pleural Other', 'Fracture', 'Support Devices'
    ]
}

# Set random seeds
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

Using device: cpu


###########################################
# 1. Data Processing with Medical Context
###########################################

In [None]:
class EnhancedChestXRayDataset(Dataset):
    def __init__(self, image_paths, reports, clinical_data=None, tokenizer=None, transform=None, max_length=512):
        """
        Enhanced dataset that includes clinical context data

        Args:
            image_paths: List of paths to the images
            reports: List of corresponding reports
            clinical_data: DataFrame with clinical context (patient history, findings, etc.)
            tokenizer: Tokenizer for text processing
            transform: Image transformations
            max_length: Maximum token length for reports
        """
        self.image_paths = image_paths
        self.reports = reports
        self.clinical_data = clinical_data
        self.tokenizer = tokenizer
        self.transform = transform
        self.max_length = max_length

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load and transform image
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # Get clinical context if available
        clinical_context = ""
        if self.clinical_data is not None:
            # Extract study ID from image path
            study_id = os.path.basename(image_path).split('_')[0]

            # Get clinical data for this study
            patient_data = self.clinical_data.get(study_id, {})

            # Format clinical context
            if patient_data:
                age = patient_data.get('age', 'Unknown')
                sex = patient_data.get('sex', 'Unknown')
                history = patient_data.get('history', 'No history provided')

                clinical_context = f"Patient: {age} year old {sex}. History: {history}. "

        # Tokenize report with clinical context as prefix
        report = self.reports[idx]

        # Create prompt with clinical context
        prompt = f"generate radiology report: {clinical_context}"

        tokenized_input = self.tokenizer(
            prompt,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # Prepare decoder input ids (needed for T5)
        decoder_input_ids = self.tokenizer(
            "",
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids

        # Get labels (target)
        labels = self.tokenizer(
            report,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids

        return {
            'image': image,
            'input_ids': tokenized_input.input_ids.squeeze(),
            'attention_mask': tokenized_input.attention_mask.squeeze(),
            'decoder_input_ids': decoder_input_ids.squeeze(),
            'labels': labels.squeeze(),
            'report': report,
            'clinical_context': clinical_context,
            'image_path': image_path
        }

def load_indiana_cxr_data(base_dir):
    """
    Updated data loader that handles .dcm.png filenames correctly
    """
    # Path configuration
    img_dir = os.path.join(base_dir, 'images', 'images_normalized')
    reports_csv = os.path.join(base_dir, 'indiana_reports.csv')
    projections_csv = os.path.join(base_dir, 'indiana_projections.csv')

    # Load data
    reports_df = pd.read_csv(reports_csv)
    projections_df = pd.read_csv(projections_csv)

    # Merge data while keeping original filenames
    merged_df = pd.merge(
        projections_df,
        reports_df,
        on='uid',
        how='inner'
    )

    # Process images and reports
    image_paths = []
    reports_processed = []
    clinical_data = {}
    patient_studies = {}

    for idx, row in merged_df.iterrows():
        # Construct image path with .dcm.png extension
        img_file = row['filename']
        image_path = os.path.join(img_dir, img_file)

        if not os.path.exists(image_path):
            print(f"Warning: Missing image {img_file} - skipping entry")
            continue

        # Extract patient ID from filename (e.g., '9' from '9_IM-2407-1001.dcm.png')
        patient_id = img_file.split('_')[0]

        # Create report
        findings = row['findings'] if pd.notna(row['findings']) else "No findings reported"
        impression = row['impression'] if pd.notna(row['impression']) else "No impression provided"
        final_report = (
            f"Projection: {row['projection']}\n"
            f"Findings: {findings}\n"
            f"Impression: {impression}"
        )

        # Store data
        image_paths.append(image_path)
        reports_processed.append(final_report)

        # Track patient studies for splitting
        study_id = os.path.splitext(img_file)[0]  # '9_IM-2407-1001.dcm'
        clinical_data[study_id] = {
            'patient_id': patient_id,
            'age': str(row.get('patient_age', 'Unknown')),
            'sex': row.get('patient_gender', 'Unknown'),
            'history': row.get('indication', 'No history provided')
        }
        patient_studies[study_id] = patient_id

    # Create patient-based splits
    unique_patients = list(set(patient_studies.values()))

    # Ensure we have enough patients for splitting
    if len(unique_patients) < 2:
        raise ValueError(f"Only {len(unique_patients)} patients found - need at least 2 for splitting")

    train_patients, temp_patients = train_test_split(
        unique_patients,
        test_size=0.3,
        random_state=CONFIG['seed']
    )
    val_patients, test_patients = train_test_split(
        temp_patients,
        test_size=0.5,
        random_state=CONFIG['seed']
    )

    # Create splits
    def get_split(patient_ids):
        return [
            (path, report)
            for path, report, study_id in zip(image_paths, reports_processed, patient_studies.keys())
            if patient_studies[study_id] in patient_ids
        ]

    train_data = get_split(train_patients)
    val_data = get_split(val_patients)
    test_data = get_split(test_patients)

    # Unpack the data
    def unpack(data):
        return list(zip(*data)) if data else ([], [])

    train_img, train_rep = unpack(train_data)
    val_img, val_rep = unpack(val_data)
    test_img, test_rep = unpack(test_data)

    return (
        train_img, val_img, test_img,
        train_rep, val_rep, test_rep,
        clinical_data
    )

###########################################
# 2. Enhanced Model Architecture
###########################################

In [None]:
class EnhancedRadiologyReportGenerator(nn.Module):
    def __init__(self, config):
        """
        Enhanced model that incorporates clinical context and uses a multi-task approach
        for more accurate report generation

        Args:
            config: Configuration dictionary with model parameters
        """
        super(EnhancedRadiologyReportGenerator, self).__init__()

        # Image encoder (same as before)
        if config['image_encoder'] == 'resnet50':
            self.image_encoder = models.resnet50(weights='IMAGENET1K_V2')
            self.image_encoder = nn.Sequential(*list(self.image_encoder.children())[:-1])
            self.image_features_dim = 2048

        elif config['image_encoder'] == 'densenet121':
            self.image_encoder = models.densenet121(weights='IMAGENET1K_V1')
            self.image_encoder.classifier = nn.Identity()
            self.image_features_dim = 1024

        elif config['image_encoder'] == 'vit':
            self.image_encoder = models.vit_b_16(weights='IMAGENET1K_V1')
            self.image_encoder.heads = nn.Identity()
            self.image_features_dim = 768
        elif config['image_encoder'] == 'resnet18':
            self.image_encoder = models.resnet18(weights='IMAGENET1K_V1')
            # Do not convert to sequential, keep the original structure
            # self.image_encoder = nn.Sequential(*list(self.image_encoder.children())[:-1])
            # self.image_encoder = nn.Sequential(*list(self.image_encoder.children())[:-2]) # Remove the avgpool and fc layers
            self.image_encoder = nn.Sequential(*list(self.image_encoder.children())[:-1]) # Keep the avgpool layer
            self.image_features_dim = 512
        else:
            raise ValueError(f"Unknown image encoder: {config['image_encoder']}")


        # Freeze image encoder layers
        for param in self.image_encoder.parameters():
            param.requires_grad = False

        # Optional: Unfreeze last layer
        # Access layer4 directly since it's now part of the model
        # for param in self.image_encoder[-1].parameters(): # Now accessing the last block, which should be layer4
        #     param.requires_grad = True
        for param in self.image_encoder[-1].parameters(): # Now accessing AdaptiveAvgPool2d
            param.requires_grad = True

        # Image feature projection
        self.image_projection = nn.Sequential(
            nn.Linear(self.image_features_dim, config['hidden_dim']),
            nn.LayerNorm(config['hidden_dim']),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(config['hidden_dim'], config['hidden_dim'])
        )

        # T5 text generator
        self.text_decoder = T5ForConditionalGeneration.from_pretrained(config['t5_model_name'])

        # Additional multi-task components for medical accuracy
        self.finding_classifier = nn.Linear(config['hidden_dim'], len(config['clinical_findings']))

        # Cross-attention mechanism (enhanced)
        t5_config = self.text_decoder.config
        self.image_to_t5_projection = nn.Linear(
            config['hidden_dim'],
            t5_config.d_model
        )

        # Clinical context attention
        self.clinical_context_attention = nn.MultiheadAttention(
            embed_dim=config['hidden_dim'],
            num_heads=8,
            batch_first=True
        )

        # Store config
        self.config = config

    def forward(self, images, input_ids=None, attention_mask=None, decoder_input_ids=None, labels=None, finding_labels=None):
        """
        Forward pass with enhanced medical context integration

        Args:
            images: Batch of X-ray images
            input_ids: Input token IDs (including clinical context)
            attention_mask: Attention mask for input
            decoder_input_ids: Input IDs for the decoder
            labels: Target token IDs
            finding_labels: Labels for clinical findings (for multi-task learning)
        """
        batch_size = images.size(0)

        # Extract image features
        image_features = self.image_encoder(images)
        if len(image_features.shape) == 4:  # For CNNs
            image_features = image_features.squeeze(-1).squeeze(-1)

        # Project image features
        image_features = self.image_projection(image_features)

        # Multi-task prediction of findings (improves feature extraction)
        finding_preds = None
        if self.training and finding_labels is not None:
            finding_preds = self.finding_classifier(image_features)

        # Process text inputs (including clinical context)
        if self.training:
            # During training we use the full encoder-decoder with teacher forcing

            # Get T5 encoder outputs first
            encoder_outputs = self.text_decoder.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
            encoder_hidden_states = encoder_outputs.last_hidden_state

            # Project image features to T5 dimension
            image_features_t5 = self.image_to_t5_projection(image_features)

            # Add image features to encoder outputs (medical context integration)
            image_features_expanded = image_features_t5.unsqueeze(1).expand(-1, encoder_hidden_states.size(1), -1)

            # Enhanced context processing - attention-based fusion
            enhanced_encoder_states = encoder_hidden_states + image_features_expanded

            # Forward through decoder
            outputs = self.text_decoder(
                encoder_outputs=(enhanced_encoder_states,),
                decoder_input_ids=decoder_input_ids,
                labels=labels,
                return_dict=True
            )

            # Return with finding predictions if doing multi-task learning
            if finding_preds is not None:
                return outputs.loss, finding_preds
            else:
                return outputs

        else:
            # During inference we generate text
            # First encode the input prompt
            encoder_outputs = self.text_decoder.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )

            # Add image features to encoder outputs
            encoder_hidden_states = encoder_outputs.last_hidden_state

            # Project image features to T5 dimension
            image_features_t5 = self.image_to_t5_projection(image_features)

            # Add image features to encoder outputs
            image_features_expanded = image_features_t5.unsqueeze(1).expand(-1, encoder_hidden_states.size(1), -1)
            enhanced_encoder_states = encoder_hidden_states + image_features_expanded

            # Generate report using the enhanced encoder states
            generated_ids = self.text_decoder.generate(
                encoder_outputs=(enhanced_encoder_states,),
                max_length=self.config['max_report_length'],
                num_beams=5,  # More beams for better quality
                length_penalty=1.0,
                early_stopping=True,
                repetition_penalty=1.2,  # Avoid repetitive text
                use_cache=True,
                do_sample=True,  # Enable sampling for more diverse outputs
                top_p=0.9,  # Nucleus sampling
                temperature=0.7  # Temperature for sampling
            )

            return generated_ids

###########################################
# 3. Enhanced Training & Evaluation
###########################################

In [None]:
from torch.cuda.amp import autocast, GradScaler  # ADD THIS AT TOP

def train_epoch_enhanced(model, dataloader, optimizer, scheduler, device, clinical_findings=None):
    """Training with multi-task learning + mixed precision"""
    model.train()
    scaler = GradScaler()  # INITIALIZE SCALER HERE
    epoch_loss = 0
    report_loss = 0
    finding_loss = 0

    progress_bar = tqdm(dataloader, desc="Training")
    for batch in progress_bar:
        # Move batch to device
        images = batch['image'].to(device, non_blocking=True)  # ADD NON_BLOCKING
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        decoder_input_ids = batch['decoder_input_ids'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)

        # Initialize finding labels
        finding_labels = None
        if clinical_findings and 'clinical_labels' in batch:
            finding_labels = batch['clinical_labels'].to(device, non_blocking=True)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass with mixed precision -------------------------
        with autocast():  # WRAP FORWARD PASS IN AUTOCAST
            if finding_labels is not None:
                loss_reports, finding_preds = model(
                    images=images,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=decoder_input_ids,
                    labels=labels,
                    finding_labels=finding_labels
                )
                finding_loss_val = F.binary_cross_entropy_with_logits(finding_preds, finding_labels)
                loss = loss_reports + 0.2 * finding_loss_val
            else:
                outputs = model(
                    images=images,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=decoder_input_ids,
                    labels=labels
                )
                loss = outputs.loss
        # -----------------------------------------------------------

        # Backward pass with scaler --------------------------------
        scaler.scale(loss).backward()  # REPLACE loss.backward()

        # Clip gradients
        scaler.unscale_(optimizer)  # NEED TO UNSCALE BEFORE CLIPPING
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Update parameters
        scaler.step(optimizer)  # REPLACE optimizer.step()
        scaler.update()

        # Update parameters
        optimizer.step()
        if scheduler:
            scheduler.step()

        # Update progress bar
        epoch_loss += loss.item()
        progress_bar.set_postfix({
            "loss": loss.item(),
            "report_loss": report_loss / (progress_bar.n + 1),
            "finding_loss": finding_loss / (progress_bar.n + 1) if finding_labels is not None else 0
        })

    return {
        'total_loss': epoch_loss / len(dataloader),
        'report_loss': report_loss / len(dataloader),
        'finding_loss': finding_loss / len(dataloader) if clinical_findings else 0
    }

def evaluate_enhanced(model, dataloader, tokenizer, device, clinical_findings=None):
    """Enhanced evaluation with clinical accuracy metrics"""
    model.eval()
    val_loss = 0

    generated_reports = []
    reference_reports = []
    clinical_contexts = []

    # For clinical accuracy
    if clinical_findings:
        finding_preds = []
        finding_labels = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Evaluating")
        for batch in progress_bar:
            # Move batch to device
            images = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            decoder_input_ids = batch['decoder_input_ids'].to(device)
            labels = batch['labels'].to(device)

            # For clinical accuracy
            if clinical_findings and 'clinical_labels' in batch:
                batch_finding_labels = batch['clinical_labels'].to(device)
                finding_labels.append(batch_finding_labels.cpu().numpy())

            # For loss calculation
            outputs = model(
                images=images,
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                labels=labels
            )

            val_loss += outputs.loss.item()

            # Generate reports
            generated_ids = model(
                images=images,
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            # Decode the generated and reference reports
            gen_reports = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            generated_reports.extend(gen_reports)
            reference_reports.extend(batch['report'])

            if 'clinical_context' in batch:
                clinical_contexts.extend(batch['clinical_context'])

    # Calculate BLEU score
    references = [[report.split()] for report in reference_reports]
    candidates = [report.split() for report in generated_reports]
    bleu_score = corpus_bleu(references, candidates)

    # Calculate ROUGE scores
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_scores = {metric: 0.0 for metric in ['rouge1', 'rouge2', 'rougeL']}

    for gen, ref in zip(generated_reports, reference_reports):
        scores = scorer.score(ref, gen)
        for metric in rouge_scores:
            rouge_scores[metric] += scores[metric].fmeasure

    # Average ROUGE scores
    for metric in rouge_scores:
        rouge_scores[metric] /= len(generated_reports)

    # Clinical accuracy metrics
    clinical_metrics = {}
    if clinical_findings and finding_labels:
        finding_labels = np.concatenate(finding_labels, axis=0)
        finding_preds = np.concatenate(finding_preds, axis=0)
        finding_preds_binary = (finding_preds > 0.5).astype(int)

        # Calculate precision, recall, F1
        precision, recall, f1, _ = precision_recall_fscore_support(
            finding_labels, finding_preds_binary, average='macro'
        )

        accuracy = accuracy_score(finding_labels, finding_preds_binary)

        clinical_metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }

    # Sample predictions
    samples = []
    for i in range(min(5, len(generated_reports))):
        context = clinical_contexts[i] if clinical_contexts else ""
        samples.append({
            'clinical_context': context,
            'generated': generated_reports[i],
            'reference': reference_reports[i]
        })

    return {
        'loss': val_loss / len(dataloader),
        'bleu': bleu_score,
        'rouge': rouge_scores,
        'clinical_metrics': clinical_metrics,
        'samples': samples
    }

def analyze_medical_accuracy(generated_reports, reference_reports, medical_terms=None):
    """
    Analyze the medical accuracy of generated reports using medical terminology

    Args:
        generated_reports: List of generated reports
        reference_reports: List of reference reports
        medical_terms: Dictionary of important medical terms to check
    """
    if medical_terms is None:
        # Common radiology findings and terms
        medical_terms = {
            'cardiomegaly': ['cardiomegaly', 'enlarged heart', 'cardiac enlargement'],
            'edema': ['edema', 'pulmonary edema', 'fluid overload'],
            'pneumonia': ['pneumonia', 'consolidation', 'infectious process'],
            'pleural_effusion': ['pleural effusion', 'effusion', 'fluid in pleural space'],
            'pneumothorax': ['pneumothorax', 'collapsed lung'],
            'atelectasis': ['atelectasis', 'lung collapse', 'loss of lung volume'],
            'nodule': ['nodule', 'mass', 'opacity'],
            'normal': ['normal', 'no acute findings', 'no acute abnormality']
        }

    # Initialize results
    term_precision = {term: 0 for term in medical_terms}
    term_recall = {term: 0 for term in medical_terms}
    term_counts_ref = {term: 0 for term in medical_terms}
    term_counts_gen = {term: 0 for term in medical_terms}

    for gen_report, ref_report in zip(generated_reports, reference_reports):
        gen_report = gen_report.lower()
        ref_report = ref_report.lower()

        for term, synonyms in medical_terms.items():
            # Check if term in reference report
            term_in_ref = any(synonym in ref_report for synonym in synonyms)
            term_in_gen = any(synonym in gen_report for synonym in synonyms)

            if term_in_ref:
                term_counts_ref[term] += 1
                if term_in_gen:
                    term_recall[term] += 1

            if term_in_gen:
                term_counts_gen[term] += 1
                if term_in_ref:
                    term_precision[term] += 1

    # Calculate precision and recall
    precision_results = {}
    recall_results = {}
    f1_results = {}

    for term in medical_terms:
        if term_counts_gen[term] > 0:
            precision_results[term] = term_precision[term] / term_counts_gen[term]
        else:
            precision_results[term] = 0

        if term_counts_ref[term] > 0:
            recall_results[term] = term_recall[term] / term_counts_ref[term]
        else:
            recall_results[term] = 0

        # F1 score
        if precision_results[term] + recall_results[term] > 0:
            f1_results[term] = 2 * precision_results[term] * recall_results[term] / (precision_results[term] + recall_results[term])
        else:
            f1_results[term] = 0

    # Calculate macro averages
    avg_precision = sum(precision_results.values()) / len(precision_results)
    avg_recall = sum(recall_results.values()) / len(recall_results)
    avg_f1 = sum(f1_results.values()) / len(f1_results)

    # Calculate overall accuracy
    correct_negations = analyze_negations(generated_reports, reference_reports)

    return {
        'precision': precision_results,
        'recall': recall_results,
        'f1': f1_results,
        'avg_precision': avg_precision,
        'avg_recall': avg_recall,
        'avg_f1': avg_f1,
        'negation_accuracy': correct_negations
    }

def analyze_negations(generated_reports, reference_reports):
    """
    Analyze how well the model handles negations in medical reports

    Args:
        generated_reports: List of generated reports
        reference_reports: List of reference reports
    """
    # Common negation phrases
    negation_phrases = ['no', 'not', 'without', 'free of', 'absence of', 'negative for', 'clear of', 'ruled out']

    # Common medical conditions that are frequently negated
    conditions = [
        'pneumothorax', 'effusion', 'infiltrate', 'consolidation', 'edema',
        'cardiomegaly', 'pneumonia', 'nodule', 'fracture', 'mass'
    ]

    negation_accuracy = {}
    total_negations = 0
    correct_negations = 0

    for condition in conditions:
        negation_accuracy[condition] = {
            'total': 0,
            'correct': 0
        }

    for gen_report, ref_report in zip(generated_reports, reference_reports):
        gen_lower = gen_report.lower()
        ref_lower = ref_report.lower()

        for condition in conditions:
            # Check for negated conditions in reference report
            for neg in negation_phrases:
                negation_pattern = f"{neg} {condition}"

                # If negation exists in reference
                if negation_pattern in ref_lower:
                    negation_accuracy[condition]['total'] += 1
                    total_negations += 1

                    # Check if correctly negated in generated report
                    if any(f"{neg} {condition}" in gen_lower for neg in negation_phrases):
                        negation_accuracy[condition]['correct'] += 1
                        correct_negations += 1
                    # Check if incorrectly affirmed
                    elif condition in gen_lower and not any(f"{neg} {condition}" in gen_lower for neg in negation_phrases):
                        pass  # Incorrectly affirmed - don't increment correct

    # Calculate accuracy
    for condition in conditions:
        if negation_accuracy[condition]['total'] > 0:
            negation_accuracy[condition]['accuracy'] = negation_accuracy[condition]['correct'] / negation_accuracy[condition]['total']
        else:
            negation_accuracy[condition]['accuracy'] = 0

    overall_accuracy = correct_negations / total_negations if total_negations > 0 else 0

    return {
        'condition_accuracy': negation_accuracy,
        'overall_accuracy': overall_accuracy,
        'total_negations': total_negations,
        'correct_negations': correct_negations
    }

###########################################
# 4. Custom Clinical Accuracy Metrics
###########################################

In [None]:
def calculate_chexbert_score(generated_reports, reference_reports, chexbert_model_path=None):
    """
    Calculate CheXbert score for radiology report evaluation
    This is a specialized metric for measuring clinical accuracy of radiology reports

    CheXbert requires a specialized model which must be downloaded separately
    See: https://github.com/stanfordmlgroup/CheXbert

    Args:
        generated_reports: List of generated reports
        reference_reports: List of reference reports
        chexbert_model_path: Path to CheXbert model
    """

    try:
        from chexbert import CheXbertLabeler

        if chexbert_model_path is None:
            print("CheXbert model path not provided, skipping CheXbert evaluation")
            return None

        # Initialize CheXbert labeler
        labeler = CheXbertLabeler(chexbert_model_path, device=device)

        # Get labels for generated and reference reports
        gen_labels = labeler.label(generated_reports)
        ref_labels = labeler.label(reference_reports)

        # Calculate accuracy, precision, recall, F1 score for each label
        label_metrics = {}
        conditions = labeler.get_conditions()

        for i, condition in enumerate(conditions):
            true_pos = 0
            false_pos = 0
            false_neg = 0
            true_neg = -0

            for gen_label, ref_label in zip(gen_labels, ref_labels):
                # CheXbert uses 0 for negative, 1 for positive
                gen_positive = gen_label[i] == 1
                ref_positive = ref_label[i] == 1

                if gen_positive and ref_positive:
                    true_pos += 1
                elif gen_positive and not ref_positive:
                    false_pos += 1
                elif not gen_positive and ref_positive:
                    false_neg += 1
                else:
                    true_neg += 1

            # Calculate metrics
            precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) > 0 else 0
            recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            accuracy = (true_pos + true_neg) / (true_pos + true_neg + false_pos + false_neg)

            label_metrics[condition] = {
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'accuracy': accuracy
            }

        # Calculate overall metrics
        overall_metrics = {
            'precision': sum(m['precision'] for m in label_metrics.values()) / len(label_metrics),
            'recall': sum(m['recall'] for m in label_metrics.values()) / len(label_metrics),
            'f1': sum(m['f1'] for m in label_metrics.values()) / len(label_metrics),
            'accuracy': sum(m['accuracy'] for m in label_metrics.values()) / len(label_metrics)
        }

        return {
            'label_metrics': label_metrics,
            'overall': overall_metrics
        }

    except ImportError:
        print("CheXbert package not installed, skipping CheXbert evaluation")
        return None

def calculate_radgraph_score(generated_reports, reference_reports, radgraph_path=None):
    """
    Calculate RadGraph score for radiology report evaluation
    RadGraph is a specialized graph-based metric for radiology report evaluation

    See: https://github.com/ncbi-nlp/RadGraph

    Args:
        generated_reports: List of generated reports
        reference_reports: List of reference reports
        radgraph_path: Path to RadGraph model
    """
    try:
        import spacy
        from radgraph import RadGraph

        if radgraph_path is None:
            print("RadGraph model path not provided, skipping RadGraph evaluation")
            return None

        # Initialize RadGraph
        nlp = spacy.load("en_core_web_sm")
        radgraph = RadGraph(nlp)

        total_precision = 0
        total_recall = 0
        total_f1 = 0

        for gen_report, ref_report in zip(generated_reports, reference_reports):
            # Extract entities and relations from reports
            gen_graph = radgraph(gen_report)
            ref_graph = radgraph(ref_report)

            # Get entities and relationships
            gen_entities = set([(ent.text, ent.label_) for ent in gen_graph.ents])
            ref_entities = set([(ent.text, ent.label_) for ent in ref_graph.ents])

            # Get relations
            gen_relations = set([(rel[0].text, rel[1], rel[2].text) for rel in gen_graph.relations])
            ref_relations = set([(rel[0].text, rel[1], rel[2].text) for rel in ref_graph.relations])

            # Calculate entity metrics
            common_entities = gen_entities.intersection(ref_entities)
            precision_entities = len(common_entities) / len(gen_entities) if len(gen_entities) > 0 else 0
            recall_entities = len(common_entities) / len(ref_entities) if len(ref_entities) > 0 else 0

            # Calculate relation metrics
            common_relations = gen_relations.intersection(ref_relations)
            precision_relations = len(common_relations) / len(gen_relations) if len(gen_relations) > 0 else 0
            recall_relations = len(common_relations) / len(ref_relations) if len(ref_relations) > 0 else 0

            # Combined metrics
            precision = (precision_entities + precision_relations) / 2
            recall = (recall_entities + recall_relations) / 2
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

            total_precision += precision
            total_recall += recall
            total_f1 += f1

        # Average metrics
        avg_precision = total_precision / len(generated_reports)
        avg_recall = total_recall / len(generated_reports)
        avg_f1 = total_f1 / len(generated_reports)

        return {
            'precision': avg_precision,
            'recall': avg_recall,
            'f1': avg_f1
        }

    except ImportError:
        print("RadGraph package not installed, skipping RadGraph evaluation")
        return None

###########################################
# 5. Main Training Loop with Medical Context
###########################################

In [None]:
def main():
    """Main function to train and evaluate the enhanced model"""

    # Set up image transforms
    transform_train = transforms.Compose([
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    transform_val = transforms.Compose([
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Initialize tokenizer
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")

    print("Loading IU X-Ray dataset from Google Drive...")
    dataset_path = '/content/drive/MyDrive/IU-Xray'  # Updated path
    train_img_paths, val_img_paths, test_img_paths, train_reports, val_reports, test_reports, clinical_data = load_indiana_cxr_data(dataset_path)

    print(f"Training samples: {len(train_img_paths)}")
    print(f"Validation samples: {len(val_img_paths)}")
    print(f"Test samples: {len(test_img_paths)}")

    # Create datasets with clinical context
    train_dataset = EnhancedChestXRayDataset(
        train_img_paths, train_reports, clinical_data, tokenizer,
        transform=transform_train, max_length=CONFIG['max_report_length']
    )

    val_dataset = EnhancedChestXRayDataset(
        val_img_paths, val_reports, clinical_data, tokenizer,
        transform=transform_val, max_length=CONFIG['max_report_length']
    )

    test_dataset = EnhancedChestXRayDataset(
        test_img_paths, test_reports, clinical_data, tokenizer,
        transform=transform_val, max_length=CONFIG['max_report_length']
    )

    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True)
    val_dataloader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)
    test_dataloader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)

    # Initialize model
    print("Initializing enhanced model...")
    model = EnhancedRadiologyReportGenerator(CONFIG).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'])
    total_steps = len(train_dataloader) * CONFIG['epochs']
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=CONFIG['learning_rate'],
        steps_per_epoch=len(train_dataloader), epochs=CONFIG['epochs']
    )

    print("Starting training...")
    best_val_loss = float('inf')
    patience = 2
    no_improve = 0

    for epoch in range(CONFIG['epochs']):
        print(f"\nEpoch {epoch+1}/{CONFIG['epochs']}")
        train_metrics = train_epoch_enhanced(
            model, train_dataloader, optimizer, scheduler, device
        )

        print(f"Train loss: {train_metrics['total_loss']:.4f}")

        eval_results = evaluate_enhanced(
            model, val_dataloader, tokenizer, device
        )

        val_loss = eval_results['loss']
        bleu_score = eval_results['bleu']
        rouge_scores = eval_results['rouge']

        print(f"Validation loss: {val_loss:.4f}")
        print(f"BLEU score: {bleu_score:.4f}")
        print(f"ROUGE-1: {rouge_scores['rouge1']:.4f}")
        print(f"ROUGE-2: {rouge_scores['rouge2']:.4f}")
        print(f"ROUGE-L: {rouge_scores['rougeL']:.4f}")

        print("\nSample predictions:")
        for i, sample in enumerate(eval_results['samples'][:2]):
            print(f"Example {i+1}:")
            if sample['clinical_context']:
                print(f"Clinical context: {sample['clinical_context']}")
            print(f"Generated: {sample['generated']}")
            print(f"Reference: {sample['reference']}")
            print()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve = 0
            print("Saving best model...")
            torch.save(model.state_dict(), "best_model.pth")
        else:
            no_improve += 1
            if no_improve >= patience:
                print("Early stopping!")
                break

    print("\nEvaluating on test set...")
    model.load_state_dict(torch.load("best_model.pth"))
    test_results = evaluate_enhanced(model, test_dataloader, tokenizer, device)

    print(f"Test BLEU score: {test_results['bleu']:.4f}")
    print(f"Test ROUGE-L: {test_results['rouge']['rougeL']:.4f}")

    print("\nAnalyzing medical accuracy...")
    med_accuracy = analyze_medical_accuracy(
        [s['generated'] for s in test_results['samples']],
        [s['reference'] for s in test_results['samples']]
    )

    print(f"Medical term F1 score: {med_accuracy['avg_f1']:.4f}")
    print(f"Negation accuracy: {med_accuracy['negation_accuracy']['overall_accuracy']:.4f}")
    print("Training complete!")


###########################################
# 6. Dataset Details and Helper Functions
###########################################

In [None]:
!pip install huggingface_hub
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: fineG

In [None]:
!pip install huggingface_hub[hf_xet]

Collecting hf-xet>=0.1.4 (from huggingface_hub[hf_xet])
  Downloading hf_xet-1.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (494 bytes)
Downloading hf_xet-1.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (53.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 MB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: hf-xet
Successfully installed hf-xet-1.0.3


In [None]:
def get_dataset_info():
    """
    Print information about available public chest X-ray datasets
    """
    datasets = {
        "MIMIC-CXR": {
            "description": "MIMIC Chest X-ray (MIMIC-CXR) Database is a large publicly available dataset of chest radiographs with free-text radiology reports.",
            "size": "377,110 images from 227,835 radiographic studies",
            "features": "Chest X-rays with associated radiology reports, patient metadata, and CheXpert labels",
            "access": "PhysioNet Credentialed Access (requires training and data use agreement)",
            "url": "https://physionet.org/content/mimic-cxr/2.0.0/",
            "paper": "Johnson AEW, et al. MIMIC-CXR, a de-identified publicly available database of chest radiographs with free-text reports. Scientific Data (2019)."
        },
        "Indiana CXR (Open-I)": {
            "description": "Open-I Chest X-ray collection from Indiana University",
            "size": "~8,000 chest X-rays with reports",
            "features": "Images with findings, impression, and indications",
            "access": "Public access, no restrictions",
            "url": "https://openi.nlm.nih.gov/gridquery?it=xg&coll=cxr",
            "paper": "Demner-Fushman D, et al. Preparing a collection of radiology examinations for distribution and retrieval. JAMIA (2016)."
        },
        "CheXpert": {
            "description": "Large dataset of chest X-rays with labels for 14 common findings",
            "size": "224,316 chest radiographs of 65,240 patients",
            "features": "Labels for 14 observations, frontal and lateral views",
            "access": "Public access, requires form submission",
            "url": "https://stanfordmlgroup.github.io/competitions/chexpert/",
            "paper": "Irvin J, et al. CheXpert: A Large Chest Radiograph Dataset with Uncertainty Labels and Expert Comparison. AAAI (2019)."
        },
        "PadChest": {
            "description": "Large-scale chest X-ray dataset with multi-label annotations",
            "size": "160,000 images from 67,000 patients",
            "features": "Multiple views, radiographic reports and labels (extracted with NLP)",
            "access": "Public access",
            "url": "https://bimcv.cipf.es/bimcv-projects/padchest/",
            "paper": "Bustos A, et al. PadChest: A large chest x-ray image dataset with multi-label annotated reports. Medical Image Analysis (2020)."
        },
        "RSNA Pneumonia Detection": {
            "description": "Dataset for pneumonia detection from the RSNA Pneumonia Detection Challenge",
            "size": "~30,000 chest X-rays",
            "features": "Images with bounding box annotations for pneumonia",
            "access": "Public access (Kaggle)",
            "url": "https://www.kaggle.com/c/rsna-pneumonia-detection-challenge",
            "paper": "RSNA Challenge (2018)"
        }
    }

    for name, info in datasets.items():
        print(f"\n{name}:")
        for key, value in info.items():
            print(f"  {key}: {value}")

    print("\n\nRecommended dataset combinations for this project:")
    print("1. MIMIC-CXR (primary) - provides both images and associated reports")
    print("2. MIMIC-CXR + CheXpert labels - for improved clinical accuracy validation")
    print("3. Indiana CXR (Open-I) - smaller but publicly accessible without restrictions")


if __name__ == "__main__":
    main()

Loading IU X-Ray dataset from Google Drive...
Training samples: 5245
Validation samples: 1118
Test samples: 1103
Initializing enhanced model...


  scaler = GradScaler()  # INITIALIZE SCALER HERE


Starting training...

Epoch 1/5


  with autocast():  # WRAP FORWARD PASS IN AUTOCAST
Training:  40%|███▉      | 65/164 [1:11:17<1:31:18, 55.34s/it, loss=11.5, report_loss=0, finding_loss=0]Process Process-20:
Process Process-17:
Process Process-19:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.11/multiprocessing/process.py", line 317, in _bootstrap
    util._exit_function()
  File "/usr/lib/python3.11/multiprocessing/process.py", line 317, in _bootstrap
    util._exit_function()
  File "/usr/lib/python3.11/multiprocessing/process.py", line 317, in _bootstrap
    util._exit_function()
  File "/usr/lib/python3.11/multiprocessing/util.py", line 363, in _exit_function
    _run_finalizers()
  File "/usr/lib/python3.11/multiprocessing/util.py", line 363, in _exit_function
    _run_finalizers()
  File "/usr/lib/python3.11/multiprocessing/util.py", line 363, in _exit_function
    _run_finalizers()
  File "/usr/lib/python3.11/multiprocessing/ut