
Connect to Google Drive for datasets (colab)

In [1]:
#from google.colab import drive
#drive.mount('/content/gdrive')

Install dependencies (colab)

In [2]:
#!pip install simpleitk

Dependencies

In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons

import scipy

import SimpleITK as sitk
reader = sitk.ImageFileReader()
reader.SetImageIO("MetaImageIO")

import numpy as np

import os

import pathlib

from natsort import natsorted

#Set GPU/Cuda Device to run model on
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

np.random.seed(46)

Using cuda device


Dataset Directories <Br>
Comment out directory not in use


In [2]:

#Toy Dataset Slices 
dummy_train_img_slice_dir = pathlib.Path(r"spider_toy_dset_slices/dummy_train_img_slices")
dummy_train_label_slice_dir = pathlib.Path(r"spider_toy_dset_slices/dummy_train_label_slices")
dummy_test_img_slice_dir = pathlib.Path(r"spider_toy_dset_slices/dummy_test_img_slices")
dummy_test_label_slice_dir= pathlib.Path(r"spider_toy_dset_slices/dummy_test_label_slices")


Image Slice Class

In [3]:
from transforms import mri_transforms


#TODO add bool image label for interp 
class Mri_Slice:
    def __init__(self, path):
        mri_mha = sitk.ReadImage(path, imageIO = "MetaImageIO") #explicitly setting ioreader just in case
        #resample images to 64x64
        #slice = mri_transforms.resample_img_to_res(itk_slice=slice, out_res= [64, 64], is_label= False, smoothing=False)

        #get 2d array from mri slice
        mri_a = np.array(sitk.GetArrayFromImage(mri_mha)) #mri_array
        
        mri_a_float32 = mri_a.astype(dtype = np.float32)
        #TODO: set bounds to [-1000, 2000] https://en.wikipedia.org/wiki/Hounsfield_scale
        self.hu_a = mri_a_float32

Sort directories


In [4]:
#get lists from directories

#toy dset slices
image_path = dummy_train_img_slice_dir
label_path = dummy_train_label_slice_dir

image_dir_list = os.listdir(image_path)
label_dir_list = os.listdir(label_path)

print(image_path) 
print(label_path)

#local dset
'''
image_dir_list = os.listdir(local_img_idr)
label_dir_list = os.listdir(local_label_dir)
'''
#sort lists
image_dir_list = natsorted(image_dir_list)
label_dir_list = natsorted(label_dir_list)
#empty lists to hold x and y dimensions of images
row_list = []
col_list = []

#dirlen = len(os.listdir(dummy_train_label_dir))
dirlen = len(os.listdir(label_path))

print(dirlen)

#print(local_label_idr)


spider_toy_dset_slices\dummy_train_img_slices
spider_toy_dset_slices\dummy_train_label_slices
924


Get max dimension of slice in dset for x y padding <br>
Images have slices with 0 label info removed and are cropped using zero crop 

In [5]:
from transforms import array_transforms

for idx in range(0, dirlen):
  #print("dirlen", dirlen)
  
 #toy dset 
  
  img_path = image_path.joinpath(image_dir_list[idx])
  lbl_path = label_path.joinpath(label_dir_list[idx])#first part before joinpath is pathlib.Path, second part is the directory of hte file 
  '''
  img_path = local_img_idr.joinpath(image_dir_list[idx])
  label_path = local_label_dir.joinpath(label_dir_list[idx]) #first part before joinpath is pathlib.Path, second part is the directory of hte file 
  '''
  image = Mri_Slice(img_path)
  label = Mri_Slice(lbl_path)


  #if(image.hu_a.shape[0] > 600): #if image way too high res
    #print("high res image in directory", img_path)

  #print("label after: ", label_a.shape)
  
  row_list.append(image.hu_a.shape[0]) #add row value to list
  col_list.append(image.hu_a.shape[1]) #add col value to list 

  
#calculate max 
row_dim_max = max(row_list)
col_dim_max = max(col_list)

