# BrainDead-Solution: Cognitive Radiology Inference Demo

This notebook demonstrates how to use the BrainDead-Solution cognitive radiology system for chest X-ray analysis. The system generates both disease classifications and clinical reports from chest X-ray images.

## System Architecture

The system consists of three main modules:
1. **PRO-FA**: Progressive Region-based Feature Aggregation for hierarchical feature extraction
2. **MIX-MLP**: Multi-scale Interactive eXpert MLP for disease classification
3. **RCTA**: Region-aware Cognitive Text Attention for report generation

In [None]:
# Install dependencies (if needed)
# !pip install -r ../requirements.txt

# Import required libraries
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from pathlib import Path
import sys
import os

# Add project root to path
sys.path.append('..')

# Import our models
from models.encoder import PROFA
from models.classifier import MIXMLP
from models.decoder import RCTA

print("Libraries imported successfully!")

## Load Pre-trained Models

Load the three modules of the cognitive radiology system.

In [None]:
# Initialize models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# PRO-FA Encoder
encoder = PROFA(embed_dim=256, num_heads=8).to(device)

# MIX-MLP Classifier
classifier = MIXMLP(embed_dim=256, num_classes=14).to(device)

# RCTA Decoder
decoder = RCTA(
    embed_dim=256,
    num_heads=8,
    vocab_size=5000,
    max_seq_len=256,
    num_decoder_layers=2
).to(device)

# Load pretrained weights (if available)
checkpoint_path = "../checkpoints/full_system_best.pt"
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Load individual model weights
    if 'encoder' in checkpoint:
        encoder.load_state_dict(checkpoint['encoder'])
    if 'classifier' in checkpoint:
        classifier.load_state_dict(checkpoint['classifier'])
    if 'decoder' in checkpoint:
        decoder.load_state_dict(checkpoint['decoder'])

    print("Pretrained weights loaded!")
else:
    print("No pretrained weights found - using randomly initialized models")

# Set models to evaluation mode
encoder.eval()
classifier.eval()
decoder.eval()

print("Models loaded and ready for inference!")