In [3]:
import os
import cv2
import torch
import argparse
from pylab import *
import numpy as np
import torch.nn as nn
from model import UNet
import SimpleITK as sitk
from scipy import ndimage
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from data import SegThorDataset, Rescale, ToTensor, Normalize

In [4]:
# Data path
TRAIN_PATH = 'data/train'
TEST_PATH = 'data/test'
epochs = 5
data = []
label = []
#train_data = []
#test_data = []

In [3]:
def read_data(datapath, train):
    if train == 1: 
        for patient in os.listdir(datapath):
            image = os.path.join(datapath, patient, patient+'.nii.gz')   # Reading nifti image
            itkimage = sitk.ReadImage(image)
            
            label_img = os.path.join(TRAIN_PATH, patient, 'GT.nii.gz')       # Reading Ground Truth labels
            label_itkimage = sitk.ReadImage(label_img)
            

            # Convert the image to a  numpy array first and then shuffle the dimensions to get axis in the order z,y,x
            label_volume_array = sitk.GetArrayFromImage(label_itkimage)
            volume_array = sitk.GetArrayFromImage(itkimage)

            for s in range(1,volume_array.shape[0]):
                # Appending nifti images into an array
                slice_array = volume_array[s-1,:,:]
                resize_img = ndimage.interpolation.zoom(slice_array, zoom=0.25)   #Resize image  for faster training
                data.append(resize_img)
                
                # Appending ground truth labels into an array
                label_slice_array = label_volume_array[s-1,:,:]
                label_resize_img = ndimage.interpolation.zoom(label_slice_array, zoom=0.25)   #Resize image  for faster training
                label.append(label_resize_img)
        return data,label
                
    else:
        for patient in os.listdir(datapath):
            filename = os.path.join(datapath, patient, patient+'.nii.gz')
            itkimage = sitk.ReadImage(filename)

            # Convert the image to a  numpy array first and then shuffle the dimensions to get axis in the order z,y,x
            volume_array = sitk.GetArrayFromImage(itkimage)

            for s in range(1,volume_array.shape[0]):
                slice_array = volume_array[s-1,:,:]
                resize_img = ndimage.interpolation.zoom(slice_array, zoom=0.25)   #Resize image  for faster training
                data.append(resize_img)

        return data

In [7]:
# Creating class for model (later add it to seperate python file)
def train(epochs, batch_size, learning_rate):
    #train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        SegThorDataset("data", train=1,
                       transform=transforms.Compose([
                           Rescale(0.25),
                           ToTensor()
                       ]),
                       target_transform=transforms.Compose([
                           Rescale(0.25),
                           ToTensor()
                       ])),
        batch_size=batch_size, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)    #learning rate to 0.001 for intial
    
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch + 1, epochs))
        print('-' * 10)
        
        running_loss = 0.0
        for batch_idx, (train_data, labels) in enumerate(train_loader):
            train_data, labels = train_data.to(device), labels.to(device)

            optimizer.zero_grad()

            output = model(train_data)
            loss = F.binary_cross_entropy(output, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_loader)
        print("Loss: {:.4f}\n".format(epoch_loss))

    os.makedirs("models", exist_ok=True)
    torch.save(model, "models/model.pt")

## if __name__ == "__main__":
  #  train_data = read_data(TRAIN_PATH, train = 1)
 #  labels = read_labels(TRAIN_PATH)
 #   print('train_data shape:', np.array(train_data).shape)
 #   print('labels shape:', np.array(labels).shape)
    
    train(epochs=2, batch_size=4, learning_rate=0.001)
    
  #  test_data = read_data(TEST_PATH, train = 0)

    #print('train_data shape:', np.array(train_data).shape)
    #print('test_data shape:', np.array(test_data).shape)
#    parser = argparse.ArgumentParser()
 #   parser.add_argument("-d", "--datapath")
  #  load_itk(filename=args.filename)
