This code uses the lung segmentation dataset from Kaggle and the UNET architecture to train a model to segment lung scans. The dataset has two labels; normal and tubercolosis. The model learns which is which and is able to segment the lungs so that it is clear to people viewing the scans. 

In [65]:
import warnings
warnings.filterwarnings('ignore')
import matplotlib as plt
import torch.nn as nn
import torch
import torchvision 
import os
import tqdm
import PIL
from PIL import Image
from torchsummary import summary


# Image preprocessing libraries
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

In [90]:
BATCH = 32
LR = 0.001
dir_path = r'C:\Users\Sam\OneDrive - Monash University\Monash - Uni\DeepNeuron\Training Project\Begineer Projects\Medical MNIST Data'


In [91]:
# Define transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(), 
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load entire dataset
dataset = ImageFolder(root=dir_path, transform=transform)

# Check class names and label mapping
print(dataset.classes)  # ['feet', 'hand', 'head']
print(dataset.class_to_idx)  # {'feet': 0, 'hand': 1, 'head': 2}

['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']
{'AbdomenCT': 0, 'BreastMRI': 1, 'CXR': 2, 'ChestCT': 3, 'Hand': 4, 'HeadCT': 5}


In [92]:
# Split the dataset into train, validation and test: (70%, 20%, 10%)

train_split = int(0.7*len(dataset))
val_split = int(0.2*len(dataset))
test_split = len(dataset) - train_split - val_split


train_data, val_data, test_data = random_split(dataset, [train_split, val_split, test_split])

print("Training data samples: ", len(train_data))
print("Val data samples: ", len(val_data))
print("Test data samples: ", len(test_data))


train_load = DataLoader(train_data, batch_size=BATCH, shuffle=True)
val_load = DataLoader(val_data, batch_size=BATCH, shuffle=False)
test_load = DataLoader(test_data, batch_size=BATCH, shuffle=False)

Training data samples:  41267
Val data samples:  11790
Test data samples:  5897


In [93]:
images, labels = next(iter(train_load))
print("Image batch shape:", images.shape)  # Expected: (batch_size, 1, 224, 224)
print("Label batch shape:", labels.shape)

Image batch shape: torch.Size([32, 1, 256, 256])
Label batch shape: torch.Size([32])


In [94]:
class CNN(nn.Module):

    def __init__(self, num_classes, batch_size=BATCH):
        super(CNN, self).__init__()

        # Input size into the model will be: (bs, c, h, w)
        self.num_classes = num_classes
        self.network = nn.Sequential(
            nn.Conv2d(1, 64, 5, 1, 2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, 5, 1, 2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, 5, 1, 2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.fc = nn.Linear(256 * 32 * 32, num_classes)
        

    def forward(self, x):
        x = self.network(x)
        x = x.view(x.size(0), -1)  # Flatten before FC layer
        x = self.fc(x)
        return x
    
# Print a model summary showing the output dimensions and parameters passed through each layer. 
model = CNN(6)
summary(model, (1, 256, 256))
        

    

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           1,664
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
         MaxPool2d-4         [-1, 64, 128, 128]               0
            Conv2d-5        [-1, 128, 128, 128]         204,928
       BatchNorm2d-6        [-1, 128, 128, 128]             256
              ReLU-7        [-1, 128, 128, 128]               0
         MaxPool2d-8          [-1, 128, 64, 64]               0
            Conv2d-9          [-1, 256, 64, 64]         819,456
      BatchNorm2d-10          [-1, 256, 64, 64]             512
             ReLU-11          [-1, 256, 64, 64]               0
        MaxPool2d-12          [-1, 256, 32, 32]               0
           Linear-13                    [-1, 6]       1,572,870
Total params: 2,599,814
Trainable param

In [105]:
# Training: 
import tqdm
import torch.optim as optim


class TrainModel():
    def __init__(self, model, optim, criterion, device):

        self.optimizer = optim
        self.criterion = criterion
        self.model = model
        self.device = device
    
    def train_epoch(self, trainloader):

        model = self.model.train()
        progress = tqdm.tqdm(trainloader)
        correct=0
        running_loss = 0.0
        
        for batch in progress:
            # Prepare the data from the batch
            data, label = batch
            data, label = data.to(self.device), label.to(self.device)
            
            # Zero the gradients on the model
            self.optimizer.zero_grad()
            
            # Output a prediction
            output = model(data)
            loss = self.criterion(output, label)
            loss.backward() # Back propagation through the model
            self.optimizer.step()    # Updates parameters

            
            # compute training statistics
            _, predicted = torch.max(output, 1)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item() # sum total loss in current epoch for print later
            progress.set_postfix(loss=loss.item())

        avg_loss = running_loss / len(train_data)
        avg_acc = correct / len(train_data)

        return avg_loss, avg_acc

    def eval_model(self, validloader):
        """ 
        This function evaluates the model performance on the avl dataset and 
        updates hyperparameters in training. 
        """
        model=self.model.eval() # puts the model in validation mode
        with torch.no_grad():   # No gradient calculation as no backward propagation
            loss_val = 0.0
            correct_val = 0

        # Iterates through each batch in the data loader
            for data in tqdm.tqdm(validloader):
                batch, labels = data
                batch, labels = batch.to(self.device), labels.to(self.device)   # Sends data to GPU if available
                outputs = model(batch)
                loss = self.criterion(outputs, labels)
                _, predicted = torch.max(outputs, 1)
                correct_val += (predicted == labels).sum().item()
                loss_val += loss.item()
            avg_loss_val = loss_val / len(val_data)
            avg_acc_val = correct_val /len(val_data)

        return avg_loss_val, avg_acc_val

    def test(self, testloader):
        """
        Loads in the test dataset and tests the model against unseen images
        Returns the accuracy on test images
        """
        correct = 0
        model=self.model.eval()
        with torch.no_grad(): # no gradient calculation
            for data in testloader:
                batch, labels = data
                batch, labels = batch.to(self.device), labels.to(self.device)
                outputs = model(batch)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()

        return ('Accuracy on the test images: %.2f %%' % (100 * correct / len(test_data)))




In [None]:
optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
device = device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss().to(device)

CNN_model = TrainModel(model, optimizer, criterion, device)

#acc log for graph
epoch = 1
simp_acc_hist = []
simp_acc_hist_val =[]
for e in range(epoch):
  print(f'Epoch {e + 1}/{epoch}')
  print('-' * 10)
  simple_train_loss ,simple_train_acc= CNN_model.train_epoch(train_load)
  simp_acc_hist.append(simple_train_acc)
  print(f'Train loss {simple_train_loss} accuracy {simple_train_acc}')

  simple_val_loss, simple_val_acc = CNN_model.eval_model(val_load)
  simp_acc_hist_val.append(simple_val_acc)

  print(f'Val loss {simple_val_loss} accuracy {simple_val_acc}')
  print()
