# Intro
---
This is a project for practicing creating a convolutional neural network. The dataset used for the model is the Open Access Series of Imaging Studies (OASIS) OASIS-1 dataset. The dataset was downloaded from: https://www.kaggle.com/datasets/ninadaithal/imagesoasis

Acknowledgments: “Data were provided by OASIS-1: Cross-Sectional: Principal Investigators: D. Marcus, R, Buckner, J, Csernansky J. Morris; P50 AG05681, P01 AG03991, P01 AG026276, R01 AG021910, P20 MH071616, U24 RR021382

In [None]:
# Mounting Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Commands to retrieve and unzip data

In [None]:
!cp "/content/drive/MyDrive/OASIS-I/archive.zip" /content/

In [None]:
!unzip /content/archive.zip -d /content/OASIS-I_extracted

## Note
---
Because the OASIS-I dataset is composed of hundreds of similar MRI slices for each subject, randomly splitting the slices into train/test set causes the sets to not be independent. This means that one slice of a subject may end up in the training set and the neighboring slice (which is highly similar) may end up in the testing set, meaning that the model is not generalizing, but is rather recognizing similar looking slices. This leads to inflated accuracy rates. To fix this, all slices from each subject should stay in a single set.

Furthermore, the number of subjects in each category of dementia vary heavily, leading to an imbalanced dataset. This can cause the convolutional neural network to learn to predict "Non Demented" most of time and still get a high overall accuracy. Thus, metrics other than accuracy such as F1-score will be focused on in the testing phase. Also, since there are only 2 subjects in the "Moderate Dementia" class, the "Moderate Dementia" and "Mild Dementia" class were merged into a single "Mild+ Dementia" class for better statistical reliability.

In [None]:
! mkdir -p /content/OASIS-I_extracted/Data/Mild+\ Dementia

