In [6]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import json
import os
from PIL import Image
from torchvision import transforms
from typing import Tuple, Dict, List
import glob

class ChestXrayDataset(Dataset):
    """
    PyTorch Dataset for Chest X-ray images with multi-hot encoded labels.
    
    Reads metadata from CSV file, loads images, and returns:
    - image_tensor: Preprocessed image (224x224, normalized)
    - metadata_tensor: One-hot encoded view information
    - label_tensor: Multi-hot encoded finding labels
    """
    
    def __init__(self, metadata_path: str, images_dir: str, class_to_idx_path: str = None):
        """
        Initialize the dataset.
        
        Args:
            metadata_path: Path to CSV file with metadata
            images_dir: Directory containing images
            class_to_idx_path: Path to save/load class_to_idx mapping
        """
        self.metadata_path = metadata_path
        self.images_dir = images_dir
        self.class_to_idx_path = class_to_idx_path or "../data/class_to_idx.json"
        
        # Load metadata
        self.metadata = pd.read_csv(metadata_path)
        
        # Build class_to_idx mapping if not exists
        if os.path.exists(self.class_to_idx_path):
            with open(self.class_to_idx_path, 'r') as f:
                self.class_to_idx = json.load(f)
        else:
            self.class_to_idx = self._build_class_mapping()
            self._save_class_mapping()
        
        # Build view mapping
        self.view_to_idx = self._build_view_mapping()
        
        # Image preprocessing transforms
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])  # ImageNet stats
        ])
        
        # Filter out samples where image doesn't exist
        self.valid_indices = self._filter_valid_samples()
        
    def _build_class_mapping(self) -> Dict[str, int]:
        """Build class_to_idx mapping from finding column."""
        all_classes = set()
        
        for finding in self.metadata['finding'].dropna():
            # Split on '/' for multiple diseases
            classes = [cls.strip() for cls in finding.split('/')]
            all_classes.update(classes)
        
        # Sort for consistent ordering
        sorted_classes = sorted(list(all_classes))
        return {cls: idx for idx, cls in enumerate(sorted_classes)}
    
    def _build_view_mapping(self) -> Dict[str, int]:
        """Build view_to_idx mapping from view column."""
        unique_views = sorted(self.metadata['view'].dropna().unique())
        return {view: idx for idx, view in enumerate(unique_views)}
    
    def _save_class_mapping(self):
        """Save class_to_idx mapping to JSON file."""
        os.makedirs(os.path.dirname(self.class_to_idx_path), exist_ok=True)
        with open(self.class_to_idx_path, 'w') as f:
            json.dump(self.class_to_idx, f, indent=2)
    
    def _filter_valid_samples(self) -> List[int]:
        """Filter out samples where image file doesn't exist."""
        valid_indices = []
        for idx, row in self.metadata.iterrows():
            image_path = os.path.join(self.images_dir, row['filename'])
            if os.path.exists(image_path):
                valid_indices.append(idx)
        return valid_indices
    
    def _encode_finding(self, finding: str) -> torch.Tensor:
        """Convert finding string to multi-hot encoded vector."""
        if pd.isna(finding):
            return torch.zeros(len(self.class_to_idx), dtype=torch.float32)
        
        # Split on '/' for multiple diseases
        classes = [cls.strip() for cls in finding.split('/')]
        
        # Create multi-hot vector
        label_vector = torch.zeros(len(self.class_to_idx), dtype=torch.float32)
        for cls in classes:
            if cls in self.class_to_idx:
                label_vector[self.class_to_idx[cls]] = 1.0
        
        return label_vector
    
    def _encode_view(self, view: str) -> torch.Tensor:
        """Convert view string to one-hot encoded vector."""
        if pd.isna(view):
            return torch.zeros(len(self.view_to_idx), dtype=torch.float32)
        
        view_vector = torch.zeros(len(self.view_to_idx), dtype=torch.float32)
        if view in self.view_to_idx:
            view_vector[self.view_to_idx[view]] = 1.0
        
        return view_vector
    
    def __len__(self) -> int:
        return len(self.valid_indices)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Get a sample from the dataset.
        
        Returns:
            image_tensor: Preprocessed image (3, 224, 224)
            metadata_tensor: One-hot encoded view (num_views,)
            label_tensor: Multi-hot encoded finding (num_classes,)
        """
        # Get the actual index in metadata
        actual_idx = self.valid_indices[idx]
        row = self.metadata.iloc[actual_idx]
        
        # Load and preprocess image
        image_path = os.path.join(self.images_dir, row['filename'])
        image = Image.open(image_path).convert('RGB')
        image_tensor = self.transform(image)
        
        # Encode metadata and labels
        metadata_tensor = self._encode_view(row['view'])
        label_tensor = self._encode_finding(row['finding'])
        
        return image_tensor, metadata_tensor, label_tensor
    
    def get_class_names(self) -> List[str]:
        """Get list of class names in order."""
        return [cls for cls, _ in sorted(self.class_to_idx.items(), key=lambda x: x[1])]
    
    def get_view_names(self) -> List[str]:
        """Get list of view names in order."""
        return [view for view, _ in sorted(self.view_to_idx.items(), key=lambda x: x[1])]
    
    def print_dataset_info(self):
        """Print dataset information."""
        print(f"Dataset Info:")
        print(f"  Total samples: {len(self)}")
        print(f"  Number of classes: {len(self.class_to_idx)}")
        print(f"  Number of views: {len(self.view_to_idx)}")
        print(f"  Classes: {self.get_class_names()}")
        print(f"  Views: {self.get_view_names()}")

# Test the dataset
if __name__ == "__main__":
    # Initialize dataset
    dataset = ChestXrayDataset(
        metadata_path="../data/metadata_sample_minimal.csv",
        images_dir="../data/images_sample/"
    )
    
    # Print dataset info
    dataset.print_dataset_info()
    
    # Create DataLoader
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    # Test a batch
    for batch_idx, (images, metadata, labels) in enumerate(dataloader):
        print(f"\nBatch {batch_idx + 1}:")
        print(f"  Image tensor shape: {images.shape}")
        print(f"  Metadata tensor shape: {metadata.shape}")
        print(f"  Label tensor shape: {labels.shape}")
        
        # Show some details
        print(f"  Image dtype: {images.dtype}")
        print(f"  Metadata dtype: {metadata.dtype}")
        print(f"  Label dtype: {labels.dtype}")
        
        # Show sample labels
        print(f"  Sample label vectors:")
        for i in range(min(2, labels.shape[0])):
            active_classes = torch.nonzero(labels[i]).squeeze()
            if active_classes.numel() > 0:
                class_names = [dataset.get_class_names()[idx.item()] for idx in active_classes]
                print(f"    Sample {i}: {class_names}")
            else:
                print(f"    Sample {i}: No active classes")
        
        # Only show first batch for testing
        break
    
    print("\nDataset test completed successfully!")


Dataset Info:
  Total samples: 116
  Number of classes: 28
  Number of views: 7
  Classes: ['Aspergillosis', 'Aspiration', 'Bacterial', 'COVID-19', 'Chlamydophila', 'E.Coli', 'Fungal', 'H1N1', 'Herpes', 'Influenza', 'Klebsiella', 'Legionella', 'Lipoid', 'MERS-CoV', 'MRSA', 'Mycoplasma', 'No Finding', 'Nocardia', 'Pneumocystis', 'Pneumonia', 'SARS', 'Staphylococcus', 'Streptococcus', 'Tuberculosis', 'Unknown', 'Varicella', 'Viral', 'todo']
  Views: ['AP', 'AP Erect', 'AP Supine', 'Axial', 'Coronal', 'L', 'PA']

Batch 1:
  Image tensor shape: torch.Size([4, 3, 224, 224])
  Metadata tensor shape: torch.Size([4, 7])
  Label tensor shape: torch.Size([4, 28])
  Image dtype: torch.float32
  Metadata dtype: torch.float32
  Label dtype: torch.float32
  Sample label vectors:
    Sample 0: ['COVID-19', 'Pneumonia', 'Viral']
    Sample 1: ['COVID-19', 'Pneumonia', 'Viral']

Dataset test completed successfully!


In [7]:
import torch
import torch.nn as nn
from torchvision import models

class ImageEncoder(nn.Module):
    """
    Image encoder using EfficientNet-B0 pretrained model.
    
    Features:
    - Loads pretrained EfficientNet-B0
    - Freezes all backbone parameters
    - Replaces classifier with Identity to output 1280-dim features
    - Input: (B, 3, 224, 224), Output: (B, 1280)
    """
    
    def __init__(self, pretrained: bool = True):
        """
        Initialize the ImageEncoder.
        
        Args:
            pretrained: Whether to load pretrained weights
        """
        super(ImageEncoder, self).__init__()
        
        # Load pretrained EfficientNet-B0
        self.backbone = models.efficientnet_b0(pretrained=pretrained)
        
        # Freeze all parameters in the backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        # Replace the classifier with Identity to output raw features
        # EfficientNet-B0 classifier outputs 1000 classes, we want 1280 features
        self.backbone.classifier = nn.Identity()
        
        # Get the feature dimension (1280 for EfficientNet-B0)
        self.feature_dim = self.backbone.classifier.in_features if hasattr(self.backbone.classifier, 'in_features') else 1280
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the encoder.
        
        Args:
            x: Input tensor of shape (B, 3, 224, 224)
            
        Returns:
            Feature tensor of shape (B, 1280)
        """
        return self.backbone(x)
    
    def get_feature_dim(self) -> int:
        """Get the output feature dimension."""
        return self.feature_dim

