In [None]:
!pip install SimpleITK

Collecting SimpleITK
  Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.4/52.4 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.4.0


In [None]:
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import sys
import nibabel as nib
import SimpleITK as sitk
import numpy as np
import os
import nibabel as nib
import h5py
import gc
import matplotlib.pyplot as plt
import json

from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from tqdm import tqdm
#from brain_extraction import BrainExtraction
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
class ConvBlock(nn.Module):
  def __init__(self, in_channel, out_channel, kernel, padding):
    super().__init__()
    self.in_channels = in_channel
    self.out_channels = out_channel
    self.kernel = kernel
    self.padding = padding

    self.blocks = nn.Sequential(
        nn.Conv3d(self.in_channels, self.out_channels, kernel_size=self.kernel, padding=self.padding),
        nn.BatchNorm3d(self.out_channels),
        nn.MaxPool3d(self.kernel,padding=self.padding),
        nn.ReLU()
    )

  def forward(self, x):
    #for i in range(self.num_conv):
    return self.blocks(x)

class SkipConnectionConvBlock(nn.Module):
  def __init__(self, in_channel, out_channel, kernel, padding, num_conv_blocks):
    super().__init__()
    self.in_channels = in_channel
    self.out_channels = out_channel
    self.kernel = kernel
    self.padding = padding
    self.num_conv_blocks = num_conv_blocks

    self.conv_blocks = nn.ModuleList()
    self.batch_norms = nn.ModuleList()
    if self.num_conv_blocks % 2 == 1:
      raise Exception('number of conv blocks should be even')

    self.conv_blocks.append(
        torch.nn.Conv3d(
            self.in_channels,
            self.out_channels,
            kernel_size=self.kernel,
            padding=self.padding,
            device=device
        )
    )
    self.batch_norms.append(nn.BatchNorm3d(self.out_channels, device=device))
    skip_input_channel = self.in_channels
    for i in range(1, self.num_conv_blocks):
        if i % 2 == 0:
            input_channel = self.out_channels + skip_input_channel
            skip_input_channel = self.out_channels
        else:
            input_channel = self.out_channels
        self.conv_blocks.append(
            torch.nn.Conv3d(
                input_channel,
                self.out_channels,
                kernel_size=self.kernel,
                padding=self.padding,
                device=device
            )
        )
        self.batch_norms.append(nn.BatchNorm3d(self.out_channels, device=device))

    self.maxpool = nn.MaxPool3d(self.kernel,padding=self.padding)
    self.relu = nn.ReLU()

  def forward(self, x):
    skip_channel_input = x
    conv_block_output = None
    for i in range(self.num_conv_blocks):
        if i != 0 and i % 2 == 0:
            current_input = torch.concat([skip_channel_input, conv_block_output], dim=1)
        elif i == 0:
            current_input = x
        else:
            current_input = conv_block_output
        conv_block_output1 = self.conv_blocks[i](current_input)
        del current_input
        conv_block_output2 = self.batch_norms[i](conv_block_output1)
        del conv_block_output1
        conv_block_output = self.relu(conv_block_output2)
        del conv_block_output2

    x = conv_block_output
    #maxpool, avgpool
    return x

class CNNModel(torch.nn.Module):

  def __init__(self, in_channel, start_out_channel, num_classes, H,W,Z,number_of_blocks=2):
    super().__init__()
    self.num_classes = num_classes
    self.number_of_blocks = number_of_blocks
    self.out_channels = [start_out_channel * (2**i) for i in range(self.number_of_blocks)] #consider changing scaling for each layer so that theyre the same size
    self.in_channels = [in_channel] + self.out_channels[:-1]

    self.kernel_size = [3]*number_of_blocks
    self.padding=[1]*number_of_blocks #3//2

    self.skip_connection_block = SkipConnectionConvBlock(self.in_channels[0], self.out_channels[0], self.kernel_size[0], self.padding[0], 4)
    self.in_channels[0] = self.out_channels[0]
    self.resnet_blocks = nn.ModuleList([
        ConvBlock(self.in_channels[i], self.out_channels[i], self.kernel_size[i], self.padding[i]) for i in range(self.number_of_blocks)
    ])
    #self.flatten = nn.Flatten()

    self.fc = nn.Linear(self.out_channels[-1], self.num_classes, device=device)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.skip_connection_block(x)
    for i in range(self.number_of_blocks):
        x = self.resnet_blocks[i](x)
    x1 = x.mean(dim=(2,3,4)) #average pool
    out = self.fc(x1)
    out = self.sigmoid(out)
    return out

In [None]:
class NumpyDataset(Dataset):#old block
    def __init__(self, data, labels):
        self.data = [torch.from_numpy(np.expand_dims(x, axis=0)).float() for x in data]
        self.labels = torch.from_numpy(labels).long()

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

    def __len__(self):
        return len(self.data)

