In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset ,WeightedRandomSampler
from torchvision import transforms
import os
from PIL import Image
from sklearn.metrics import precision_score, recall_score, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import train_test_split

In [2]:
# Constants
BATCH_SIZE = 16
IMAGE_SIZE = (180, 180)
EPOCHS = 15
GCS_PATH = "/kaggle/input/labeled-chest-xray-images/chest_xray"

In [3]:
# Get filenames for train and test directories
train_filenames = []
for root, _, files in os.walk(os.path.join(GCS_PATH, 'train')):
    for file in files:
        if file.endswith(('jpeg', 'jpg', 'png')):
            train_filenames.append(os.path.join(root, file))

test_filenames = []
for root, _, files in os.walk(os.path.join(GCS_PATH, 'test')):
    for file in files:
        if file.endswith(('jpeg', 'jpg', 'png')):
            test_filenames.append(os.path.join(root, file))

In [4]:
# Optionally, create a validation set from the train set (if desired)
train_filenames, val_filenames = train_test_split(train_filenames, test_size=0.2, random_state=42)

In [5]:
# Count classes for weighted sampling
COUNT_NORMAL = len([f for f in train_filenames if "NORMAL" in f])
COUNT_PNEUMONIA = len([f for f in train_filenames if "PNEUMONIA" in f])
print(f"Normal images count in training set: {COUNT_NORMAL}")
print(f"Pneumonia images count in training set: {COUNT_PNEUMONIA}")

Normal images count in training set: 1090
Pneumonia images count in training set: 3095


In [6]:
# Class labels
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]

In [7]:
# Custom Dataset class
class ChestXRayDataset(Dataset):
    def __init__(self, file_paths, transform=None):
        self.file_paths = file_paths
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = 1 if "PNEUMONIA" in file_path else 0
        image = Image.open(file_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

In [8]:
# Data transformations
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor()
])

In [9]:
train_dataset = ChestXRayDataset(train_filenames, transform=transform)
val_dataset = ChestXRayDataset(val_filenames, transform=transform)
test_dataset = ChestXRayDataset(test_filenames, transform=transform)

In [10]:
# Class weights for handling imbalance
class_weights = torch.tensor([
    1.0 / COUNT_NORMAL,
    1.0 / COUNT_PNEUMONIA
])

In [11]:
sample_weights = [
    class_weights[1 if "PNEUMONIA" in f else 0].item() for f in train_filenames
]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

In [12]:
# Data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


In [13]:
# Model definition
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),
        nn.MaxPool2d(2)
    )
def dense_block(in_features, out_features, dropout_rate):
    return nn.Sequential(
        nn.Linear(in_features, out_features),
        nn.ReLU(),
        nn.BatchNorm1d(out_features),
        nn.Dropout(dropout_rate)
    )

