In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/rsna-pneumonia-processed-dataset/stage2_train_metadata.csv
/kaggle/input/rsna-pneumonia-processed-dataset/stage2_test_metadata.csv
/kaggle/input/rsna-pneumonia-processed-dataset/Training/Images/7a530f4d-13af-4e4e-bf2e-080c7fb27ffb.png
/kaggle/input/rsna-pneumonia-processed-dataset/Training/Images/86f084ef-539c-404b-9a91-c9a5d5ab3771.png
/kaggle/input/rsna-pneumonia-processed-dataset/Training/Images/673989cb-131a-43fc-8276-710c65ddb2d6.png
/kaggle/input/rsna-pneumonia-processed-dataset/Training/Images/85a77a2c-9fb8-4ceb-af2b-4de716d4d366.png
/kaggle/input/rsna-pneumonia-processed-dataset/Training/Images/9f756055-99fc-4a80-97b4-f89f98cf92d2.png
/kaggle/input/rsna-pneumonia-processed-dataset/Training/Images/799c13e6-dff4-4c52-9aaf-b9ea6daeacb1.png
/kaggle/input/rsna-pneumonia-processed-dataset/Training/Images/ff68830c-84cf-44c5-983d-6d332f86d82a.png
/kaggle/input/rsna-pneumonia-processed-dataset/Training/Images/a67dee22-6255-4c02-b0d2-d3cc0903c833.png
/kaggle/input/rsna-pneu

In [2]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from transformers import ViTImageProcessor, ViTForImageClassification
from tqdm import tqdm

# Define constants
IMAGES_PATH = '/kaggle/input/rsna-pneumonia-processed-dataset/Training/Images'
METADATA = '/kaggle/input/rsna-pneumonia-processed-dataset/stage2_train_metadata.csv'
BATCH_SIZE = 16
NUM_EPOCHS = 12  # Changed to 12 epochs
LEARNING_RATE = 2e-5
MODEL_NAME = "google/vit-base-patch16-224"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Load metadata
df = pd.read_csv(METADATA)

# Prepare dataset focusing on the Target column (pneumonia or not)
# Get unique patient IDs and their corresponding Target values
patient_targets = df[['patientId', 'Target']].drop_duplicates()
print(f"Total unique patients: {len(patient_targets)}")
print(f"Patients with pneumonia: {patient_targets['Target'].sum()}")
print(f"Patients without pneumonia: {len(patient_targets) - patient_targets['Target'].sum()}")

# Create a custom dataset
class PneumoniaDataset(Dataset):
    def __init__(self, patient_ids, targets, images_path, transform=None):
        self.patient_ids = patient_ids
        self.targets = targets
        self.images_path = images_path
        self.transform = transform
        
    def __len__(self):
        return len(self.patient_ids)
    
    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        image_path = os.path.join(self.images_path, f"{patient_id}.png")
        
        # Load image
        try:
            image = Image.open(image_path).convert('RGB')
        except FileNotFoundError:
            print(f"Image not found: {image_path}")
            # Return a dummy image if the file isn't found
            image = Image.new('RGB', (224, 224), color='black')
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        # Get the target (pneumonia or not)
        target = self.targets[idx]
        
        return image, target

# Create train-validation split
patient_ids = patient_targets['patientId'].values
targets = patient_targets['Target'].values

train_ids, val_ids, train_targets, val_targets = train_test_split(
    patient_ids, targets, test_size=0.2, stratify=targets, random_state=42
)

print(f"Training samples: {len(train_ids)}")
print(f"Validation samples: {len(val_ids)}")

# Load the image processor
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

# Define transformations
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

# Create datasets
train_dataset = PneumoniaDataset(train_ids, train_targets, IMAGES_PATH, transform=train_transforms)
val_dataset = PneumoniaDataset(val_ids, val_targets, IMAGES_PATH, transform=val_transforms)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Initialize the pretrained model
model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=1,  # Binary classification
    ignore_mismatched_sizes=True  # Handles case where final layer shape doesn't match
)

# Move model to device
model = model.to(DEVICE)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler - updated to use 12 epochs
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# Training function
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    predictions = []
    true_labels = []
    
    for images, labels in tqdm(dataloader, desc="Training"):
        images = images.to(device)
        labels = labels.float().to(device).unsqueeze(1)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        running_loss += loss.item()
        
        # Store predictions and true labels
        preds = torch.sigmoid(outputs) >= 0.5
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions)
    recall = recall_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions)
    
    return running_loss / len(dataloader), accuracy, precision, recall, f1

# Validation function
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    predictions = []
    true_labels = []
    raw_probs = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validation"):
            images = images.to(device)
            labels = labels.float().to(device).unsqueeze(1)
            
            # Forward pass
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            
            # Store predictions and true labels
            probs = torch.sigmoid(outputs)
            preds = probs >= 0.5
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            raw_probs.extend(probs.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions)
    recall = recall_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions)
    auc = roc_auc_score(true_labels, raw_probs)
    
    return running_loss / len(dataloader), accuracy, precision, recall, f1, auc

