This is an example of creating a custom dataset for CT scans and predicting age using CNNs.

The dataset is similar to ours but much smaller.

Link to dataset:
https://www.kaggle.com/datasets/kmader/siim-medical-images/data

In [1]:
%%capture
!pip install kaggle
!pip install pydicom

In [2]:
import torch
import torch.nn as nn
from torchvision.datasets import VOCDetection, CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Dataset
import torch.optim as optim
import matplotlib.pyplot as plt
import os
import pandas as pd
from PIL import Image
import pydicom
import numpy as np

In [3]:
from google.colab import userdata

# ---
# Use yoyr own Kaggle key and username here
# ---
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')

In [4]:
!kaggle datasets download -d kmader/siim-medical-images

Downloading siim-medical-images.zip to /content
 98% 246M/250M [00:01<00:00, 198MB/s]
100% 250M/250M [00:01<00:00, 171MB/s]


In [5]:
%%capture
!unzip /content/siim-medical-images.zip

In [6]:
# ---
# Custom datset to generate images and target age from DICOM files
# ---
class DICOMDataset(Dataset):
  def __init__(self, root_dir, transform):
    self.root_dir = root_dir
    self.transform = transform
    self.files = os.listdir(root_dir)

  def __len__(self):
    return len(self.files)

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()

    # Here, we convert the dicom files to images in our preferred form
    img_name = os.path.join(self.root_dir, self.files[idx])
    ds = pydicom.dcmread(img_name)
    image = ds.pixel_array
    image = Image.fromarray(image)

    if image.mode != 'RGB':
      image = image.convert('RGB')

    image = self.transform(image)

    # The age of each patient is given in the filename
    # Here, we retrieve the age from the filename and use it as the target
    # We need to see how to do this for the actual dataset
    s = img_name.split('/')
    s = s[-1]
    s = s.split('_')
    target = int(s[3])

    return image, target


In [9]:
path = '/content/dicom_dir'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = DICOMDataset(root_dir=path, transform=transform)

In [96]:
train_set, test_set = random_split(dataset, [0.8, 0.2])
train_dataloader = DataLoader(train_set, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size=1, shuffle=True)
print(f"Training set size: {len(train_set)}, Test set size: {len(test_set)}")

Training set size: 80, Test set size: 20


In [88]:
for X, y in train_dataloader:
  print(f"Input format: {X.shape}")
  print(f"Target format: {y}")
  break

Input format: torch.Size([1, 3, 224, 224])
Target format: tensor([72])


In [103]:
# --- Model Architecture ---
#
# CNN Model for age prediction task
# Input: Tensors representing images
# Output: Integer representing age
#
# We use a number of convolutional layers to downsample the images
# Then we flatten the outputs from convolutional layers and run them
# through a number of dense layers to finally get a single number as output
#
class CNN_model(nn.Module):
  def __init__(self):
    super().__init__()
    self.img_size = 224
    self.channels = 3
    self.hidden_units = 16
    self.convolutional_layer = nn.Sequential(
        nn.Conv2d(in_channels=self.channels, out_channels=self.hidden_units, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(in_channels=self.hidden_units, out_channels=self.hidden_units, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )
    self.dense_layer = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=self.hidden_units*14*14, out_features=256),
        nn.ReLU(),
        nn.Linear(in_features=256, out_features=32),
        nn.ReLU(),
        nn.Linear(in_features=32, out_features=1)
    )

  def forward(self, x):
    x = self.convolutional_layer(x)
    x = self.dense_layer(x)
    return x

model = CNN_model()

In [104]:
# Training run
num_epochs = 150
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
model.train()

for epoch in range(num_epochs):
  train_loss = 0
  for X, y in train_dataloader:
    target = y.clone().detach().view(-1, 1).float()
    output = model(X)
    loss = loss_fn(output, target)
    train_loss += loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  train_loss /= len(train_dataloader)
  if (epoch + 1) % 10 == 0:
    print(f"Epoch: {epoch + 1}, Loss: {train_loss:.4f}")

Epoch: 10, Loss: 126.9956
Epoch: 20, Loss: 101.3677
Epoch: 30, Loss: 144.1414
Epoch: 40, Loss: 86.1506
Epoch: 50, Loss: 31.0108
Epoch: 60, Loss: 14.3189
Epoch: 70, Loss: 9.6694
Epoch: 80, Loss: 9.3386
Epoch: 90, Loss: 8.9831
Epoch: 100, Loss: 5.3190
Epoch: 110, Loss: 17.8799
Epoch: 120, Loss: 4.6158
Epoch: 130, Loss: 3.2613
Epoch: 140, Loss: 2.0407
Epoch: 150, Loss: 12.1292


In [105]:
# Test run
model.eval()
test_loss = 0

for X, y in test_dataloader:
  target = y.clone().detach().view(-1, 1).float()
  pred = model(X)
  test_loss += loss_fn(pred, y)

test_loss /= len(test_dataloader)
print(f"Final test loss: {test_loss}")

Final test loss: 106.84370422363281


In [106]:
# Looking at som predictions
model.eval()
for X, y in test_dataloader:
  output = model(X)
  print(f"Prediction: {int(output.item())}, Correct: {y.item()}")

Prediction: 56, Correct: 74
Prediction: 66, Correct: 44
Prediction: 55, Correct: 47
Prediction: 65, Correct: 75
Prediction: 57, Correct: 74
Prediction: 62, Correct: 80
Prediction: 60, Correct: 63
Prediction: 71, Correct: 73
Prediction: 64, Correct: 61
Prediction: 67, Correct: 74
Prediction: 70, Correct: 71
Prediction: 57, Correct: 58
Prediction: 67, Correct: 60
Prediction: 61, Correct: 70
Prediction: 63, Correct: 67
Prediction: 61, Correct: 74
Prediction: 68, Correct: 74
Prediction: 64, Correct: 74
Prediction: 58, Correct: 61
Prediction: 66, Correct: 74
