
Connect to Google Drive for datasets (colab)

In [1]:
#from google.colab import drive
#drive.mount('/content/gdrive')

Install dependencies (colab)

In [2]:
#!pip install simpleitk

Dependencies

In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons

import scipy

import SimpleITK as sitk
reader = sitk.ImageFileReader()
reader.SetImageIO("MetaImageIO")

import numpy as np

import os

import pathlib

from natsort import natsorted

#Set GPU/Cuda Device to run model on
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

np.random.seed(46)

Using cuda device


Dataset Directories <Br>
Comment out directory not in use


In [2]:
#Toy Dataset Directory from Spider | Grand Challenge 
dummy_train_img_dir = pathlib.Path(r"spider_toy_dset/dummy_train_images")
dummy_train_label_dir = pathlib.Path(r"spider_toy_dset/dummy_train_labels")
dummy_test_img_dir = pathlib.Path(r"spider_toy_dset/dummy_test_images")
dummy_test_label_dir= pathlib.Path(r"spider_toy_dset/dummy_test_labels")

'''
#Colab Google Drive Directories Toy Dataset
dummy_train_img_dir = pathlib.Path(r"/content/gdrive/MyDrive/Spider Dummy Dataset/dummy_train_images")
dummy_train_label_dir = pathlib.Path(r"/content/gdrive/MyDrive/Spider Dummy Dataset/dummy_train_labels")
dummy_test_img_dir = pathlib.Path(r"/content/gdrive/MyDrive/Spider Dummy Dataset/dummy_test_images")
dummy_test_label_dir= pathlib.Path(r"/content/gdrive/MyDrive/Spider Dummy Dataset/dummy_test_labels")
'''


'\n#Colab Google Drive Directories Toy Dataset\ndummy_train_img_dir = pathlib.Path(r"/content/gdrive/MyDrive/Spider Dummy Dataset/dummy_train_images")\ndummy_train_label_dir = pathlib.Path(r"/content/gdrive/MyDrive/Spider Dummy Dataset/dummy_train_labels")\ndummy_test_img_dir = pathlib.Path(r"/content/gdrive/MyDrive/Spider Dummy Dataset/dummy_test_images")\ndummy_test_label_dir= pathlib.Path(r"/content/gdrive/MyDrive/Spider Dummy Dataset/dummy_test_labels")\n'

Image class

In [3]:
from transforms import mri_transforms

class Mri:
    def __init__(self, path):
        mri_mha = sitk.ReadImage(path, imageIO = "MetaImageIO") #explicitly setting ioreader just in case

        #resampling
        #mri_mha_resampled = mri_transforms.resample_img(mri_mha, out_spacing= [1, 0.3, 0.3])
        #TODO separate resample (bilinear, nearestNeighbor) for images and labels

        mri_a = np.array(sitk.GetArrayFromImage(mri_mha)) #mri_array

        #transpose array to format z x y
        if(mri_a.shape[0] > mri_a.shape[1] or mri_a.shape[0] > mri_a.shape[2]): #if z axis isn't first
          mri_a = np.transpose(mri_a, (2, 0, 1))
      
        mri_a_float32 = mri_a.astype(dtype = np.float32)
        #TODO: set bounds to [-1000, 2000] https://en.wikipedia.org/wiki/Hounsfield_scale
        self.hu_a = mri_a_float32

Check for sorted directory and dimension order <Br>
Get max dimension on x y for zero padding 

In [4]:
#get lists from directories
label_dir_list = os.listdir(dummy_train_label_dir)
image_dir_list = os.listdir(dummy_train_img_dir) 
#sort lists
image_dir_list = natsorted(image_dir_list)
label_dir_list = natsorted(label_dir_list)
#empty lists to hold x and y dimensions of images
dim1_list = []
dim2_list = []

dirlen = len(os.listdir(dummy_train_label_dir))

print(dirlen)

print(dummy_train_label_dir)