class ChestXRayModel(nn.Module):
    def __init__(self):
        super(ChestXRayModel, self).__init__()
        self.feature_extractor = nn.Sequential(
            conv_block(3, 16),
            conv_block(16, 32),
            conv_block(32, 64),
            conv_block(64, 128),
            nn.Dropout(0.2),
            conv_block(128, 256),
            nn.Dropout(0.2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            dense_block(256 * (IMAGE_SIZE[0] // 32) * (IMAGE_SIZE[1] // 32), 512, 0.7),
            dense_block(512, 128, 0.5),
            dense_block(128, 64, 0.3),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.classifier(x)
        return x

In [14]:
# Instantiate model
model = ChestXRayModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


ChestXRayModel(
  (feature_extractor): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 128, kerne

In [16]:
print(model)

ChestXRayModel(
  (feature_extractor): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 128, kerne

In [None]:

# Define loss function, optimizer, and metrics
criterion = nn.BCEWithLogitsLoss()  # Handles binary classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score


In [None]:
# Function to compute and print metrics
def compute_metrics(labels, preds):
    precision = precision_score(labels, preds)
    recall = recall_score(labels, preds)
    f1 = f1_score(labels, preds)
    return precision, recall, f1

In [None]:
# Function to plot confusion matrix
def plot_confusion_matrix(labels, preds, class_names):
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(6, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

In [None]:
# Function to visualize sample images from the dataset
def visualize_samples(data_loader):
    data_iter = iter(data_loader)
    images, labels = next(data_iter)
    fig, axes = plt.subplots(1, 5, figsize=(15, 5))
    for i in range(5):
        image = images[i].permute(1, 2, 0).numpy()  # Convert to HWC format for display
        ax = axes[i]
        ax.imshow(image)
        ax.set_title(f"Label: {labels[i].item()}")
        ax.axis('off')
    plt.show()

In [None]:
"""# Learning rate scheduler (exponential decay)
def exponential_decay(lr0, s):
    def exponential_decay_fn(epoch):
        return lr0 * 0.1 ** (epoch / s)
    return exponential_decay_fn

exponential_decay_fn = exponential_decay(0.01, 20)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=exponential_decay_fn)"""

In [None]:
# Visualize sample images from the training set
def visualize_samples(data_loader):
    data_iter = iter(data_loader)
    images, labels = next(data_iter)
    fig, axes = plt.subplots(1, 5, figsize=(15, 5))
    for i in range(5):
        image = images[i].permute(1, 2, 0).numpy()  # Convert to HWC format for display
        ax = axes[i]
        ax.imshow(image)
        ax.set_title(f"Label: {labels[i].item()}")
        ax.axis('off')
    plt.show()

# Visualization of sample images from the training set
visualize_samples(train_loader)

In [None]:
"""from tqdm import tqdm

# Training loop with tqdm for progress tracking
def train(model, train_loader, val_loader, epochs):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        # Wrap the train_loader with tqdm for progress tracking
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as tepoch:
            for inputs, labels in tepoch:
                inputs, labels = inputs.to(device), labels.to(device, dtype=torch.float32)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.squeeze(), labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

                # Calculate training accuracy
                predicted = (outputs.squeeze() > 0.5).float()
                correct_train += (predicted == labels).sum().item()
                total_train += labels.size(0)

                # Update the tqdm progress bar with loss and accuracy
                train_accuracy = 100 * correct_train / total_train
                tepoch.set_postfix(loss=running_loss / (tepoch.n + 1), train_accuracy=train_accuracy)

        # Training accuracy for the current epoch
        train_accuracy = 100 * correct_train / total_train
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}, Training Accuracy: {train_accuracy:.2f}%")

        evaluate(model, val_loader)

# The evaluation function remains the same as previously defined
def evaluate(model, val_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            predicted = (outputs.squeeze() > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    
    accuracy = 100 * correct / total
    print(f'Validation Accuracy: {accuracy:.2f}%')

"""

In [None]:
# Training loop with tqdm for progress tracking
def train(model, train_loader, val_loader, epochs):
    train_losses = []
    val_losses = []
    metrics_dict = {'precision': [], 'recall': [], 'f1': []}
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as tepoch:
            for inputs, labels in tepoch:
                inputs, labels = inputs.to(device), labels.to(device).float()  # Cast labels to float
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.squeeze(), labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

                # Calculate training accuracy
                predicted = (outputs.squeeze() > 0.5).float()
                correct_train += (predicted == labels).sum().item()
                total_train += labels.size(0)

                train_accuracy = 100 * correct_train / total_train
                tepoch.set_postfix(loss=running_loss / (tepoch.n + 1), train_accuracy=train_accuracy)

        # Calculate training accuracy and append loss
        train_accuracy = 100 * correct_train / total_train
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}, Training Accuracy: {train_accuracy:.2f}%")
        train_losses.append(running_loss / len(train_loader))

        # Validation evaluation
        model.eval()
        val_running_loss = 0.0
        correct_val = 0
        total_val = 0
        all_labels = []
        all_preds = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device).float()  # Cast labels to float
                outputs = model(inputs)
                val_loss = criterion(outputs.squeeze(), labels)
                val_running_loss += val_loss.item()

                # Calculate validation accuracy
                predicted = (outputs.squeeze() > 0.5).float()
                correct_val += (predicted == labels).sum().item()
                total_val += labels.size(0)
                
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(predicted.cpu().numpy())

        val_accuracy = 100 * correct_val / total_val
        val_losses.append(val_running_loss / len(val_loader))
        print(f"Validation Loss: {val_running_loss / len(val_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%")

        # Compute additional metrics
        precision, recall, f1 = compute_metrics(all_labels, all_preds)
        metrics_dict['precision'].append(precision)
        metrics_dict['recall'].append(recall)
        metrics_dict['f1'].append(f1)
        
        # Plot confusion matrix
        plot_confusion_matrix(all_labels, all_preds, CLASS_NAMES)

    # Plot the training and validation losses
    plt.plot(range(epochs), train_losses, label='Train Loss')
    plt.plot(range(epochs), val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()

    # Plot Precision, Recall, F1 Score over epochs
    plt.figure(figsize=(10, 6))
    plt.plot(range(epochs), metrics_dict['precision'], label="Precision")
    plt.plot(range(epochs), metrics_dict['recall'], label="Recall")
    plt.plot(range(epochs), metrics_dict['f1'], label="F1 Score")
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.title('Precision, Recall, F1 Score over Epochs')
    plt.legend()
    plt.show()


In [None]:
# Train the model with updated loop
train(model, train_loader, val_loader, EPOCHS)

In [None]:
def evaluate_test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device).float()  # Cast labels to float
            outputs = model(inputs)
            predicted = (outputs.squeeze() > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

    # Optionally, print or plot confusion matrix
    plot_confusion_matrix(all_labels, all_preds, CLASS_NAMES)


In [None]:
evaluate_test(model, test_loader)

In [None]:
torch.save(model, '/kaggle/working/model.h5')