# Training loop
best_val_f1 = 0.0
patience = 5
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    # Train
    train_loss, train_acc, train_prec, train_recall, train_f1 = train_epoch(
        model, train_loader, criterion, optimizer, DEVICE
    )
    
    # Validate
    val_loss, val_acc, val_prec, val_recall, val_f1, val_auc = validate(
        model, val_loader, criterion, DEVICE
    )
    
    # Update learning rate
    scheduler.step()
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Prec: {train_prec:.4f}, Recall: {train_recall:.4f}, F1: {train_f1:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, Prec: {val_prec:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}, AUC: {val_auc:.4f}")
    
    # Check if this is the best model so far
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        patience_counter = 0
        
        # Save the best model
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'vit_state_dict': model.vit.state_dict(),  # Save just the ViT part for transfer
            'val_f1': val_f1,
            'epoch': epoch,
        }, 'best_pneumonia_vit_model.pth')
        
        print("Saved best model!")
    else:
        patience_counter += 1
        
    # Early stopping
    if patience_counter >= patience:
        print(f"Early stopping after {epoch+1} epochs!")
        break
    
    print("-" * 50)

# Load the best model for final evaluation
checkpoint = torch.load('best_pneumonia_vit_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Final evaluation on validation set
final_val_loss, final_val_acc, final_val_prec, final_val_recall, final_val_f1, final_val_auc = validate(
    model, val_loader, criterion, DEVICE
)

print("\nFinal Validation Metrics:")
print(f"Loss: {final_val_loss:.4f}")
print(f"Accuracy: {final_val_acc:.4f}")
print(f"Precision: {final_val_prec:.4f}")
print(f"Recall: {final_val_recall:.4f}")
print(f"F1 Score: {final_val_f1:.4f}")
print(f"AUC: {final_val_auc:.4f}")

# Save the final model (full model and ViT separately)
torch.save({
    'model_state_dict': model.state_dict(),
    'vit_state_dict': model.vit.state_dict(),  # Just the ViT for transfer learning
    'config': model.vit.config,
    'val_metrics': {
        'loss': final_val_loss,
        'accuracy': final_val_acc,
        'precision': final_val_prec,
        'recall': final_val_recall,
        'f1': final_val_f1,
        'auc': final_val_auc
    },
}, 'final_pneumonia_vit_model.pth')

print("Saved final model!")
print("ViT encoder can be loaded separately from this checkpoint for your report generation task.")

2025-04-17 10:33:26.445528: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744886006.711317      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744886006.781058      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda
Total unique patients: 26684
Patients with pneumonia: 6012
Patients without pneumonia: 20672
Training samples: 21347
Validation samples: 5337


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([1, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/12


Training: 100%|██████████| 1335/1335 [13:13<00:00,  1.68it/s]
Validation: 100%|██████████| 334/334 [01:22<00:00,  4.05it/s]


Train Loss: 0.3834, Acc: 0.8293, Prec: 0.6688, Recall: 0.4807, F1: 0.5593
Val Loss: 0.3801, Acc: 0.8250, Prec: 0.7757, Recall: 0.3136, F1: 0.4467, AUC: 0.8761
Saved best model!
--------------------------------------------------
Epoch 2/12


Training: 100%|██████████| 1335/1335 [13:32<00:00,  1.64it/s]
Validation: 100%|██████████| 334/334 [01:09<00:00,  4.82it/s]


Train Loss: 0.3517, Acc: 0.8441, Prec: 0.6965, Recall: 0.5464, F1: 0.6124
Val Loss: 0.3548, Acc: 0.8362, Prec: 0.7477, Recall: 0.4118, F1: 0.5311, AUC: 0.8818
Saved best model!
--------------------------------------------------
Epoch 3/12


Training: 100%|██████████| 1335/1335 [13:31<00:00,  1.64it/s]
Validation: 100%|██████████| 334/334 [01:09<00:00,  4.82it/s]


Train Loss: 0.3380, Acc: 0.8507, Prec: 0.7160, Recall: 0.5593, F1: 0.6280
Val Loss: 0.3500, Acc: 0.8456, Prec: 0.7647, Recall: 0.4542, F1: 0.5699, AUC: 0.8866
Saved best model!
--------------------------------------------------
Epoch 4/12


Training: 100%|██████████| 1335/1335 [13:31<00:00,  1.64it/s]
Validation: 100%|██████████| 334/334 [01:09<00:00,  4.81it/s]


Train Loss: 0.3234, Acc: 0.8575, Prec: 0.7257, Recall: 0.5913, F1: 0.6516
Val Loss: 0.3468, Acc: 0.8424, Prec: 0.7335, Recall: 0.4717, F1: 0.5742, AUC: 0.8869
Saved best model!
--------------------------------------------------
Epoch 5/12


Training: 100%|██████████| 1335/1335 [13:32<00:00,  1.64it/s]
Validation: 100%|██████████| 334/334 [01:09<00:00,  4.82it/s]


Train Loss: 0.3017, Acc: 0.8693, Prec: 0.7589, Recall: 0.6152, F1: 0.6795
Val Loss: 0.3807, Acc: 0.8439, Prec: 0.6631, Recall: 0.6240, F1: 0.6429, AUC: 0.8839
Saved best model!
--------------------------------------------------
Epoch 6/12


Training: 100%|██████████| 1335/1335 [13:32<00:00,  1.64it/s]
Validation: 100%|██████████| 334/334 [01:08<00:00,  4.86it/s]


Train Loss: 0.2784, Acc: 0.8802, Prec: 0.7821, Recall: 0.6491, F1: 0.7094
Val Loss: 0.3814, Acc: 0.8375, Prec: 0.6418, Recall: 0.6306, F1: 0.6362, AUC: 0.8820
--------------------------------------------------
Epoch 7/12


Training: 100%|██████████| 1335/1335 [13:31<00:00,  1.64it/s]
Validation: 100%|██████████| 334/334 [01:09<00:00,  4.82it/s]


Train Loss: 0.2515, Acc: 0.8933, Prec: 0.8073, Recall: 0.6917, F1: 0.7450
Val Loss: 0.3821, Acc: 0.8417, Prec: 0.6627, Recall: 0.6048, F1: 0.6324, AUC: 0.8805
--------------------------------------------------
Epoch 8/12


Training: 100%|██████████| 1335/1335 [13:32<00:00,  1.64it/s]
Validation: 100%|██████████| 334/334 [01:09<00:00,  4.81it/s]


Train Loss: 0.2221, Acc: 0.9095, Prec: 0.8367, Recall: 0.7435, F1: 0.7873
Val Loss: 0.4337, Acc: 0.8389, Prec: 0.6948, Recall: 0.5075, F1: 0.5865, AUC: 0.8664
--------------------------------------------------
Epoch 9/12


Training: 100%|██████████| 1335/1335 [13:32<00:00,  1.64it/s]
Validation: 100%|██████████| 334/334 [01:09<00:00,  4.81it/s]


Train Loss: 0.1932, Acc: 0.9216, Prec: 0.8610, Recall: 0.7778, F1: 0.8173
Val Loss: 0.4691, Acc: 0.8372, Prec: 0.6736, Recall: 0.5374, F1: 0.5979, AUC: 0.8674
--------------------------------------------------
Epoch 10/12


Training: 100%|██████████| 1335/1335 [13:33<00:00,  1.64it/s]
Validation: 100%|██████████| 334/334 [01:08<00:00,  4.85it/s]
  checkpoint = torch.load('best_pneumonia_vit_model.pth')


Train Loss: 0.1726, Acc: 0.9318, Prec: 0.8786, Recall: 0.8094, F1: 0.8425
Val Loss: 0.5138, Acc: 0.8361, Prec: 0.6727, Recall: 0.5300, F1: 0.5928, AUC: 0.8613
Early stopping after 10 epochs!


Validation: 100%|██████████| 334/334 [01:09<00:00,  4.83it/s]



Final Validation Metrics:
Loss: 0.3807
Accuracy: 0.8439
Precision: 0.6631
Recall: 0.6240
F1 Score: 0.6429
AUC: 0.8839
Saved final model!
ViT encoder can be loaded separately from this checkpoint for your report generation task.


In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from transformers import ViTImageProcessor, ViTForImageClassification
from tqdm import tqdm

# Define constants
IMAGES_PATH = '/kaggle/input/rsna-pneumonia-processed-dataset/Training/Images'
METADATA = '/kaggle/input/rsna-pneumonia-processed-dataset/stage2_train_metadata.csv'
BATCH_SIZE = 16
NUM_EPOCHS = 10
LEARNING_RATE = 2e-5
MODEL_NAME = "google/vit-base-patch16-224"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Load metadata
df = pd.read_csv(METADATA)

# Prepare dataset focusing on the Target column (pneumonia or not)
# Get unique patient IDs and their corresponding Target values
patient_targets = df[['patientId', 'Target']].drop_duplicates()
print(f"Total unique patients: {len(patient_targets)}")
print(f"Patients with pneumonia: {patient_targets['Target'].sum()}")
print(f"Patients without pneumonia: {len(patient_targets) - patient_targets['Target'].sum()}")

# Create a custom dataset
class PneumoniaDataset(Dataset):
    def __init__(self, patient_ids, targets, images_path, transform=None):
        self.patient_ids = patient_ids
        self.targets = targets
        self.images_path = images_path
        self.transform = transform
        
    def __len__(self):
        return len(self.patient_ids)
    
    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        image_path = os.path.join(self.images_path, f"{patient_id}.png")
        
        # Load image
        try:
            image = Image.open(image_path).convert('RGB')
        except FileNotFoundError:
            print(f"Image not found: {image_path}")
            # Return a dummy image if the file isn't found
            image = Image.new('RGB', (224, 224), color='black')
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        # Get the target (pneumonia or not)
        target = self.targets[idx]
        
        return image, target

# Create train-validation split
patient_ids = patient_targets['patientId'].values
targets = patient_targets['Target'].values

train_ids, val_ids, train_targets, val_targets = train_test_split(
    patient_ids, targets, test_size=0.2, stratify=targets, random_state=42
)

print(f"Training samples: {len(train_ids)}")
print(f"Validation samples: {len(val_ids)}")

# Load the image processor
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

# Define transformations
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

# Create datasets
train_dataset = PneumoniaDataset(train_ids, train_targets, IMAGES_PATH, transform=train_transforms)
val_dataset = PneumoniaDataset(val_ids, val_targets, IMAGES_PATH, transform=val_transforms)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Initialize the pretrained model
model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=1,  # Binary classification
    ignore_mismatched_sizes=True  # Handles case where final layer shape doesn't match
)

# Move model to device
model = model.to(DEVICE)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# Training function
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    predictions = []
    true_labels = []
    
    for images, labels in tqdm(dataloader, desc="Training"):
        images = images.to(device)
        labels = labels.float().to(device).unsqueeze(1)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        running_loss += loss.item()
        
        # Store predictions and true labels
        preds = torch.sigmoid(outputs) >= 0.5
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions)
    recall = recall_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions)
    
    return running_loss / len(dataloader), accuracy, precision, recall, f1

# Validation function
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    predictions = []
    true_labels = []
    raw_probs = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validation"):
            images = images.to(device)
            labels = labels.float().to(device).unsqueeze(1)
            
            # Forward pass
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            
            # Store predictions and true labels
            probs = torch.sigmoid(outputs)
            preds = probs >= 0.5
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
            raw_probs.extend(probs.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions)
    recall = recall_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions)
    auc = roc_auc_score(true_labels, raw_probs)
    
    return running_loss / len(dataloader), accuracy, precision, recall, f1, auc

# Training loop
best_val_f1 = 0.0
patience = 5
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    # Train
    train_loss, train_acc, train_prec, train_recall, train_f1 = train_epoch(
        model, train_loader, criterion, optimizer, DEVICE
    )
    
    # Validate
    val_loss, val_acc, val_prec, val_recall, val_f1, val_auc = validate(
        model, val_loader, criterion, DEVICE
    )
    
    # Update learning rate
    scheduler.step()
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Prec: {train_prec:.4f}, Recall: {train_recall:.4f}, F1: {train_f1:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, Prec: {val_prec:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}, AUC: {val_auc:.4f}")
    
    # Check if this is the best model so far
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        patience_counter = 0
        
        # Save the best model
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'vit_state_dict': model.vit.state_dict(),  # Save just the ViT part for transfer
            'val_f1': val_f1,
            'epoch': epoch,
        }, 'best_pneumonia_vit_model.pth')
        
        print("Saved best model!")
    else:
        patience_counter += 1
        
    # Early stopping
    if patience_counter >= patience:
        print(f"Early stopping after {epoch+1} epochs!")
        break
    
    print("-" * 50)

# Load the best model for final evaluation
checkpoint = torch.load('best_pneumonia_vit_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Final evaluation on validation set
final_val_loss, final_val_acc, final_val_prec, final_val_recall, final_val_f1, final_val_auc = validate(
    model, val_loader, criterion, DEVICE
)

print("\nFinal Validation Metrics:")
print(f"Loss: {final_val_loss:.4f}")
print(f"Accuracy: {final_val_acc:.4f}")
print(f"Precision: {final_val_prec:.4f}")
print(f"Recall: {final_val_recall:.4f}")
print(f"F1 Score: {final_val_f1:.4f}")
print(f"AUC: {final_val_auc:.4f}")

# Save the final model (full model and ViT separately)
torch.save({
    'model_state_dict': model.state_dict(),
    'vit_state_dict': model.vit.state_dict(),  # Just the ViT for transfer learning
    'config': model.vit.config,
    'val_metrics': {
        'loss': final_val_loss,
        'accuracy': final_val_acc,
        'precision': final_val_prec,
        'recall': final_val_recall,
        'f1': final_val_f1,
        'auc': final_val_auc
    },
}, 'final_pneumonia_vit_model.pth')

print("Saved final model!")
print("ViT encoder can be loaded separately from this checkpoint for your report generation task.")