44
spider_toy_dset\dummy_train_labels


Get max dimension of slice in dset for x y padding (creates 912x912 images, too large for training)

In [5]:
for idx in range(0, dirlen):
  img_path = dummy_train_img_dir.joinpath(image_dir_list[idx])
  label_path = dummy_train_label_dir.joinpath(label_dir_list[idx])

  image = Mri(img_path)
  label = Mri(label_path)
  '''
  print(idx, "image: ", image.hu_a.shape)
  print(idx, "label: ", label.hu_a.shape)
  '''
  dim1_list.append(image.hu_a.shape[1]) #add x value to list
  dim2_list.append(image.hu_a.shape[2]) #add y value to list 

#calculate max 
x_dim_max = max(dim1_list)
y_dim_max = max(dim2_list)

print("x max:", max(dim1_list))
print("y max:", max(dim2_list))

x_dim_max = 912 
y_dim_max = 528
#912 and 528 was done by calculating the nearest multiple of 16 **above** x and y


x max: 899
y max: 514


Find empty slices in labels <br>
Create new label and image arrays without the slices w/o mask info <br>
Cell works for 1 image

In [7]:
from transforms import array_transforms

#grab the first image from the dset for testing
img_path = dummy_train_img_dir.joinpath(image_dir_list[0]) 
label_path = dummy_train_label_dir.joinpath(label_dir_list[0])

image = Mri(img_path)
label = Mri(label_path)


test_image_hu, test_label_hu = array_transforms.remove_empty_slices(image.hu_a, label.hu_a)


size of array before trimming 50
size of array after trimming what it should be 30
True
(30, 578, 448)


For slices that aren't empty delete the surrounding 0s to bring resolution down <br>
Will work on the test arrays from the cell above <br>
Cell works for 1 image

In [8]:
array_transforms.crop_zero(test_image_hu, test_label_hu)



original image res (30, 578, 448)
x max 406
y max 143


(406, 143)

find max dims for cropping 

In [6]:
x_max = list()
y_max  = list()

for idx in range(0, dirlen): #for images in directory 
  img_path = dummy_train_img_dir.joinpath(image_dir_list[idx])
  label_path = dummy_train_label_dir.joinpath(label_dir_list[idx])

  image = Mri(img_path)
  label = Mri(label_path)

  x ,y = array_transforms.crop_zero(image.hu_a ,label.hu_a)
  x_max.append(x)
  y_max.append(y)


print("x max in dset", max(x_max))
print("y max in dset", max(y_max))

NameError: name 'array_transforms' is not defined

Dataset Class

In [7]:
from transforms import tensor_transforms

class SpiderDataset(Dataset):
    def __init__(self, labels_dir, img_dir, transform=None, target_transform=None):
        self.labels_dir = labels_dir
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(os.listdir(self.labels_dir))

    def __getitem__(self, idx):
        label_dir_list = os.listdir(self.labels_dir)
        image_dir_list = os.listdir(self.img_dir)

        image_dir_list = natsorted(image_dir_list)
        label_dir_list = natsorted(label_dir_list)

        img_path = self.img_dir.joinpath(image_dir_list[idx])
        label_path = self.labels_dir.joinpath(label_dir_list[idx])

        image = Mri(img_path)
        label = Mri(label_path)

        #image = self.transform(image)
        #label = self.target_transform(label)

        #comment out the part not being used whether for 3d or 2d model 
    
        '''
        #3d tensor for 3D CNN
        image_tensor = torch.from_numpy(image.hu_a)
        label_tensor = torch.from_numpy(label.hu_a)
        '''

        #2d tensor for 2D CNN, get random slice from image 
        rand_idx = np.random.randint(0, image.hu_a.shape[0])

        image_tensor = torch.from_numpy(image.hu_a[rand_idx])
        label_tensor = torch.from_numpy(label.hu_a[rand_idx])
        
        image_tensor = image_tensor.to(torch.float32)
        label_tensor = label_tensor.to(torch.float32)
        
        #pad to max resolution of slice in dset 
        image_tensor = tensor_transforms.pad_to_resolution(image_tensor, [x_dim_max, y_dim_max])
        label_tensor = tensor_transforms.pad_to_resolution(label_tensor, [x_dim_max, y_dim_max])

        image_tensor = image_tensor.unsqueeze(0)
        label_tensor = label_tensor.unsqueeze(0)

        image_tensor = image_tensor.to(device)
        label_tensor = label_tensor.to(device)

        return image_tensor, label_tensor