# Test the ImageEncoder
if __name__ == "__main__":
    # Create encoder
    encoder = ImageEncoder(pretrained=True)
    
    print("ImageEncoder initialized successfully!")
    print(f"Feature dimension: {encoder.get_feature_dim()}")
    
    # Test with random batch
    batch_size = 4
    random_batch = torch.randn(batch_size, 3, 224, 224)
    
    print(f"\nInput shape: {random_batch.shape}")
    
    # Forward pass
    with torch.no_grad():
        output = encoder(random_batch)
    
    print(f"Output shape: {output.shape}")
    print(f"Expected shape: [{batch_size}, 1280]")
    
    # Verify the output shape
    expected_shape = (batch_size, 1280)
    if output.shape == expected_shape:
        print("✅ Output shape is correct!")
    else:
        print("❌ Output shape is incorrect!")
    
    # Show some statistics
    print(f"\nOutput statistics:")
    print(f"  Mean: {output.mean().item():.4f}")
    print(f"  Std: {output.std().item():.4f}")
    print(f"  Min: {output.min().item():.4f}")
    print(f"  Max: {output.max().item():.4f}")
    
    # Check if parameters are frozen
    frozen_params = sum(1 for param in encoder.backbone.parameters() if not param.requires_grad)
    total_params = sum(1 for param in encoder.backbone.parameters())
    print(f"\nParameter status:")
    print(f"  Frozen parameters: {frozen_params}/{total_params}")
    print(f"  Trainable parameters: {total_params - frozen_params}/{total_params}")
    
    print("\nImageEncoder test completed successfully!")




