<a href="https://colab.research.google.com/github/alimomennasab/ChestXRay-Classification/blob/main/ConvNext.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##1. Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import pandas as pd
import os
import numpy as np
from torchvision.ops import StochasticDepth
from typing import List
from PIL import Image
from glob import glob
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

##2. Loading Data

In [None]:
data_dir = '/content/drive/My Drive/chest_xray/'

In [None]:
# Split dataset into training, validation, and test sets
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')
test_dir = os.path.join(data_dir, 'test')

In [None]:
print(os.listdir(data_dir))
classes = os.listdir(data_dir + "/train")
print(classes)

['test', 'val', 'train']
['PNEUMONIA', 'NORMAL']


In [None]:
# Pneumonia images
pneumonia_files = os.listdir(data_dir + "/train/PNEUMONIA")
print('No. of training examples for Pneumonia:', len(pneumonia_files))
print(pneumonia_files[:5])

No. of training examples for Pneumonia: 3875
['person557_virus_1097.jpeg', 'person553_bacteria_2316.jpeg', 'person537_bacteria_2264.jpeg', 'person543_bacteria_2281.jpeg', 'person496_virus_1003.jpeg']


In [None]:
# Normal (healthy) images
normal_files = os.listdir(data_dir + "/train/NORMAL")
print('No. of training examples for Normal:', len(normal_files))
print(normal_files[:5])

No. of training examples for Normal: 1341
['IM-0524-0001.jpeg', 'IM-0515-0001.jpeg', 'IM-0508-0001.jpeg', 'IM-0516-0001.jpeg', 'IM-0511-0001-0002.jpeg']


In [None]:
# There are almost three times more pneumonia images than normal images, so we will use class weighing

# Define classes
classes = ['NORMAL', 'PNEUMONIA']

# Define class weights
num_pneumonia_train = len(os.listdir(os.path.join(train_dir, classes[1])))
num_normal_train = len(os.listdir(os.path.join(train_dir, classes[0])))
total_train = num_pneumonia_train + num_normal_train
class_weights = torch.tensor([total_train/num_normal_train, total_train/num_pneumonia_train]).to(device)

##3. Preparing Dataset and DataLoader

In [None]:
# Define transforms
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=10, translate=(0.05,0.05)),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

val_and_test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

In [None]:
# Define data directories. We won't use a custom class because the dataset is already well-formatted.

train_dataset = ImageFolder('/content/drive/My Drive/chest_xray/train', transform = train_transform)
val_dataset = ImageFolder('/content/drive/My Drive/chest_xray/val', transform = val_and_test_transform)
test_dataset = ImageFolder('/content/drive/My Drive/chest_xray/test', transform = val_and_test_transform)

train_loader = DataLoader(train_dataset, batch_size = 16, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 16, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size = 16, shuffle = False)

##4. Defining and Choosing Model

In [None]:
class LayerScaler(nn.Module):
    def __init__(self, init_value: float, dimensions: int):
        super().__init__()
        self.gamma = nn.Parameter(init_value * torch.ones((dimensions)), 
                                    requires_grad=True)
        
    def forward(self, x):
        return self.gamma[None,...,None,None] * x
        
class ConvNormAct(nn.Sequential):
    # A little util layer composed by (conv) -> (norm) -> (act) layers. 
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_size: int,
        norm = nn.BatchNorm2d,
        act = nn.ReLU,
        **kwargs
    ):
        super().__init__(
            nn.Conv2d(
                in_features,
                out_features,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
                **kwargs
            ),
            norm(out_features),
            act(),
        )

class BottleNeckBlock(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        expansion: int = 4,
        drop_p: float = .0,
        layer_scaler_init_value: float = 1e-6,
    ):
        super().__init__()
        expanded_features = out_features * expansion
        self.block = nn.Sequential(
            # narrow -> wide (with depth-wise and bigger kernel)
            nn.Conv2d(
                in_features, in_features, kernel_size=7, padding=3, bias=False, groups=in_features
            ),
            # GroupNorm with num_groups=1 is the same as LayerNorm but works for 2D data
            nn.GroupNorm(num_groups=1, num_channels=in_features),
            # wide -> wide 
            nn.Conv2d(in_features, expanded_features, kernel_size=1),
            nn.GELU(),
            # wide -> narrow
            nn.Conv2d(expanded_features, out_features, kernel_size=1),
        )
        self.layer_scaler = LayerScaler(layer_scaler_init_value, out_features)
        self.drop_path = StochasticDepth(drop_p, mode="batch")

        
    def forward(self, x: Tensor) -> Tensor:
        res = x
        x = self.block(x)
        x = self.layer_scaler(x)
        x = self.drop_path(x)
        x += res
        return x

class ConvNexStage(nn.Sequential):
    def __init__(
        self, in_features: int, out_features: int, depth: int, **kwargs
    ):
        super().__init__(
            # add the downsampler
            nn.Sequential(
                nn.GroupNorm(num_groups=1, num_channels=in_features),
                nn.Conv2d(in_features, out_features, kernel_size=2, stride=2)
            ),
            *[
                BottleNeckBlock(out_features, out_features, **kwargs)
                for _ in range(depth)
            ],
        )

class ConvNextStem(nn.Sequential):
    def __init__(self, in_features: int, out_features: int):
        super().__init__(
            nn.Conv2d(in_features, out_features, kernel_size=4, stride=4),
            nn.BatchNorm2d(out_features)
        )

class ConvNextEncoder(nn.Module):
    def __init__(
        self,
        in_channels: int,
        stem_features: int,
        depths: List[int],
        widths: List[int],
        drop_p: float = .0,
    ):
        super().__init__()
        self.stem = ConvNextStem(in_channels, stem_features)

        in_out_widths = list(zip(widths, widths[1:]))
        # create drop paths probabilities (one for each stage)
        drop_probs = [x.item() for x in torch.linspace(0, drop_p, sum(depths))] 
        
        self.stages = nn.ModuleList(
            [
                ConvNexStage(stem_features, widths[0], depths[0], drop_p=drop_probs[0]),
                *[
                    ConvNexStage(in_features, out_features, depth, drop_p=drop_p)
                    for (in_features, out_features), depth, drop_p in zip(
                        in_out_widths, depths[1:], drop_probs[1:]
                    )
                ],
            ]
        )
        

    def forward(self, x):
        x = self.stem(x)
        for stage in self.stages:
            x = stage(x)
        return x

class ClassificationHead(nn.Sequential):
    def __init__(self, num_channels: int, num_classes: int = 2):
        super().__init__(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(1),
            nn.LayerNorm(num_channels),
            nn.Linear(num_channels, num_classes)
        )
    
class ConvNext(nn.Sequential):
    def __init__(self,  
                 in_channels: int,
                 stem_features: int,
                 depths: List[int],
                 widths: List[int],
                 drop_p: float = .0,
                 num_classes: int = 2):
        super().__init__()
        self.encoder = ConvNextEncoder(in_channels, stem_features, depths, widths, drop_p)
        self.head = ClassificationHead(widths[-1], num_classes)


In [None]:
#Initializing model
model = ConvNext(in_channels=3, stem_features=64, depths=[3,4,6,4], widths=[256, 512, 1024, 2048])

In [None]:
if torch.cuda.is_available():
    model.cuda()

##5. Testing Model

In [None]:
x = torch.rand(1, 3, 224, 224).to(device)
model(x).shape

torch.Size([1, 2])

##6. Defining Main Training

In [None]:
# Hyper-parameters
num_epochs = 30
learning_rate = 0.001
patience = 10

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

curr_lr = learning_rate
total_step = len(train_loader)

# Early stopping parameters
early_stopping_counter = 0
best_loss = float('inf')

# For updating learning rate
def update_lr(optimizer, lr):    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
    
    # Decay learning rate
    if (epoch+1) % 20 == 0:
        curr_lr /= 3
        update_lr(optimizer, curr_lr)

    # Calculate validation loss
    with torch.no_grad():
        val_loss = 0
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            val_loss += criterion(outputs, labels)
        val_loss /= len(val_loader)

        # Save the model if the validation loss is the best observed yet
        if val_loss < best_loss:
            print(f'Saving model with validation loss of {val_loss:.4f}...')
            torch.save(model.state_dict(), 'best_model.pth')
            best_loss = val_loss

    # Early stopping if overfitting
    if early_stopping_counter >= patience:
        print(f'Validation loss has not improved for {patience} epochs. Early stopping...')
        break
    elif val_loss < best_loss:
        best_loss = val_loss
        early_stopping_counter = 0


Epoch [1/30], Step [100/326] Loss: 0.8381
Epoch [1/30], Step [200/326] Loss: 0.4606
Epoch [1/30], Step [300/326] Loss: 0.8287
Saving model with validation loss of 0.5096...
Epoch [2/30], Step [100/326] Loss: 0.5864
Epoch [2/30], Step [200/326] Loss: 0.4673
Epoch [2/30], Step [300/326] Loss: 0.5147
Saving model with validation loss of 0.4660...
Epoch [3/30], Step [100/326] Loss: 0.5141
Epoch [3/30], Step [200/326] Loss: 0.5784
Epoch [3/30], Step [300/326] Loss: 0.4019
Epoch [4/30], Step [100/326] Loss: 0.6047
Epoch [4/30], Step [200/326] Loss: 0.3296
Epoch [4/30], Step [300/326] Loss: 0.4036
Saving model with validation loss of 0.4536...
Epoch [5/30], Step [100/326] Loss: 0.7433
Epoch [5/30], Step [200/326] Loss: 0.4903
Epoch [5/30], Step [300/326] Loss: 0.3827
Epoch [6/30], Step [100/326] Loss: 0.4108
Epoch [6/30], Step [200/326] Loss: 0.7480
Epoch [6/30], Step [300/326] Loss: 0.5949
Epoch [7/30], Step [100/326] Loss: 0.5605
Epoch [7/30], Step [200/326] Loss: 0.5370
Epoch [7/30], Step 

##7. Testing

In [None]:
# Load the saved model checkpoint
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint)

model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Model accuracy on test images: {} %'.format(100 * correct / total))

Model accuracy on test images: 70.99358974358974 %