In [None]:
! mv /content/OASIS-I_extracted/Data/Mild\ Dementia/* /content/OASIS-I_extracted/Data/Mild+\ Dementia/
! mv /content/OASIS-I_extracted/Data/Moderate\ Dementia/* /content/OASIS-I_extracted/Data/Mild+\ Dementia/

In [None]:
! rmdir /content/OASIS-I_extracted/Data/Mild\ Dementia
! rmdir /content/OASIS-I_extracted/Data/Moderate\ Dementia

In [None]:
# Importing stuff
import os
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split, Subset
from PIL import Image
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from sklearn.model_selection import GroupShuffleSplit

In [None]:
# Creating the dataset
class OASISDataset(Dataset):
  def __init__(self, data_dir, img_transform):
    self.img_transform = img_transform
    self.image_paths = []
    self.labels = []
    self.label_names = sorted(os.listdir(data_dir))

    for label_id, label_name in enumerate(self.label_names):
      label_dir = os.path.join(data_dir, label_name)
      images = os.listdir(label_dir)
      for image in images:
        image_path = os.path.join(label_dir, image)
        self.image_paths.append(image_path)
        self.labels.append(label_id)

  def __len__(self):
    return len(self.image_paths)

  def __getitem__(self, idx):
    image_path = self.image_paths[idx]
    image = Image.open(image_path)
    image = self.img_transform(image)
    label = torch.tensor(self.labels[idx], dtype=torch.long)
    return image, label

  def get_label_names(self):
    return self.label_names

  def get_image_paths(self):
    return self.image_paths

data_dir = "/content/OASIS-I_extracted/Data"

img_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((248, 496)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

data = OASISDataset(data_dir, img_transform)

In [None]:
# Determining the number of subjects in each class
data_dir = "/content/OASIS-I_extracted/Data"
subject_id_re = re.compile(r'(OAS1_\d{4})')
class_names = data.get_label_names()
class_counts = {class_name: set() for class_name in class_names}
for class_name in class_names:
  class_dir = os.path.join(data_dir, class_name)
  for name in os.listdir(class_dir):
    found = subject_id_re.search(name)
    if found:
      subject_id = found.group(1)
      class_counts[class_name].add(subject_id)

for class_name, subjects in class_counts.items():
  print(f"{class_name}: {len(subjects)} subjects")

# Storing the number of subjects in each class into a list
# for later use in weighted sum
class_counts_array = []
for class_name in class_names:
  class_counts_array.append(len(class_counts[class_name]))

Mild+ Dementia: 23 subjects <br>
Non Demented: 266 subjects <br>
Very mild Dementia: 58 subjects <br>

In [None]:
# Setting up training, validation, and testing dataloaders
image_paths = data.get_image_paths()
subject_ids = []
subject_id_re = re.compile(r'(OAS1_\d{4})')
for path in image_paths:
  found = subject_id_re.search(path)
  if found:
    subject_id = found.group(1)
    subject_ids.append(subject_id)

subject_ids = np.array(subject_ids)

# Uses GroupShuffleSplit to split by subject
gss = GroupShuffleSplit(n_splits=1, test_size=0.3)
indices = np.arange(len(data))
train_idx, temp_idx = next(gss.split(indices, groups=subject_ids))
temp_subjects = subject_ids[temp_idx]

gss_val_test = GroupShuffleSplit(n_splits=1, test_size=0.333)
temp_indices = np.arange(len(temp_idx))
val_rel_idx, test_rel_idx = next(gss_val_test.split(temp_indices, groups=temp_subjects))
val_idx = temp_idx[val_rel_idx]
test_idx = temp_idx[test_rel_idx]

train_dataset = Subset(data, train_idx)
validation_dataset = Subset(data, val_idx)
test_dataset = Subset(data, test_idx)

batch_size = 64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Showing sample images
def imshow(img):
  img = img/2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.show()

dataiter = iter(train_dataloader)
images, labels = next(dataiter)

imshow(torchvision.utils.make_grid(images))

In [None]:
# Creating the cnn
class ConvolutionalNeuralNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.cnn_stack = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
        nn.ReLU(),
        nn.Conv2d(16, 32, 5, 1, 2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Dropout(0.4)
    )

    self.fc = nn.Sequential(
        nn.Flatten(),
        nn.Linear(16 * 124 * 248, 256),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(256, 3)
    )
    dummy_input = torch.randn(1, 1, 248, 496)  # your real input size
    output = self.cnn_stack(dummy_input)
    print(output.shape)

  def forward(self, x):
    x = self.cnn_stack(x)
    x = self.fc(x)
    return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = ConvolutionalNeuralNetwork()
model.to(device)

In [None]:
# Setting up loss function and optimizer, using weighted loss for imbalanced dataset
learning_rate = 1e-4
epochs = 10

class_counts_tensor = torch.tensor(class_counts_array, dtype=torch.float32)
class_weights = 1 / class_counts_tensor
class_weights = class_weights / torch.sum(class_weights)

loss_fn = nn.CrossEntropyLoss(weight=class_weights.to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Running the training and validation loop
train_losses = []
val_losses = []
for epoch in range(epochs):
  model.train()
  train_loss = 0.0
  validation_loss = 0.0
  for i, (X, y) in enumerate(train_dataloader):
    X, y = X.to(device), y.to(device)

    pred = model(X)
    loss = loss_fn(pred, y)
    train_loss += loss.item()

    loss.backward()
    # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    optimizer.zero_grad()

  # Validation
  model.eval()
  with torch.no_grad():
    for (X, y) in validation_dataloader:
      X, y = X.to(device), y.to(device)
      pred = model(X)
      loss = loss_fn(pred, y)
      validation_loss += loss.item()

  train_losses.append(train_loss/len(train_dataloader))
  val_losses.append(validation_loss/len(validation_dataloader))
  print(f'[{epoch+1}] train loss:\t{train_losses[epoch]:.3f}')
  print(f'[{epoch+1}] val loss:\t{val_losses[epoch]:.3f}')

print('Finished training and validation')

In [None]:
# Plotting training and validation loss
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Losses vs Epochs')
plt.legend()
plt.show()

In [None]:
# Writing training losses and validation losses to a text file
with open('/content/drive/MyDrive/OASIS-I/train_loss.txt', 'w') as f:
    for loss in train_losses:
        f.write(f"{loss}\n")

with open('/content/drive/MyDrive/OASIS-I/val_loss.txt', 'w') as f:
    for loss in val_losses:
        f.write(f"{loss}\n")

In [None]:
# Saving the trained model
torch.save(model.state_dict(), '/content/drive/MyDrive/OASIS-I/cnn.pth')

In [None]:
# Loading the trained model
model.load_state_dict(torch.load('/content/drive/MyDrive/OASIS-I/cnn.pth'))

In [None]:
# Running the testing loop
classes = data.get_label_names()
class_correct = {classname: 0 for classname in classes}
class_total = {classname: 0 for classname in classes}
overall_correct = 0
overall_total = 0
all_predictions = []
all_targets = []
model.eval()
with torch.no_grad():
  for (X, y) in test_dataloader:
    X, y = X.to(device), y.to(device)
    outputs = model(X)
    _, predictions = torch.max(outputs, 1)
    overall_total += y.size(0)
    overall_correct += (predictions == y).sum().item()
    all_predictions.extend(predictions.cpu().numpy())
    all_targets.extend(y.cpu().numpy())
    for label, prediction in zip(y, predictions):
      if label == prediction:
        class_correct[classes[label]] += 1
      class_total[classes[label]] +=1

overall_accuracy = 100 * float(overall_correct) / overall_total
print(f'Overall accuracy of the network on test:images: {overall_accuracy:.2f} %')
for classname, correct_count in class_correct.items():
  accuracy = 100 * float(correct_count) / class_total[classname]
  print(f'Accuracy for class: {classname:5s} is {accuracy:.2f} %')

In [None]:
# Displaying classification scores and plotting the confusion matrix
print(classification_report(all_targets, all_predictions, target_names=classes))
cm = confusion_matrix(all_targets, all_predictions)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
disp.plot(cmap='Blues')
plt.xticks(rotation=45)
plt.title('Confusion Matrix')
plt.show()

## Stuff to Learn/Improve and Final Thoughts
---
Due to the class imbalance, the model overfits and has issues with accuracy.  The training and validation loss looks good for the first few epochs, but the validation loss then seems to diverge completely and causes the model to overfit by a substantial amount.

Stuff to Learn/Improve:
- Figure out how to stop model from overfitting
- Learn ways to approach class imbalance (Undersampling, Oversampling, Focal Loss Function, ...)
- https://arxiv.org/abs/2109.09850 - Could be a good paper to read
- Compare with different models
- Write better documentation and learn better coding practices

Final Thoughts:
- Is splitting by subject the correct thing to do? Am I approaching the way to prepare the data wrong?
- How complex should the model be? How can I know what the best number of in/out channels for the Conv2d layer is, when to use dropout, how large the fully connected layer should be?
