In [None]:
import os
import json
from pathlib import Path
from glob import glob
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score, recall_score, f1_score
import matplotlib.pyplot as plt

In [None]:

from google.colab import drive, files
drive.mount('/content/drive')

In [None]:
 #====================== Paths ======================
base_path = "/content/drive/MyDrive/pneumonia detectio sytem/dataset/chest_xray"
train_dir = Path(os.path.join(base_path, "train"))
val_dir = Path(os.path.join(base_path, "val"))
test_dir = Path(os.path.join(base_path, "test"))

In [None]:
# ====================== Display Sample Images ======================
def display_images(folder, num=3):
    classes = ['NORMAL', 'PNEUMONIA', 'NOT_CHEST']
    extensions = ['*.jpeg', '*.jpg', '*.png', '*.JPEG', '*.JPG', '*.PNG']
    fig, axes = plt.subplots(nrows=len(classes), ncols=num, figsize=(15, 5))
    for row, cls in enumerate(classes):
        class_path = folder / cls
        images = []
        for ext in extensions:
            images.extend(class_path.glob(ext))
        images = images[:num]
        for i in range(num):
            if i < len(images):
                img = Image.open(images[i])
                axes[row, i].imshow(img, cmap='gray')
                axes[row, i].set_title(cls)
                axes[row, i].axis('off')
            else:
                axes[row, i].axis('off')
    plt.tight_layout()
    plt.show()

display_images(train_dir)


In [None]:
# ====================== Transforms ======================
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
])

val_test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [None]:
# ====================== Load Data ======================
train_dataset = ImageFolder(train_dir, transform=train_transforms)
val_dataset = ImageFolder(val_dir, transform=val_test_transforms)
test_dataset = ImageFolder(test_dir, transform=val_test_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# ====================== Compute Class Weights ======================
from sklearn.utils.class_weight import compute_class_weight
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_dataset.targets),
    y=train_dataset.targets
)
class_weights = torch.tensor(class_weights, dtype=torch.float)

In [None]:
# ====================== Load ViT Model ======================
config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k")
config.num_labels = 3

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k", config=config
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
class_weights = class_weights.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [None]:
from torch.amp import autocast  # For mixed precision training
from torch.cuda.amp import GradScaler  # Scales gradients to avoid underflow during AMP

def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=20):
    # Lists to store losses and accuracies for each epoch
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []



    for epoch in range(num_epochs):
        print(f"\n📘 Epoch {epoch+1}/{num_epochs}")
        model.train()  # Set model to training mode
        running_loss, correct, total = 0.0, 0, 0  # Track training loss and accuracy

        # ---------- TRAINING LOOP ----------
        for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}", leave=False):
            images, labels = images.to(device), labels.to(device)  # Move data to GPU/CPU
            optimizer.zero_grad()  # Clear previous gradients

            # Forward pass with automatic mixed precision
            with autocast(device_type=device.type):
                outputs = model(images).logits  # Forward pass (ViT model returns .logits)
                loss = criterion(outputs, labels)  # Compute loss

            # Backpropagation using scaled loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * images.size(0)  # Accumulate loss
            _, preds = outputs.max(1)  # Get class predictions
            correct += (preds == labels).sum().item()  # Count correct predictions
            total += labels.size(0)  # Count total samples

        # Compute average training loss and accuracy
        train_loss = running_loss / len(train_loader.dataset)
        train_acc = 100 * correct / total
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)

        # ---------- VALIDATION LOOP ----------
        model.eval()  # Set model to evaluation mode
        val_loss, correct, total = 0.0, 0, 0
        all_preds, all_labels = [], []  # Store predictions and labels for metrics

        with torch.no_grad():  # Disable gradient calculation
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images).logits  # Forward pass
                loss = criterion(outputs, labels)  # Compute loss
                val_loss += loss.item() * images.size(0)

                _, preds = outputs.max(1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

                all_preds.extend(preds.cpu().numpy())  # Save predictions
                all_labels.extend(labels.cpu().numpy())  # Save true labels

        # Compute average validation loss and accuracy
        val_loss = val_loss / len(val_loader.dataset)
        val_acc = 100 * correct / total
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)

        # Compute additional metrics (for multi-class classification)
        val_precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
        val_recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
        val_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

        # Print results for this epoch
        print(f"✅ Epoch {epoch+1} | Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}% | "
              f"Precision: {val_precision:.2f}, Recall: {val_recall:.2f}, F1: {val_f1:.2f}")


    # Return the model and all metrics for further analysis or plotting
    return model, train_losses, val_losses, train_accuracies, val_accuracies


In [None]:
# ====================== Train Model ======================
trained_model, train_losses, val_losses, train_accuracies, val_accuracies = train_model(
    model, criterion, optimizer, train_loader, val_loader, num_epochs=20
)

In [None]:
# ====================== Plot Accuracy & Loss ======================
plt.figure(figsize=(14, 5))
plt.subplot(1, 2, 1)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Val Accuracy', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Model Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Model Loss')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# ====================== Evaluate with Confusion Matrix ======================
def evaluate_model_with_confusion(model, test_loader, class_names):
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, preds = outputs.max(1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    acc = 100 * sum(p == t for p, t in zip(y_pred, y_true)) / len(y_true)
    print(f"\n📊 Test Accuracy: {acc:.2f}%")
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)

    print(f"🔍 Precision: {precision:.4f}")
    print(f"🔍 Recall:    {recall:.4f}")
    print(f"🔍 F1 Score:  {f1:.4f}")

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis],
                annot=True, fmt='.2f', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Normalized Confusion Matrix")
    plt.show()

evaluate_model_with_confusion(trained_model, test_loader, class_names=train_dataset.classes)