#toy train test dataset to test network running
dummy_train_set = SpiderDataset(dummy_train_label_dir, dummy_train_img_dir)
dummy_test_set = SpiderDataset(dummy_test_label_dir, dummy_test_img_dir)

print("train dataset len",dummy_train_set.__len__())
print("test dataset len",dummy_test_set.__len__())


train dataset len 44
test dataset len 10


Create Unet Instance







In [8]:
from models import unet 

input_channels = 1 #Hounsfield scale
output_channels = 1 #Vertebra, disc and spinal canal masks SHOULD BE 3 FOR 3 MASKS
model = unet.UNet(in_channels = input_channels, out_channels = output_channels)
model.to(device)
model.to(torch.float32)
for param in model.parameters():
    print(param.device)

cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0


Hyperparameters


In [9]:
epochs = 1 #testing
lr = 0.001 #testing
batchsize = 2 #testing
loss_func = nn.MSELoss()
loss_func.to(device)
optim = torch.optim.Adam(model.parameters(), lr=lr)


Dataloader

In [10]:
dummy_train_dataloader = DataLoader(dummy_train_set, batch_size = batchsize, shuffle=True)
dummy_test_dataloader = DataLoader(dummy_test_set, batch_size = batchsize, shuffle=True)

Dataloader Iterate through Z Axis of tensor (3D tensor)

In [11]:
for images, masks in dummy_test_dataloader:
  for i in images, masks:
    print(i.shape)
    break


torch.Size([2, 1, 912, 528])
torch.Size([2, 1, 912, 528])
torch.Size([2, 1, 912, 528])
torch.Size([2, 1, 912, 528])
torch.Size([2, 1, 912, 528])


One Epoch <br>
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

In [11]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(dummy_train_dataloader):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        print(inputs.shape)
        print(device)

        #inputs = inputs.transpose(-1,0)
        #labels = labels.transpose(-1,0)
        #inputs = inputs.reshape(inputs.shape(1), inputs.shape(0), inputs.shape(3), inputs.shape(4))
        #labels = labels.reshape(labels.shape(1), labels.shape(0), labels.shape(3), labels.shape(4))

      
        # Zero your gradients for every batch!
        optim.zero_grad()

        #2d
        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_func(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optim.step()

        #3d tensor, probably won't use keeping here just in case  
        '''
        #run everything while indexing through z axis
        for axis in images, masks:
          for idx in range(0, axis.size(1)):
            # Zero your gradients for every batch!
            optim.zero_grad()
            print(inputs[:, idx, : ,:])
            # Make predictions for this batch
            outputs = model(inputs[:, idx, : ,:])

            # Compute the loss and its gradients
            loss = loss_func(outputs[:, idx, : ,:], labels[:, idx, : ,:])
            loss.backward()

            # Adjust learning weights
            optim.step()
        '''
        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(dummy_train_dataloader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

Train Loop <br>
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html


In [12]:
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/spider_seg_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(dummy_test_dataloader):
            vinputs, vlabels = vdata

            voutputs = model(vinputs)
            vloss = loss_func(voutputs, vlabels)
            running_vloss += vloss
            

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
LOSS train 0.0 valid 185.86248779296875
EPOCH 2:
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 528])
cuda
torch.Size([2, 1, 912, 5