In [None]:
class CustomDirDataset(Dataset):
    def __init__(self, data_path_list, labels):
        self.data = data_path_list
        self.labels = torch.from_numpy(labels).long()

    def __getitem__(self, index):
        img = nib.load(self.data[index])
        img = img.get_fdata()
        img = np.expand_dims(img, axis=0)
        img = torch.from_numpy(img).float()
        return img, self.labels[index]

    def __len__(self):
        return len(self.data)

In [None]:
#hd5 image loader

IMAGE_DIR = '/content/drive/MyDrive/DL_CNN/OAS_STRIP'
LABEL_PATH = '/content/drive/MyDrive/DL_CNN/OAS_LABELS/data_labels.csv'

df = pd.read_csv(LABEL_PATH)

def create_hdf5_dataset(folder_path=IMAGE_DIR, output_file='brain_images.h5'):
    state_dict = {row['MRI ID']: (row['Group']!='Nondemented') for idx, row in df.iterrows()}
    total_images = 0
    file_mapping = []

    for patient_visit in sorted(os.listdir(folder_path)):
        if 'MR' in patient_visit:
                if patient_visit not in state_dict:
                    continue
                patient_dir = os.path.join(folder_path, patient_visit)
                for root, _, files in os.walk(patient_dir):
                    for file in files:
                        if 'n4' in file and file.endswith('.gz'):
                            file_path = os.path.join(root, file)
                            if os.path.exists(file_path):
                                file_mapping.append((patient_visit, file_path))
                                total_images += 1

    print(f"{total_images} images")

    first_img = nib.load(file_mapping[0][1]).get_fdata()
    first_img = np.swapaxes(np.swapaxes(first_img, 0, 2), 1, 2)
    img_shape = first_img.shape
    del first_img
    gc.collect()

    with h5py.File(output_file, 'w') as f:
        images_dataset = f.create_dataset('images', shape=(total_images,) + img_shape,dtype='float32', chunks=(1,) + img_shape,compression='gzip', compression_opts=4)
        labels_dataset = f.create_dataset('labels',shape=(total_images,),dtype='bool')
        patient_ids = f.create_dataset('patient_ids',shape=(total_images,),dtype=h5py.special_dtype(vlen=str))
        for idx, (patient_visit, file_path) in enumerate(tqdm(file_mapping)):
                target_image = nib.load(file_path)
                target_image = target_image.get_fdata()
                target_image = np.swapaxes(np.swapaxes(target_image, 0, 2), 1, 2)

                images_dataset[idx] = target_image
                labels_dataset[idx] = state_dict[patient_visit]
                patient_ids[idx] = patient_visit

                del target_image
                gc.collect()
                torch.cuda.empty_cache()

        f.attrs['total_images'] = total_images
        f.attrs['image_shape'] = img_shape

create_hdf5_dataset()

1368 images


100%|██████████| 1368/1368 [43:34<00:00,  1.91s/it]


In [None]:
#hd5 block

class HD5Dataset(torch.utils.data.Dataset):
    def __init__(self, h5_file='brain_images.h5'):
        self.h5_file = h5py.File(h5_file, 'r')
        self.images = self.h5_file['images']
        self.labels = self.h5_file['labels']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx): #[128, 256, 256]
        image = torch.from_numpy(self.images[idx]).float() #[1, 128, 256, 256]
        image = image.unsqueeze(0)
        label = torch.tensor(self.labels[idx]).long()
        return image, label

    def __del__(self):
        self.h5_file.close()

def create_data_loaders(batch_size=32):
    full_dataset = HD5Dataset()
    total_size = len(full_dataset)
    train_size = int(0.7 * total_size)
    val_size = int(0.15 * total_size)
    test_size = total_size - train_size - val_size
    indices = list(range(total_size))
    train_indices, temp_indices = train_test_split(indices, train_size=train_size, random_state=42)
    val_indices, test_indices = train_test_split(temp_indices, train_size=val_size, random_state=42)

    train_dataset = Subset(full_dataset, train_indices)
    val_dataset = Subset(full_dataset, val_indices)
    test_dataset = Subset(full_dataset, test_indices)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = create_data_loaders(batch_size=4)
dataset = HD5Dataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
model = CNNModel(1, 16, 1, 256,128,256,3) #re-run after GPU limit reached
model.to(device)

In [None]:
criterion = nn.BCELoss() #re-run after GPU limit reached
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
#new training block ; this block was re-run beyond the first 50 epoch run, and was cut off when Colab limit reached

