<a href="https://colab.research.google.com/github/alimomennasab/ChestXRay-Classification/blob/main/DenseNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##1. Setup

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import pandas as pd
import os
import numpy as np
from PIL import Image
from glob import glob
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [17]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

##2. Loading Data

In [19]:
data_dir = '/content/drive/My Drive/chest_xray/'

In [20]:
# Split dataset into training, validation, and test sets
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')
test_dir = os.path.join(data_dir, 'test')

In [21]:
print(os.listdir(data_dir))
classes = os.listdir(data_dir + "/train")
print(classes)

['train', 'chest_xray', 'val', '__MACOSX', 'test']
['PNEUMONIA', 'NORMAL']


In [22]:
# Pneumonia images
pneumonia_files = os.listdir(data_dir + "/train/PNEUMONIA")
print('No. of training examples for Pneumonia:', len(pneumonia_files))
print(pneumonia_files[:5])

No. of training examples for Pneumonia: 3875
['person556_virus_1096.jpeg', 'person536_bacteria_2257.jpeg', 'person581_bacteria_2390.jpeg', 'person592_bacteria_2434.jpeg', 'person581_virus_1125.jpeg']


In [23]:
# Normal (healthy) images
normal_files = os.listdir(data_dir + "/train/NORMAL")
print('No. of training examples for Normal:', len(normal_files))
print(normal_files[:5])

No. of training examples for Normal: 1341
['IM-0526-0001.jpeg', 'IM-0524-0001.jpeg', 'IM-0507-0001.jpeg', 'IM-0508-0001.jpeg', 'IM-0520-0001.jpeg']


In [24]:
# There are almost three times more pneumonia images than normal images, so we will use class weighing

# Define classes
classes = ['NORMAL', 'PNEUMONIA']

# Define class weights
num_pneumonia_train = len(os.listdir(os.path.join(train_dir, classes[1])))
num_normal_train = len(os.listdir(os.path.join(train_dir, classes[0])))
total_train = num_pneumonia_train + num_normal_train
class_weights = torch.tensor([total_train/num_normal_train, total_train/num_pneumonia_train]).to(device)

##3. Preparing Dataset and DataLoader

In [25]:
# Define transforms
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=10, translate=(0.05,0.05)),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

val_and_test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

In [26]:
# Define data directories. We won't use a custom class because the dataset is already well-formatted.

train_dataset = ImageFolder('/content/drive/My Drive/chest_xray/train', transform = train_transform)
val_dataset = ImageFolder('/content/drive/My Drive/chest_xray/val', transform = val_and_test_transform)
test_dataset = ImageFolder('/content/drive/My Drive/chest_xray/test', transform = val_and_test_transform)

train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 16, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size = 16, shuffle = False)

##4. Defining and Choosing Model

In [27]:
class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(
                nn.Sequential(
                    nn.BatchNorm2d(in_channels + i*growth_rate),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(in_channels + i*growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=True)
                )
            )

    def forward(self, x):
        for layer in self.layers:
            out = layer(x)
            x = torch.cat([x, out], dim=1)
        return x


class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransitionBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.layers(x)


class DenseNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=2, growth_rate=32, block_config=(6, 12, 24, 16)):
        super(DenseNet, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 2*growth_rate, kernel_size=7, stride=2, padding=3, bias=True),
            nn.BatchNorm2d(2*growth_rate),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        num_features = 2*growth_rate
        self.dense_blocks = nn.ModuleList()
        self.transition_blocks = nn.ModuleList()

        for i, num_layers in enumerate(block_config):
            self.dense_blocks.append(DenseBlock(num_features, growth_rate, num_layers))
            num_features += num_layers * growth_rate

            if i != len(block_config) - 1:
                out_channels = num_features // 2
                self.transition_blocks.append(TransitionBlock(num_features, out_channels))
                num_features = out_channels

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        x = self.features(x)

        for i, block in enumerate(self.dense_blocks):
            x = block(x)
            if i != len(self.dense_blocks) - 1:
                x = self.transition_blocks[i](x)

        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


In [28]:
# Initializing model
model = DenseNet()

In [29]:
if torch.cuda.is_available():
    model.cuda()

##5. Testing Model

In [30]:
x = torch.randn(2, 3, 224, 224).to(device)
print(model(x).shape) 

torch.Size([2, 2])


##6. Defining Main Training

In [31]:
# Hyper-parameters
num_epochs = 30
learning_rate = 0.001
patience = 10

In [32]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

curr_lr = learning_rate
total_step = len(train_loader)

# Early stopping parameters
early_stopping_counter = 0
best_loss = float('inf')

# For updating learning rate
def update_lr(optimizer, lr):    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
    
    # Decay learning rate
    if (epoch+1) % 20 == 0:
        curr_lr /= 3
        update_lr(optimizer, curr_lr)

    # Calculate validation loss
    with torch.no_grad():
        val_loss = 0
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            val_loss += criterion(outputs, labels)
        val_loss /= len(val_loader)

        # Save the model if the validation loss is the best observed yet
        if val_loss < best_loss:
            print(f'Saving model with validation loss of {val_loss:.4f}...')
            torch.save(model.state_dict(), 'best_model.pth')
            best_loss = val_loss

    # Early stopping if overfitting
    if early_stopping_counter >= patience:
        print(f'Validation loss has not improved for {patience} epochs. Early stopping...')
        break
    elif val_loss < best_loss:
        best_loss = val_loss
        early_stopping_counter = 0


Epoch [1/30], Step [100/326] Loss: 0.5958
Epoch [1/30], Step [200/326] Loss: 0.5650
Epoch [1/30], Step [300/326] Loss: 0.3751
Saving model with validation loss of 0.5921...
Epoch [2/30], Step [100/326] Loss: 0.4353
Epoch [2/30], Step [200/326] Loss: 0.7278
Epoch [2/30], Step [300/326] Loss: 0.6477
Epoch [3/30], Step [100/326] Loss: 0.3848
Epoch [3/30], Step [200/326] Loss: 0.7005
Epoch [3/30], Step [300/326] Loss: 0.3854
Epoch [4/30], Step [100/326] Loss: 0.4455
Epoch [4/30], Step [200/326] Loss: 0.4852
Epoch [4/30], Step [300/326] Loss: 0.6449
Epoch [5/30], Step [100/326] Loss: 0.4308
Epoch [5/30], Step [200/326] Loss: 0.4755
Epoch [5/30], Step [300/326] Loss: 0.6101
Epoch [6/30], Step [100/326] Loss: 0.6314
Epoch [6/30], Step [200/326] Loss: 0.3303
Epoch [6/30], Step [300/326] Loss: 0.4165
Epoch [7/30], Step [100/326] Loss: 0.7264
Epoch [7/30], Step [200/326] Loss: 0.4038
Epoch [7/30], Step [300/326] Loss: 0.4659
Epoch [8/30], Step [100/326] Loss: 0.4623
Epoch [8/30], Step [200/326] 

##7. Testing

In [33]:
# Load the saved model checkpoint
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint)

model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Model accuracy on test images: {} %'.format(100 * correct / total))

Model accuracy on test images: 74.67948717948718 %