ImageEncoder initialized successfully!
Feature dimension: 1280

Input shape: torch.Size([4, 3, 224, 224])
Output shape: torch.Size([4, 1280])
Expected shape: [4, 1280]
✅ Output shape is correct!

Output statistics:
  Mean: 0.1566
  Std: 0.2077
  Min: -0.1969
  Max: 1.0907

Parameter status:
  Frozen parameters: 211/211
  Trainable parameters: 0/211

ImageEncoder test completed successfully!


In [8]:
import torch
import torch.nn as nn

class MetadataEncoder(nn.Module):
    """
    Metadata encoder using a simple feedforward network.
    
    Features:
    - Input: (B, 7) float tensor (normalized metadata)
    - Two-layer feedforward network with ReLU activations
    - Output: (B, 128) metadata embedding
    """
    
    def __init__(self, input_dim: int = 7, hidden_dim: int = 64, output_dim: int = 128):
        """
        Initialize the MetadataEncoder.
        
        Args:
            input_dim: Input feature dimension (default: 7)
            hidden_dim: Hidden layer dimension (default: 64)
            output_dim: Output embedding dimension (default: 128)
        """
        super(MetadataEncoder, self).__init__()
        
        # Define the feedforward network
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.ReLU()
        )
        
        # Store dimensions for reference
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the encoder.
        
        Args:
            x: Input tensor of shape (B, 7)
            
        Returns:
            Metadata embedding tensor of shape (B, 128)
        """
        return self.encoder(x)
    
    def get_input_dim(self) -> int:
        """Get the input dimension."""
        return self.input_dim
    
    def get_output_dim(self) -> int:
        """Get the output dimension."""
        return self.output_dim
    
    def get_hidden_dim(self) -> int:
        """Get the hidden dimension."""
        return self.hidden_dim

# Test the MetadataEncoder
if __name__ == "__main__":
    # Create encoder
    encoder = MetadataEncoder()
    
    print("MetadataEncoder initialized successfully!")
    print(f"Input dimension: {encoder.get_input_dim()}")
    print(f"Hidden dimension: {encoder.get_hidden_dim()}")
    print(f"Output dimension: {encoder.get_output_dim()}")
    
    # Test with random batch
    batch_size = 4
    random_batch = torch.randn(batch_size, 7)
    
    print(f"\nInput shape: {random_batch.shape}")
    
    # Forward pass
    with torch.no_grad():
        output = encoder(random_batch)
    
    print(f"Output shape: {output.shape}")
    print(f"Expected shape: [{batch_size}, 128]")
    
    # Verify the output shape
    expected_shape = (batch_size, 128)
    if output.shape == expected_shape:
        print("✅ Output shape is correct!")
    else:
        print("❌ Output shape is incorrect!")
    
    # Show some statistics
    print(f"\nOutput statistics:")
    print(f"  Mean: {output.mean().item():.4f}")
    print(f"  Std: {output.std().item():.4f}")
    print(f"  Min: {output.min().item():.4f}")
    print(f"  Max: {output.max().item():.4f}")
    
    # Check parameter count
    total_params = sum(p.numel() for p in encoder.parameters())
    trainable_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
    print(f"\nParameter count:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    
    # Show layer details
    print(f"\nNetwork architecture:")
    for i, layer in enumerate(encoder.encoder):
        if isinstance(layer, nn.Linear):
            print(f"  Layer {i}: Linear({layer.in_features} → {layer.out_features})")
        elif isinstance(layer, nn.ReLU):
            print(f"  Layer {i}: ReLU()")
    
    print("\nMetadataEncoder test completed successfully!")


MetadataEncoder initialized successfully!
Input dimension: 7
Hidden dimension: 64
Output dimension: 128

Input shape: torch.Size([4, 7])
Output shape: torch.Size([4, 128])
Expected shape: [4, 128]
✅ Output shape is correct!

Output statistics:
  Mean: 0.0814
  Std: 0.1235
  Min: 0.0000
  Max: 0.6653

Parameter count:
  Total parameters: 8,832
  Trainable parameters: 8,832

Network architecture:
  Layer 0: Linear(7 → 64)
  Layer 1: ReLU()
  Layer 2: Linear(64 → 128)
  Layer 3: ReLU()

MetadataEncoder test completed successfully!


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FusionClassifier(nn.Module):
    """
    Fusion classifier that combines image and metadata features using:
    - Concatenation
    - Gating mechanism
    - Attention mechanism
    - Feedforward classifier
    """
    
    def __init__(self, img_feat_dim: int = 1280, meta_feat_dim: int = 128, 
                 num_classes: int = 28, hidden_dim: int = 512, 
                 attention_dim: int = 256, dropout: float = 0.3):
        """
        Initialize the FusionClassifier.
        
        Args:
            img_feat_dim: Image feature dimension (default: 1280)
            meta_feat_dim: Metadata feature dimension (default: 128)
            num_classes: Number of output classes (default: 28)
            hidden_dim: Hidden layer dimension (default: 512)
            attention_dim: Attention projection dimension (default: 256)
            dropout: Dropout rate (default: 0.3)
        """
        super(FusionClassifier, self).__init__()
        
        self.img_feat_dim = img_feat_dim
        self.meta_feat_dim = meta_feat_dim
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.attention_dim = attention_dim
        
        # Gating mechanism
        self.gate = nn.Linear(meta_feat_dim, img_feat_dim)
        
        # Attention mechanism - project to common space
        self.img_proj = nn.Linear(img_feat_dim, attention_dim)
        self.meta_proj = nn.Linear(meta_feat_dim, attention_dim)
        
        # Fusion dimension after concatenation
        fusion_dim = img_feat_dim + meta_feat_dim + attention_dim  # 1280 + 128 + 256 = 1664
        
        # Feedforward classifier
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(self, img_feat: torch.Tensor, meta_feat: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the fusion classifier.
        
        Args:
            img_feat: Image features of shape (B, 1280)
            meta_feat: Metadata features of shape (B, 128)
            
        Returns:
            Logits of shape (B, 28)
        """
        batch_size = img_feat.size(0)
        
        # 1. Gating mechanism
        gate = torch.sigmoid(self.gate(meta_feat))  # (B, 1280)
        gated_img = img_feat * gate  # (B, 1280)
        
        # 2. Attention mechanism
        # Project both features to common space
        img_proj = self.img_proj(gated_img)  # (B, 256)
        meta_proj = self.meta_proj(meta_feat)  # (B, 256)
        
        # Compute attention weights
        attention_weights = F.softmax(torch.sum(img_proj * meta_proj, dim=1, keepdim=True), dim=0)  # (B, 1)
        
        # Apply attention to image features
        attended_img = img_proj * attention_weights  # (B, 256)
        
        # 3. Concatenation fusion
        fusion_vector = torch.cat([gated_img, meta_feat, attended_img], dim=1)  # (B, 1280 + 128 + 256)
        
        # 4. Feedforward classifier
        logits = self.classifier(fusion_vector)  # (B, 28)
        
        return logits
    
    def get_fusion_dim(self) -> int:
        """Get the fusion vector dimension."""
        return self.img_feat_dim + self.meta_feat_dim + self.attention_dim
    
    def get_num_classes(self) -> int:
        """Get the number of output classes."""
        return self.num_classes

