# Chest X-ray Classification and Localization using DenseNet

This notebook implements a transfer learning approach using DenseNet for chest X-ray classification and localization on the NIH Chest X-ray dataset.

In [1]:
!pip install torch torchvision pandas numpy matplotlib tqdm scikit-learn

Collecting torch
  Downloading torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting pandas
  Downloading pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpy
  Downloading numpy-2.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting matplotlib
  Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.7/57.7 kB[0m [31m30.2 MB/s[0m

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision import models
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

In [3]:
def set_all_seeds(seed=10):
    import os, random, numpy as np, torch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
set_all_seeds(10)

## Data Loading and Preprocessing

In [103]:
IMAGES_PATH_RESIZED = os.path.join('resized_images_20k')
PREPROCESSED_IMAGES_PATH = os.path.join('bbox_resized_filtered_images_20k_labled.csv')
df_preprocessed = pd.read_csv(PREPROCESSED_IMAGES_PATH)

In [6]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class ChestXrayDataset(Dataset):
    def __init__(self, image_dir, df, transform=None):
        self.image_dir = image_dir
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['Image Index']
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        # Get labels (binary classification: normal vs abnormal)
        label = 1 if self.df.iloc[idx]['Finding Label'] != 'No Finding' else 0
        return image, torch.tensor(label, dtype=torch.float32)

Using device: cuda


In [7]:
import torch
print(torch.cuda.is_available())  # Should return True
print(torch.cuda.device_count())  # Should return > 0
print(torch.cuda.get_device_name(0))  # Should return the name of the GPU

True
1
NVIDIA A40


In [104]:
def get_transforms():
    """Define image transformations."""
    return {
        'train': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=(-10, 10)),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

g = torch.Generator()
g.manual_seed(10)

# Split data and create data loaders
train_df, temp_df = train_test_split(df_preprocessed, test_size=0.3, random_state=10)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=10)

# Get transforms
transforms_dict = get_transforms()


In [44]:
# Create datasets
train_dataset = ChestXrayDataset(IMAGES_PATH_RESIZED, train_df, transform=transforms_dict['train'])
val_dataset = ChestXrayDataset(IMAGES_PATH_RESIZED, val_df, transform=transforms_dict['val'])
test_dataset = ChestXrayDataset(IMAGES_PATH_RESIZED, test_df, transform=transforms_dict['val'])

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")



Training samples: 13839
Validation samples: 2966
Test samples: 2966


## Model Training

In [111]:
class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attn_map = self.sigmoid(self.conv(x))  # Generate attention map
        return x * attn_map, attn_map

class ResNetWithAttention(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.5):  # Changed default to 14 for NIH dataset
        super().__init__()
        # Load pre-trained ResNet18 instead of DenseNet
        resnet = models.resnet50(pretrained=True)
        
        # Remove the final fully connected layer
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        
        # Unfreeze more layers for fine-tuning
        for param in self.features.parameters():
            param.requires_grad = False
            
        # Unfreeze the last two blocks
        for param in list(self.features.children())[-3].parameters():
            param.requires_grad = True
        for param in list(self.features.children())[-2].parameters():
            param.requires_grad = True
        
        self.attention = SpatialAttention(2048)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Simpler classifier with less dropout
        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes),
            nn.Sigmoid()  # Keep sigmoid for multi-label
        )

    def forward(self, x):
        x = self.features(x)
        x, attn_map = self.attention(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        return logits, attn_map

In [112]:
class AttentionMemory:
    def __init__(self, memory_size=5, alpha=0.7):
        """
        Args:
            memory_size: Number of previous epochs to consider
            alpha: Exponential moving average factor (higher = more weight to recent)
        """
        #Keep the last 5 attention maps per image
        self.memory_size = memory_size
        #Controls how much more we trust recent maps (lower = longer memory)
        self.alpha = alpha
        self.attention_history = {}  # Maps image_id -> list of attention maps
        self.correct_predictions = {}  # Maps image_id -> list of correctness flags
    
    def update(self, image_ids, attention_maps, predictions, labels):
        """Update memory with new attention maps"""
        for i, img_id in enumerate(image_ids):
            # Get current attention map and whether prediction was correct
            attn = attention_maps[i].detach().cpu()
            correct = (predictions[i].round() == labels[i]).all().float()
            
            # Initialize if first time seeing this image
            if img_id not in self.attention_history:
                self.attention_history[img_id] = []
                self.correct_predictions[img_id] = []
            
            # Add new data
            self.attention_history[img_id].append(attn)
            self.correct_predictions[img_id].append(correct.item())
            
            # Keep only most recent entries
            if len(self.attention_history[img_id]) > self.memory_size:
                self.attention_history[img_id].pop(0)
                self.correct_predictions[img_id].pop(0)
    
    def get_historical_attention(self, image_ids):
        """Get exponentially weighted historical attention maps"""
        batch_history = []
        
        for img_id in image_ids:
            if img_id not in self.attention_history or not self.attention_history[img_id]:
                # No history available
                batch_history.append(None)
                continue
            
            # Get history for this image
            history = self.attention_history[img_id]
            correctness = self.correct_predictions[img_id]
            
            if len(history) == 1:
                # Only one entry in history
                batch_history.append(history[0])
                continue
            
            # Calculate weighted average, giving more weight to:
            # 1. More recent attention maps
            # 2. Attention maps from correct predictions
            weights = []
            for i, is_correct in enumerate(correctness):
                # Position weight (more recent = higher weight)
                pos_weight = self.alpha ** (len(correctness) - i - 1)
                # Correctness weight (correct predictions get higher weight)
                correct_weight = 1.2 if is_correct else 0.8
                weights.append(pos_weight * correct_weight)
            
            # Normalize weights
            weights = [w / sum(weights) for w in weights]
            
            # Calculate weighted attention
            weighted_attn = torch.zeros_like(history[0])
            for i, attn in enumerate(history):
                weighted_attn += weights[i] * attn
            
            batch_history.append(weighted_attn)
        
        return batch_history

In [113]:
class SelfCorrectiveAttentionLoss(nn.Module):
    def __init__(self, lambda_consistency, lambda_sparsity):
        super(SelfCorrectiveAttentionLoss, self).__init__()
        self.cls_loss = nn.BCELoss()  # Keep BCE for multi-label
        self.lambda_consistency = lambda_consistency
        self.lambda_sparsity = lambda_sparsity
    
    def forward(self, pred, target, attn_map, historical_attn=None):
        # Classification loss (primary objective)
        cls_loss = self.cls_loss(pred, target)
        
        # Initialize attention losses
        consistency_loss = torch.tensor(0.0).to(pred.device)
        sparsity_loss = torch.tensor(0.0).to(pred.device)
        
        # Simpler sparsity loss
        attn_flat = attn_map.view(attn_map.size(0), -1)
        sparsity_loss = -torch.mean(torch.sum(attn_flat * torch.log(attn_flat + 1e-8), dim=1)) / 1000
        
        # Consistency loss (if historical attention is available)
        if historical_attn is not None:
            valid_indices = [i for i, h in enumerate(historical_attn) if h is not None]
            
            if valid_indices:
                valid_hist = torch.stack([historical_attn[i] for i in valid_indices]).to(pred.device)
                valid_curr = attn_map[valid_indices]
                
                # For multi-label, consider prediction correct if all labels match
                correct_mask = (pred.round() == target).all(dim=1).float()[valid_indices].view(-1, 1, 1)
                incorrect_mask = 1 - correct_mask
                
                consistency_term = torch.abs(valid_curr - valid_hist) * correct_mask
                correction_term = torch.exp(-torch.abs(valid_curr - valid_hist)) * incorrect_mask
                
                consistency_loss = torch.mean(consistency_term + correction_term)
        
        # Total loss with reduced weights
        total_loss = cls_loss + self.lambda_consistency * consistency_loss + self.lambda_sparsity * sparsity_loss
        
        return total_loss, cls_loss, consistency_loss, sparsity_loss

In [54]:
def visualize_attention_maps(model, data_loader, num_samples=5, save_dir='attention_maps'):
    """
    Visualize attention maps for a few samples
    
    Args:
        model: Trained model
        data_loader: DataLoader for the dataset
        num_samples: Number of samples to visualize
        save_dir: Directory to save visualizations
    """
    os.makedirs(save_dir, exist_ok=True)
    model.eval()
    
    # Get a few samples
    samples = []
    for images, labels, img_ids in data_loader:
        for i in range(min(len(images), num_samples - len(samples))):
            samples.append((images[i], labels[i], img_ids[i]))
        if len(samples) >= num_samples:
            break
    
    with torch.no_grad():
        for i, (image, label, img_id) in enumerate(samples):
            # Get model prediction and attention map
            image = image.unsqueeze(0).to(device)
            logit, attn_map = model(image)
            pred = torch.sigmoid(logit).item()
            
            # Convert to numpy for visualization
            image_np = image.squeeze().cpu().numpy().transpose(1, 2, 0)
            # Denormalize image
            image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            image_np = np.clip(image_np, 0, 1)
            
            # Get attention map
            attn_map_np = attn_map.squeeze().cpu().numpy()
            
            # Create visualization
            plt.figure(figsize=(12, 5))
            
            # Original image
            plt.subplot(1, 3, 1)
            plt.imshow(image_np)
            plt.title(f"Original Image\nTrue: {label.item()}")
            plt.axis('off')
            
            # Attention map
            plt.subplot(1, 3, 2)
            plt.imshow(attn_map_np, cmap='hot')
            plt.title(f"Attention Map\nPred: {pred:.4f}")
            plt.axis('off')
            
            # Overlay
            plt.subplot(1, 3, 3)
            plt.imshow(image_np)
            plt.imshow(attn_map_np, cmap='hot', alpha=0.5)
            plt.title("Overlay")
            plt.axis('off')
            
            plt.tight_layout()
            plt.savefig(f"{save_dir}/attention_{img_id.replace('.', '_')}.png")
            plt.close()

In [123]:

def train_with_self_correction(model, train_loader, val_loader, num_epochs=15):
    """
    Train the model with self-corrective attention 
    """
    """
    Train the model with self-corrective attention 
    """
    # Initialize attention memory
    memory = AttentionMemory(memory_size=3, alpha=0.7)
    
    # Loss and optimizer
    criterion = SelfCorrectiveAttentionLoss(lambda_consistency=0, lambda_sparsity=0.001)

    optimizer = optim.AdamW([
        {'params': list(model.features.children())[-2].parameters(), 'lr': 1e-4},
        {'params': model.attention.parameters(), 'lr': 2e-4},
        {'params': model.classifier.parameters(), 'lr': 2e-4}
    ], weight_decay=1e-4) 

    # Learning rate scheduler with more patience
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.7, patience=3, verbose=True
    )
        
    # Track metrics
    best_val_loss = float('inf')
    history = {
        'train_loss': [], 'train_acc': [], 'train_auc': [],
        'val_loss': [], 'val_acc': [], 'val_auc': [],
        'cls_loss': [], 'consistency_loss': [], 'sparsity_loss': []
    }
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        epoch_cls_loss = 0.0
        epoch_consistency_loss = 0.0
        epoch_sparsity_loss = 0.0
        
        # For AUC calculation
        train_preds = []
        train_labels = []
        
        for images, labels, img_ids in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            images = images.to(device)
            labels = labels.to(device)
            
            # Get historical attention maps for batch
            historical_attn = memory.get_historical_attention(img_ids)
            
            # Forward pass
            optimizer.zero_grad()
            logits, attn_maps = model(images)
            
            # Calculate loss
            loss, cls_loss, consistency_loss, sparsity_loss = criterion(
                logits, labels, attn_maps, historical_attn
            )
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Track metrics
            train_loss += loss.item()
            epoch_cls_loss += cls_loss.item()
            epoch_consistency_loss += consistency_loss.item()
            epoch_sparsity_loss += sparsity_loss.item()
            
            # Calculate accuracy (all labels must match)
            preds = (logits > 0.5).float()
            train_correct += (preds == labels).all(dim=1).sum().item()

            # Store predictions and labels for AUC
            train_preds.append(logits.detach().cpu().numpy())
            train_labels.append(labels.cpu().numpy())
            
            # Update attention memory
            memory.update(img_ids, attn_maps, logits, labels)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0

        # For AUC calculation
        val_preds = []
        val_labels = []
        
        with torch.no_grad():
            for images, labels, img_ids in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                # Get historical attention
                historical_attn = memory.get_historical_attention(img_ids)
                
                # Forward pass
                logits, attn_maps = model(images)
                
                # Calculate loss
                loss, _, _, _ = criterion(logits, labels, attn_maps, historical_attn)
                
                # Track metrics
                val_loss += loss.item()
                preds = (logits > 0.5).float()
                val_correct += (preds == labels).all(dim=1).sum().item()

                # Store predictions and labels for AUC
                val_preds.append(logits.cpu().numpy())
                val_labels.append(labels.cpu().numpy())
        
        # Calculate epoch metrics
        avg_train_loss = train_loss / len(train_loader)
        avg_cls_loss = epoch_cls_loss / len(train_loader)
        avg_consistency_loss = epoch_consistency_loss / len(train_loader)
        avg_sparsity_loss = epoch_sparsity_loss / len(train_loader)
        train_accuracy = 100 * train_correct / len(train_loader.dataset)

        # Convert lists to numpy arrays for AUC calculation
        train_preds = np.concatenate(train_preds, axis=0)
        train_labels = np.concatenate(train_labels, axis=0)
        val_preds = np.concatenate(val_preds, axis=0)
        val_labels = np.concatenate(val_labels, axis=0)
        
        # Calculate per-label AUC and average
        train_aucs = [roc_auc_score(train_labels[:, i], train_preds[:, i]) 
                     for i in range(train_preds.shape[1])]
        train_auc = np.mean(train_aucs)
        
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * val_correct / len(val_loader.dataset)
        
        # Calculate per-label AUC and average for validation
        val_aucs = [roc_auc_score(val_labels[:, i], val_preds[:, i]) 
                   for i in range(val_preds.shape[1])]
        val_auc = np.mean(val_aucs)
        
        # Update history
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(train_accuracy)
        history['train_auc'].append(train_auc)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(val_accuracy)
        history['val_auc'].append(val_auc)
        history['cls_loss'].append(avg_cls_loss)
        history['consistency_loss'].append(avg_consistency_loss)
        history['sparsity_loss'].append(avg_sparsity_loss)
        
        # Print metrics
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {avg_train_loss:.4f} (Cls: {avg_cls_loss:.4f}, ' + 
              f'Consistency: {avg_consistency_loss:.4f}, Sparsity: {avg_sparsity_loss:.4f})')
        print(f'Train Accuracy: {train_accuracy:.2f}%, AUC: {train_auc:.4f}')
        print(f'Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%, AUC: {val_auc:.4f}')
        
        # Update scheduler
        scheduler.step(val_auc)  # Use AUC for scheduling
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model_self_corrective.pth')
            print('Model saved!')
        
        print('-' * 60)

        #if epoch % 2 == 0:  # Save every 2 epochs to reduce storage
        #    print("Visualizing attention maps...")
        #    visualize_attention_maps(model, train_loader, 
        #                            save_dir=f'attention_maps/epoch_{epoch+1}/train')
        #    visualize_attention_maps(model, val_loader, 
        #                            save_dir=f'attention_maps/epoch_{epoch+1}/val')

    
    plt.figure(figsize=(15, 10))
    
    # Plot 1: Loss curves
    plt.subplot(2, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.ylim(0, 1)
    plt.legend()
    
    # Plot 2: Accuracy curves
    plt.subplot(2, 2, 2)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Accuracy Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.ylim(0, 100)
    plt.legend()
    
    # Plot 3: AUC curves
    plt.subplot(2, 2, 3)
    plt.plot(history['train_auc'], label='Train AUC')
    plt.plot(history['val_auc'], label='Validation AUC')
    plt.title('AUC Curves')
    plt.xlabel('Epoch')
    plt.ylabel('AUC Score')
    plt.ylim(0, 1)
    plt.legend()
    
    # Plot 4: Component losses
    plt.subplot(2, 2, 4)
    plt.plot(history['cls_loss'], label='Classification Loss')
    plt.plot(history['consistency_loss'], label='Consistency Loss')
    plt.plot(history['sparsity_loss'], label='Sparsity Loss')
    plt.title('Component Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.ylim(0, 1)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.close()
    
    return model, history

In [124]:
class ChestXrayDatasetWithIDs(Dataset):
    def __init__(self, image_dir, df, transform=None):
        self.image_dir = image_dir
        self.df = df
        self.transform = transform
        
        # Create label encoder for the Finding Label column
        self.label_encoder = LabelEncoder()
        self.label_encoder.fit(df['Finding Label'])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['Image Index']
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        # Get label and convert to one-hot encoding
        label = self.df.iloc[idx]['Finding Label']
        label_idx = self.label_encoder.transform([label])[0]
        num_classes = len(self.label_encoder.classes_)
        label_onehot = torch.zeros(num_classes, dtype=torch.float32)
        label_onehot[label_idx] = 1.0
        
        # Return image ID along with image and labels
        return image, label_onehot, img_name

# Create datasets with image IDs
train_dataset = ChestXrayDatasetWithIDs(IMAGES_PATH_RESIZED, train_df, transform=transforms_dict['train'])
val_dataset = ChestXrayDatasetWithIDs(IMAGES_PATH_RESIZED, val_df, transform=transforms_dict['val'])
test_dataset = ChestXrayDatasetWithIDs(IMAGES_PATH_RESIZED, test_df, transform=transforms_dict['val'])

batch_size = 32
# Create data loaders
g = torch.Generator()
g.manual_seed(10)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, generator=g)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Number of classes: {len(train_dataset.label_encoder.classes_)}")
print("Classes:", train_dataset.label_encoder.classes_)

Training samples: 13839
Validation samples: 2966
Test samples: 2966
Number of classes: 10
Classes: ['Atelectasis' 'Cardiomegaly' 'Effusion' 'Infiltrate' 'Infiltration'
 'Mass' 'No Finding' 'Nodule' 'Pneumonia' 'Pneumothorax']


In [None]:
# Create model
num_classes = len(train_dataset.label_encoder.classes_)
model = ResNetWithAttention(num_classes=num_classes).to(device)

# model = ResNetWithAttention().to(device)

# Train with self-correction
trained_model, history = train_with_self_correction(model, train_loader, val_loader, num_epochs=10)

Epoch 1/10: 100%|██████████| 433/433 [00:29<00:00, 14.45it/s]


Epoch 1/10:
Train Loss: 0.2420 (Cls: 0.2420, Consistency: 0.0000, Sparsity: 0.0104)
Train Accuracy: 36.19%, AUC: 0.5972
Validation Loss: 0.1887, Accuracy: 42.28%, AUC: 0.7111
Model saved!
------------------------------------------------------------


Epoch 2/10: 100%|██████████| 433/433 [00:29<00:00, 14.80it/s]


Epoch 2/10:
Train Loss: 0.1892 (Cls: 0.1892, Consistency: 0.5533, Sparsity: 0.0068)
Train Accuracy: 43.75%, AUC: 0.6999
Validation Loss: 0.1887, Accuracy: 43.46%, AUC: 0.7388
Model saved!
------------------------------------------------------------


Epoch 3/10: 100%|██████████| 433/433 [00:29<00:00, 14.44it/s]


Epoch 3/10:
Train Loss: 0.1819 (Cls: 0.1819, Consistency: 0.5359, Sparsity: 0.0058)
Train Accuracy: 46.00%, AUC: 0.7552
Validation Loss: 0.1817, Accuracy: 47.44%, AUC: 0.7721
Model saved!
------------------------------------------------------------


Epoch 4/10: 100%|██████████| 433/433 [00:30<00:00, 14.13it/s]


Epoch 4/10:
Train Loss: 0.1789 (Cls: 0.1789, Consistency: 0.5313, Sparsity: 0.0056)
Train Accuracy: 46.53%, AUC: 0.7724
Validation Loss: 0.1856, Accuracy: 43.80%, AUC: 0.7317
------------------------------------------------------------


Epoch 5/10: 100%|██████████| 433/433 [00:29<00:00, 14.46it/s]


Epoch 5/10:
Train Loss: 0.1770 (Cls: 0.1770, Consistency: 0.5188, Sparsity: 0.0065)
Train Accuracy: 48.19%, AUC: 0.7884
Validation Loss: 0.1803, Accuracy: 50.10%, AUC: 0.7767
Model saved!
------------------------------------------------------------


Epoch 6/10: 100%|██████████| 433/433 [00:30<00:00, 14.30it/s]


Epoch 6/10:
Train Loss: 0.1738 (Cls: 0.1738, Consistency: 0.5085, Sparsity: 0.0048)
Train Accuracy: 49.33%, AUC: 0.7950
Validation Loss: 0.1856, Accuracy: 44.00%, AUC: 0.7580
------------------------------------------------------------


Epoch 7/10: 100%|██████████| 433/433 [00:30<00:00, 14.36it/s]


In [12]:
memory = AttentionMemory(memory_size=5, alpha=0.7)