save_dir = '/content/drive/MyDrive/DL_CNN/training_history'
os.makedirs(save_dir, exist_ok=True)
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    all_losses = []

    for i, (x, y) in tqdm(enumerate(dataloader)):
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output[:, 0], y.float())
        loss.backward()
        optimizer.step()

        loss_value = loss.detach().item()
        all_losses.append(loss_value)
        del x, y, output, loss
        torch.cuda.empty_cache()
        gc.collect()

    return sum(all_losses)/len(all_losses)

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    val_losses = []
    with torch.no_grad(): #minimize memory load
        for i, (x, y) in tqdm(enumerate(val_loader)):
            x = x.to(device)
            y = y.to(device)
            output = model(x)
            loss = criterion(output[:, 0], y.float())
            val_losses.append(loss.item())
            del x, y, output, loss
            torch.cuda.empty_cache()
            gc.collect()

    return sum(val_losses)/len(val_losses)

val_losses_history = []
train_losses_history = []

for epoch in range(50):
    model.train()
    all_losses = []

    for i, (x, y) in tqdm(enumerate(train_loader), desc=f"Training epoch {epoch}"):
        torch.cuda.empty_cache()
        gc.collect()
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output[:, 0], y.float())
        loss.backward()
        optimizer.step()

        all_losses.append(loss.detach().item())
        del x, y, output, loss

    train_loss = sum(all_losses)/len(all_losses)
    train_losses_history.append(float(train_loss))
    print(f"Train Loss: {train_loss:.5f} for epoch: {epoch}")
    model.eval()
    val_losses = []

    with torch.no_grad():
        for i, (x, y) in tqdm(enumerate(val_loader), desc=f"Validation epoch {epoch}"):
            torch.cuda.empty_cache()
            gc.collect()
            x = x.to(device)
            y = y.to(device)
            output = model(x)
            loss = criterion(output[:, 0], y.float())
            val_losses.append(loss.item())
            del x, y, output, loss
    val_loss = sum(val_losses)/len(val_losses)
    val_losses_history.append(float(val_loss))
    print(f"Val Loss: {val_loss:.5f} for epoch: {epoch}")
    history = {'train_losses': train_losses_history,'val_losses': val_losses_history,'last_epoch': epoch}

    with open(f'{save_dir}/loss_history.json', 'w') as f:
        json.dump(history, f)
    torch.cuda.empty_cache()
    gc.collect()

def plot_training_history(history_path):
    with open(history_path, 'r') as f:
        history = json.load(f)

    plt.figure(figsize=(10, 6))
    plt.plot(history['train_losses'], label='Training Loss', marker='o')
    plt.plot(history['val_losses'], label='Validation Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Over Time')
    plt.legend()
    plt.grid(True)
    plt.show()

plot_training_history(f'{save_dir}/loss_history.json')


Starting epoch 0


  label = torch.tensor(self.labels[idx]).long()
Training epoch 0: 240it [09:16,  2.32s/it]


Train Loss: 0.5877 for epoch: 0


Validation epoch 0: 52it [00:48,  1.07it/s]


Val Loss: 0.6272 for epoch: 0

Starting epoch 1


Training epoch 1: 240it [09:16,  2.32s/it]


Train Loss: 0.5826 for epoch: 1


Validation epoch 1: 52it [00:48,  1.07it/s]


Val Loss: 0.9891 for epoch: 1

Starting epoch 2


Training epoch 2: 72it [02:46,  2.30s/it]

In [None]:
#testing
model.eval()
test_losses = []
test_predictions = []
test_labels = []

with torch.no_grad():
    for i, (x, y) in tqdm(enumerate(test_loader)):
        x = x.to(device)
        y = y.to(device)
        output = model(x)
        loss = criterion(output[:, 0], y.float())
        test_losses.append(loss.item())
        test_predictions.extend(output.cpu().numpy()[:, 0])
        test_labels.extend(y.cpu().numpy())

test_predictions = np.array(test_predictions)
test_labels = np.array(test_labels)

binary_predictions = (test_predictions >= 0.5).astype(int)

correct = (binary_predictions == test_labels).sum()
accuracy = correct / len(test_labels)

print(f"Test Accuracy: {accuracy:.5f}")
print(f"Average Test Loss: {np.mean(test_losses):.5f}")

from sklearn.metrics import classification_report, confusion_matrix

print(classification_report(test_labels, binary_predictions))


  label = torch.tensor(self.labels[idx]).long()
52it [00:38,  1.37it/s]

Test Accuracy: 0.6990
Average Test Loss: 0.5894

Classification Report:
              precision    recall  f1-score   support

           0       0.82      0.44      0.58        95
           1       0.66      0.92      0.77       111

    accuracy                           0.70       206
   macro avg       0.74      0.68      0.67       206
weighted avg       0.73      0.70      0.68       206


Confusion Matrix:
[[ 42  53]
 [  9 102]]