# Test the FusionClassifier
if __name__ == "__main__":
    # Create fusion classifier
    fusion_model = FusionClassifier()
    
    print("FusionClassifier initialized successfully!")
    print(f"Image feature dimension: {fusion_model.img_feat_dim}")
    print(f"Metadata feature dimension: {fusion_model.meta_feat_dim}")
    print(f"Attention dimension: {fusion_model.attention_dim}")
    print(f"Fusion dimension: {fusion_model.get_fusion_dim()}")
    print(f"Number of classes: {fusion_model.get_num_classes()}")
    
    # Test with dummy data
    batch_size = 4
    img_feat = torch.randn(batch_size, 1280)
    meta_feat = torch.randn(batch_size, 128)
    
    print(f"\nInput shapes:")
    print(f"  Image features: {img_feat.shape}")
    print(f"  Metadata features: {meta_feat.shape}")
    
    # Forward pass
    with torch.no_grad():
        output = fusion_model(img_feat, meta_feat)
    
    print(f"\nOutput shape: {output.shape}")
    print(f"Expected shape: [{batch_size}, 28]")
    
    # Verify the output shape
    expected_shape = (batch_size, 28)
    if output.shape == expected_shape:
        print("✅ Output shape is correct!")
    else:
        print("❌ Output shape is incorrect!")
    
    # Show some statistics
    print(f"\nOutput statistics:")
    print(f"  Mean: {output.mean().item():.4f}")
    print(f"  Std: {output.std().item():.4f}")
    print(f"  Min: {output.min().item():.4f}")
    print(f"  Max: {output.max().item():.4f}")
    
    # Check parameter count
    total_params = sum(p.numel() for p in fusion_model.parameters())
    trainable_params = sum(p.numel() for p in fusion_model.parameters() if p.requires_grad)
    print(f"\nParameter count:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    
    # Show layer details
    print(f"\nNetwork architecture:")
    print("  Gating mechanism:")
    print(f"    Gate: Linear({fusion_model.meta_feat_dim} → {fusion_model.img_feat_dim})")
    print("  Attention mechanism:")
    print(f"    Image projection: Linear({fusion_model.img_feat_dim} → {fusion_model.attention_dim})")
    print(f"    Metadata projection: Linear({fusion_model.meta_feat_dim} → {fusion_model.attention_dim})")
    print("  Classifier:")
    for i, layer in enumerate(fusion_model.classifier):
        if isinstance(layer, nn.Linear):
            print(f"    Layer {i}: Linear({layer.in_features} → {layer.out_features})")
        elif isinstance(layer, nn.ReLU):
            print(f"    Layer {i}: ReLU()")
        elif isinstance(layer, nn.Dropout):
            print(f"    Layer {i}: Dropout({layer.p})")
    
    print("\nFusionClassifier test completed successfully!")


FusionClassifier initialized successfully!
Image feature dimension: 1280
Metadata feature dimension: 128
Attention dimension: 256
Fusion dimension: 1664
Number of classes: 28

Input shapes:
  Image features: torch.Size([4, 1280])
  Metadata features: torch.Size([4, 128])

Output shape: torch.Size([4, 28])
Expected shape: [4, 28]
✅ Output shape is correct!

Output statistics:
  Mean: 0.0064
  Std: 0.1564
  Min: -0.4314
  Max: 0.3584

Parameter count:
  Total parameters: 1,392,924
  Trainable parameters: 1,392,924

Network architecture:
  Gating mechanism:
    Gate: Linear(128 → 1280)
  Attention mechanism:
    Image projection: Linear(1280 → 256)
    Metadata projection: Linear(128 → 256)
  Classifier:
    Layer 0: Linear(1664 → 512)
    Layer 1: ReLU()
    Layer 2: Dropout(0.3)
    Layer 3: Linear(512 → 28)

FusionClassifier test completed successfully!
