# Prostate Segmentation

(Make sure HDF5 files contain dictionary with keys as "names","data" and "mask") <br>
Expected size of the volume - 96x96x64 centered cropped after resampling it to (1x1x1 mm) <br>
Recommended batch size of 8-10 to fit the GPU memory of 12 GB <br>
modelcheckpoint saved in the Data folder

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.models as models
import h5py
from torch.autograd import Variable
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt 
from pytorchtools import EarlyStopping
from random import randint
from segUtil import Modified3DUNet,ProstateDatasetHDF5,DiceLoss

### Reading train and val hdf5 files 

In [2]:
trainfilename = r"Data/train.h5"
valfilename = r"Data/val.h5"

train = h5py.File(trainfilename,libver='latest')
val = h5py.File(valfilename,libver='latest')

trainnames = np.array(train["names"])
valnames = np.array(val["names"])

train.close()
val.close()

### Prostate Dataset class

In [3]:
data_train = ProstateDatasetHDF5(trainfilename)
data_val = ProstateDatasetHDF5(valfilename)

### Creating DataLoader for training and validation data 

In [4]:
batch_size = 10
num_workers = 8

trainLoader = torch.utils.data.DataLoader(dataset=data_train,batch_size = batch_size,num_workers = num_workers,shuffle = True)
valLoader = torch.utils.data.DataLoader(dataset=data_val,batch_size = batch_size,num_workers = num_workers,shuffle = False) 

dataLoader = {}
dataLoader['train'] = trainLoader
dataLoader['val'] = valLoader

### Visualizing few images

In [2]:
def visualizeImagesTorch(data_train,samp):
    for i in range(3):     
        timg,mask,lb = data_train.__getitem__(i + samp)        
        timg = np.asarray(timg)
        
        print(lb)
        print(timg.min())
        print(timg.max())

        plt.subplot(121)
        plt.imshow(timg[0,20,:,:],cmap = 'gray',vmin = 0, vmax = 255)
        plt.subplot(122)
        plt.imshow(mask[0,20,:,:],cmap = 'gray',vmin = 0, vmax = 1)
        plt.show()
        
visualizeImagesTorch(dataLoader['train'].dataset,50)

### Defining UNet model

In [6]:
model = Modified3DUNet(1,2)
device = torch.device("cuda:0")
model.to(device)

print(device)

cuda:0


### Defining optimizer and loss function 

In [8]:
num_epochs = 200
learning_rate = 1e-4
weightdecay = 1e-3

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weightdecay)

print(len(dataLoader['train']))
print(len(dataLoader['val']))

281
6


### Running the model (make sure specify proper patience and model name to save the model with early stopping criteria)

In [None]:
patience = 10 
early_stopping = EarlyStopping(patience=patience, verbose=True)
modelname = r"unet"

In [1]:
niter_total=len(dataLoader['train'].dataset)/batch_size

for epoch in range(num_epochs):
    
    for phase in ["train","val"]:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode


        loss_vector=[]
        for ii,(data,mask,name) in enumerate(dataLoader[phase]):
            
            if ii % 100 == 0 : 
                print(ii)
            
            data = Variable(data.float().cuda(device))
            mask = Variable(mask.float().cuda(device))

            out,seg_layer = model(data)
            label = mask.permute(0, 2, 3, 4, 1).contiguous().view(-1).cuda(device)
        
            loss = DiceLoss(out[:,1], label)

            loss_vector.append(loss.detach().data.cpu().numpy())

            if phase=="train":
                optimizer.zero_grad()

                loss.backward()
                optimizer.step()  

        loss_avg=np.mean(loss_vector)
        torch.cuda.empty_cache()
        

        if phase == 'train':
            print("Epoch : {}, Phase : {}, Loss : {}".format(epoch,phase,loss_avg))
        else:
            print("                 Epoch : {}, Phase : {}, Loss : {}".format(epoch,phase,loss_avg))

            
            ind = randint(0,data.shape[0]-1)
            
            img = seg_layer.cpu().detach().numpy()
            
            
            plt.subplot(131)
            plt.imshow(data[ind,0,20])
            plt.axis('off')
            plt.subplot(132)
            plt.imshow(mask[ind,0,20])
            plt.axis('off')
            plt.subplot(133)
            plt.imshow(img[ind,1,20])
            plt.axis('off')
            plt.show()
            
            
        if phase == 'val':
            early_stopping(loss_avg, model, modelname,parentfolder = None)

        if early_stopping.early_stop:
            print("Early stopping")
            break

    if early_stopping.early_stop:
        break