row_dim_max = ((row_dim_max + 15) // 16) * 16 #nearest multiple of 16
col_dim_max = ((col_dim_max + 15) // 16) * 16 #nearest multiple of 16

print("row max:", max(row_list))
print("col max:", max(col_list))





KeyboardInterrupt: 

Dataset Class

In [10]:
from transforms import tensor_transforms

class SpiderDataset(Dataset):
    def __init__(self, labels_dir, img_dir, transform=None, target_transform=None):
        self.labels_dir = labels_dir
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(os.listdir(self.labels_dir))

    def __getitem__(self, idx):
        label_dir_list = os.listdir(self.labels_dir)
        image_dir_list = os.listdir(self.img_dir)

        image_dir_list = natsorted(image_dir_list)
        label_dir_list = natsorted(label_dir_list)

        img_path = self.img_dir.joinpath(image_dir_list[idx])
        label_path = self.labels_dir.joinpath(label_dir_list[idx])

        image = Mri_Slice(img_path)
        label = Mri_Slice(label_path)

        image_a = image.hu_a
        label_a = label.hu_a

        #comment out the part not being used whether for 3d or 2d model 
    
        '''
        #3d tensor for 3D CNN
        image_tensor = torch.from_numpy(image.hu_a)
        label_tensor = torch.from_numpy(label.hu_a)
        '''

        image_tensor = torch.from_numpy(image_a)
        label_tensor = torch.from_numpy(label_a)
        
        image_tensor = image_tensor.to(torch.float32)
        label_tensor = label_tensor.to(torch.float32)
        
        #pad to max resolution of slice in dset 
        image_tensor = tensor_transforms.pad_to_resolution(image_tensor, [row_dim_max, col_dim_max])
        label_tensor = tensor_transforms.pad_to_resolution(label_tensor, [row_dim_max, col_dim_max])

        image_tensor = image_tensor.unsqueeze(0)
        label_tensor = label_tensor.unsqueeze(0)

        image_tensor = image_tensor.to(device)
        label_tensor = label_tensor.to(device)

        #print(image_tensor.shape)
        return image_tensor, label_tensor




Dataset Classes


In [11]:
#toy train test dataset to test network running
#local_train_set = SpiderDataset(local_img_idr, local_label_idr)
dummy_train_set = SpiderDataset(dummy_train_label_slice_dir, dummy_train_img_slice_dir)

dummy_test_set = SpiderDataset(dummy_test_label_slice_dir, dummy_test_img_slice_dir)

print("train dataset len",dummy_train_set.__len__())
print("test dataset len",dummy_test_set.__len__())

train dataset len 924
test dataset len 200


Create Unet Instance







In [1]:
from models import unet 

input_channels = 1 #Hounsfield scale
output_channels = 3 #Vertebra, disc and spinal canal masks SHOULD BE 3 FOR 3 MASKS
output_channels = 16 #one for every part of the spine
model = unet.UNet(in_channels = input_channels, out_channels = output_channels)
model.to(device)
model.to(torch.float32)
#for param in model.parameters():
 #   print(param.device)

NameError: name 'device' is not defined

Hyperparameters


In [13]:
epochs = 5 #testing
lr = 0.001 #testing
batchsize = 2 #testing
loss_func = nn.MSELoss()
loss_func.to(device)
optim = torch.optim.Adam(model.parameters(), lr=lr)


Dataloader

In [14]:
dummy_train_dataloader = DataLoader(dummy_train_set, batch_size = batchsize, shuffle=True)

dummy_test_dataloader = DataLoader(dummy_test_set, batch_size = batchsize, shuffle=True)

'''
for batch in dummy_train_dataloader:
    for tensor in batch:
        print("min", torch.min(tensor))
        print("max", torch.max(tensor))
'''

'\nfor batch in dummy_train_dataloader:\n    for tensor in batch:\n        print("min", torch.min(tensor))\n        print("max", torch.max(tensor))\n'

Tensor Dimensions Check

In [15]:
for idx_slice in dummy_train_dataloader:
    for tensor in idx_slice:
        print(tensor.shape)
        break

torch.Size([2, 1, 899, 514])
torch.Size([2, 1, 899, 514])
torch.Size([2, 1, 899, 514])
torch.Size([2, 1, 899, 514])
torch.Size([2, 1, 899, 514])
torch.Size([2, 1, 899, 514])


KeyboardInterrupt: 

One Epoch <br>
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

In [16]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting

    #swap train dataloader for dset
    for i, data in enumerate(dummy_train_dataloader):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        #print("labels", labels.shape)
        #print(inputs.shape)
        #print(device)

        #inputs = inputs.transpose(-1,0)
        #labels = labels.transpose(-1,0)
        #inputs = inputs.reshape(inputs.shape(1), inputs.shape(0), inputs.shape(3), inputs.shape(4))
        #labels = labels.reshape(labels.shape(1), labels.shape(0), labels.shape(3), labels.shape(4))

      
        # Zero your gradients for every batch!
        optim.zero_grad()

        #2d
        # Make predictions for this batch
        outputs = model(inputs)

        #print("outputs", outputs.shape)

        # Compute the loss and its gradients
        loss = loss_func(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optim.step()

        #3d tensor, probably won't use keeping here just in case  
        '''
        #run everything while indexing through z axis
        for axis in images, masks:
          for idx in range(0, axis.size(1)):
            # Zero your gradients for every batch!
            optim.zero_grad()
            print(inputs[:, idx, : ,:])
            # Make predictions for this batch
            outputs = model(inputs[:, idx, : ,:])

            # Compute the loss and its gradients
            loss = loss_func(outputs[:, idx, : ,:], labels[:, idx, : ,:])
            loss.backward()

            # Adjust learning weights
            optim.step()
        '''
        # Gather data and report
        running_loss += loss.item()
        #if i % 1000 == 999:
            #print("goes in")
        last_loss = running_loss / 1000 # loss per batch
        print('  batch {} loss: {}'.format(i + 1, last_loss))
        tb_x = epoch_index * len(dummy_train_dataloader) + i + 1
        tb_writer.add_scalar('Loss/train', last_loss, tb_x)
        running_loss = 0.
        #if ends here

    print("loss", loss)
    return last_loss

Train Loop <br>
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html


In [18]:
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/spider_seg_{}'.format(timestamp))
epoch_number = 0


best_vloss = 1_000_000.

for epoch in range(epochs):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)
    print("avg loss in epoch", avg_loss)

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(dummy_test_dataloader):
            vinputs, vlabels = vdata

            voutputs = model(vinputs)
            vloss = loss_func(voutputs, vlabels)
            running_vloss += vloss
            

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        #commented out saving the model for now to debug loss being 0 
        '''
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)
        '''
    epoch_number += 1

EPOCH 1:


OutOfMemoryError: CUDA out of memory. Tried to allocate 448.00 MiB (GPU 0; 6.00 GiB total capacity; 5.15 GiB already allocated; 0 bytes free; 5.22 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF