In [1]:
from __future__ import print_function
import pandas as pd
import tensorboard
from os import listdir
from os.path import isfile, join
import os
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import numpy as np
import torchvision.transforms.functional as F
import gc
from time import time
from datetime import datetime
from skimage.transform import resize
from skimage.io import imsave
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader ,TensorDataset
import glob
from torchvision.io import read_image
import random
import matplotlib.pyplot as plt
import torchvision.transforms as T
from IPython.display import clear_output
datapath=''
train_data_path = '' 
test_data_path = ''
image_depth=8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_image_paths = [] #to store image paths in list
masks = [] #to store masks
image_size = 352
size= 352

norm = 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class TrainImageDataset(Dataset):
    def __init__(self, transform=None, target_transform=None):
        self.img_dir =  os.path.join(datapath, 'Train_f/')
        self.lab_dir = os.path.join(datapath, 'Mask_f/')
        self.dirtra_lis= [f for f in listdir(self.img_dir) if isfile(join(self.img_dir, f))]
        self.dirlab_lis= [f for f in listdir(self.img_dir) if isfile(join(self.lab_dir, f))]
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dirtra_lis)

    def __getitem__(self, idx):
        #self.img = torch.Tensor(np.load('imgs_train.npy'))
        #self.img_labels = torch.Tensor(np.load('imgs_mask_train.npy'))
        imgpath=os.path.join(self.img_dir,self.dirtra_lis[idx])
        maskpath=os.path.join(self.lab_dir,self.dirlab_lis[idx])
        image = torch.Tensor(np.load(imgpath))
        label = torch.Tensor(np.load(maskpath))
        #image = T.Resize(size=size)(image)
        #label = T.Resize(size=size)(label)
        image=T.CenterCrop(size=size)(image)
        label=T.CenterCrop(size=size)(label)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [3]:
class ValImageDataset(Dataset):
    def __init__(self, transform=None, target_transform=None):
        self.test_dir =  os.path.join(datapath, 'Test_f/')
        self.val_dir = os.path.join(datapath, 'Val_f/')
        self.dirtest_lis= [f for f in listdir(self.test_dir) if isfile(join(self.test_dir, f))]
        self.dirval_lis= [f for f in listdir(self.val_dir) if isfile(join(self.val_dir, f))]
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dirtest_lis)

    def __getitem__(self, idx):
        #self.val_labels = torch.Tensor(np.load('val_test.npy'))
        #self.val = torch.Tensor(np.load('imgs_test.npy'))
        testpath=os.path.join(self.test_dir,self.dirtest_lis[idx])
        valpath=os.path.join(self.val_dir,self.dirval_lis[idx])
        image = torch.Tensor(np.load(testpath))
        label = torch.Tensor(np.load(valpath))
        #image = T.Resize(size=size)(image)
        #label = T.Resize(size=size)(label)
        image=T.CenterCrop(size=size)(image)
        label=T.CenterCrop(size=size)(label)
 
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [4]:
imgs_train=TrainImageDataset()
train_dataloader = DataLoader(imgs_train, batch_size=1, shuffle=True,num_workers=0,pin_memory=True)
z=len(imgs_train)
print(z)

1504


In [5]:
imgs_test=ValImageDataset()
validation_loader = DataLoader(imgs_test, batch_size=1, shuffle=False,num_workers=0,pin_memory=True)
print(len(imgs_test))

244


In [6]:
def conv_block_3d(in_dim, out_dim, activation):
    return nn.Sequential(
        nn.Conv3d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(out_dim),
        activation,)


def conv_trans_block_3d(in_dim, out_dim, activation):
    return nn.Sequential(
        nn.ConvTranspose3d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm3d(out_dim),
        activation,)


def max_pooling_3d():
    return nn.MaxPool3d(kernel_size=2, stride=2, padding=0)


def conv_block_2_3d(in_dim, out_dim, activation):
    return nn.Sequential(
        conv_block_3d(in_dim, out_dim, activation),
        nn.Conv3d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(out_dim),)

class UNet(nn.Module):
    def __init__(self, in_dim, out_dim, num_filters):
        super(UNet, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_filters = num_filters
        activation = nn.LeakyReLU(0.2, inplace=True)
        
        # Down sampling
        self.down_1 = conv_block_2_3d(self.in_dim, self.num_filters, activation)
        self.pool_1 = max_pooling_3d()
        self.down_2 = conv_block_2_3d(self.num_filters, self.num_filters * 2, activation)
        self.pool_2 = max_pooling_3d()
        self.down_3 = conv_block_2_3d(self.num_filters * 2, self.num_filters * 4, activation)
        self.pool_3 = max_pooling_3d()
        #self.down_4 = conv_block_2_3d(self.num_filters * 4, self.num_filters * 8, activation)
        #self.pool_4 = max_pooling_3d()
        #self.down_5 = conv_block_2_3d(self.num_filters * 8, self.num_filters * 16, activation)
        #self.pool_5 = max_pooling_3d()
        
        # Bridge
        self.bridge = conv_block_2_3d(self.num_filters * 4, self.num_filters * 8, activation)
        
        # Up sampling
        #self.trans_1 = conv_trans_block_3d(self.num_filters * 32, self.num_filters * 32, activation)
        #self.up_1 = conv_block_2_3d(self.num_filters * 48, self.num_filters * 16, activation)
        #self.trans_2 = conv_trans_block_3d(self.num_filters * 8, self.num_filters * 16, activation)
        #self.up_2 = conv_block_2_3d(self.num_filters * 24, self.num_filters * 8, activation)
        self.trans_3 = conv_trans_block_3d(self.num_filters * 8, self.num_filters * 4, activation)
        self.up_3 = conv_block_2_3d(self.num_filters * 8, self.num_filters * 4, activation)
        self.trans_4 = conv_trans_block_3d(self.num_filters * 4, self.num_filters * 2, activation)
        self.up_4 = conv_block_2_3d(self.num_filters * 4, self.num_filters * 2, activation)
        self.trans_5 = conv_trans_block_3d(self.num_filters * 2, self.num_filters , activation)
        self.up_5 = conv_block_2_3d(self.num_filters*2 , self.num_filters, activation)
        
        # Output
        self.out = conv_block_3d(self.num_filters, out_dim, activation)
    
    def forward(self, x):
        # Down sampling
        down_1 = self.down_1(x) # -> [1, 8, 8, 64, 64]
        pool_1 = self.pool_1(down_1) # -> [1, 1, 4, , 64, 64]
        
        down_2 = self.down_2(pool_1) # -> [1,1, 8, 64, 64, 64]
        pool_2 = self.pool_2(down_2) # -> [1, 8, 32, 32, 32]
        
        down_3 = self.down_3(pool_2) # -> [1, 16, 32, 32, 32]
        pool_3 = self.pool_3(down_3) # -> [1, 16, 16, 16, 16]
        
        #down_4 = self.down_4(pool_3) # -> [1, 32, 16, 16, 16]
        #pool_4 = self.pool_4(down_4) # -> [1, 32, 8, 8, 8]
        

        
        # Bridge
        bridge = self.bridge(pool_3) # -> [1, 128, 4, 4, 4]
        
        # Up sampling
        
        #trans_2 = self.trans_2(pool_4) # -> [1, 64, 16, 16, 16]
        #concat_2 = torch.cat([trans_2, down_4], dim=1) # -> [1, 96, 16, 16, 16]
        #up_2 = self.up_2(concat_2) # -> [1, 32, 16, 16, 16]
        
        trans_3 = self.trans_3(bridge) # -> [1, 32, 32, 32, 32]
        concat_3 = torch.cat([trans_3, down_3], dim=1) # -> [1, 48, 32, 32, 32]
        up_3 = self.up_3(concat_3) # -> [1, 16, 32, 32, 32]
        
        trans_4 = self.trans_4(up_3) # -> [1, 16, 64, 64, 64]
        concat_4 = torch.cat([trans_4, down_2], dim=1) # -> [1, 24, 64, 64, 64]
        up_4 = self.up_4(concat_4) # -> [1, 8, 64, 64, 64]
        
        trans_5 = self.trans_5(up_4) # -> [1, 8, 128, 128, 128]
        concat_5 = torch.cat([trans_5, down_1], dim=1) # -> [1, 12, 128, 128, 128]
        up_5 = self.up_5(concat_5) # -> [1, 4, 128, 128, 128]
        
        # Output
        out = self.out(up_5) # -> [1, 3, 128, 128, 128]
        return out


model = UNet(in_dim=1, out_dim=1, num_filters=8)

In [7]:
displ=0
model.to(device=device)
if(displ==0):
    clear_output(wait=True)

In [8]:
smooth=1
def dice_coef(y_true, y_pred):
    y_true_f = torch.flatten(y_true)
    y_pred_f = torch.flatten(y_pred)
    intersection = torch.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (torch.sum(y_true_f) + torch.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return ((-dice_coef(y_true, y_pred))).to(device)

In [9]:
# Optimizers specified in the torch.optim package
opt='adam'
if(opt=='sgd'):
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
if(opt=='adam'):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
#use_amp = True
#scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

In [10]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_dataloader):
        start=time()
        # Every data instance is an input + label pair
        #with torch.cuda.amp.autocast(enabled=use_amp):
        inputs, labels = data
        inputs=inputs.cuda()
        labels=labels.cuda()
        # Zero your gradients for every batch!
        optimizer.zero_grad(set_to_none=True)
        # Make predictions for this batch
        outputs = model(inputs)
        outputs=outputs.cuda()
        # Compute the loss and its gradients
        loss = dice_coef_loss(outputs, labels)
        #scaler.scale(loss).backward()
        #scaler.step(optimizer)
        #scaler.update()
        #optimizer.zero_grad(set_to_none=True)
            
        loss.backward()
        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += float(loss.item())
        if i == (len(imgs_train)-1):
            last_loss = running_loss / len(imgs_train) # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_dataloader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
        end=time()
        print(str(np.around(((end-start)*(z-i)/60),decimals=2))+'Min' + " Sec Remaining "+" Loss = "+str(float(loss.item())),end="                                   \r")
    return last_loss

In [None]:
#scaler = torch.cuda.amp.GradScaler()
pred_dir = "Predictions"
imgs = np.ndarray((1,16, size, size), dtype=np.uint8)
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%H %M %S')
writer = SummaryWriter('runs/UnetSeg_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 5
print("Started at " + timestamp)

best_vloss = 1
gc.enable()
for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.cuda
    model.train(True)
    start1=time()
    avg_loss = train_one_epoch(epoch_number, writer)
    end1=time()
    torch.cuda.empty_cache()
    # We don't need gradients on to do reporting
    model.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(validation_loader):
        
        startval=time()
        vinputs, vlabels = vdata
        vinputs=vinputs.cuda()
        vlabels=vlabels.cuda()
        voutputs = model(vinputs)
        vloss = dice_coef_loss(voutputs, vlabels)
        running_vloss += float(vloss)
        endval=time()
        print("Val Time= " + str(np.around((endval-startval),decimals=2)),end="                                                         \r")
    avg_vloss = float (running_vloss / (i + 1))
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    
    

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',{ 'Training' : avg_loss, 'Validation' : avg_vloss },epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state.
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)
    torch.cuda.empty_cache()
    epoch_number += 1
    print('Elsp_time ')
    print((str(np.around((end1-start1)*(EPOCHS-epoch_number)/60,decimals=1))+" Min "))
    gc.collect()
timestamp1 = datetime.now().strftime('%H %M %S')
print("End at " + timestamp1)

Started at 22 51 49
EPOCH 1:
47.47Min Sec Remaining  Loss = -6.759258667443646e-06                                     

In [None]:
#del(vinputs)
#del(vlabels)
#del(inputs)
#del(labels)
#del(train_dataloader)
#del(imgs_train)
gc.collect()

In [None]:
device = 'cpu'
model.to(device)
torch.cuda.empty_cache()

Prediction

In [None]:
flag=1
if(flag==1):
    imgs = np.ndarray((len(validation_loader),8, size, size), dtype=np.uint8)#Mask
    for i, vdata in enumerate(validation_loader):
        vinputs, vlabels = vdata
        #vinputs=vinputs.cuda()
        voutputs = model(vinputs)
        #print(voutputs.shape[0]+str(i)+ "\r")
        imgs[i,:,:,:]=voutputs.detach().numpy()
    print("Pred Finished")

In [None]:
print(imgs.shape)

In [None]:
pred_dir = "Predictions"
count_processed=0
ou_dir =os.path.join(pred_dir,'epoch' + str(epoch))
if not os.path.exists(pred_dir):
    os.mkdir(pred_dir)
for x in range(0, imgs.shape[0]):
    for y in range(0, imgs.shape[1]):
        imsave(os.path.join(pred_dir, 'predictions_' + str(x)+' slice '+str(y) + '.png'), (imgs[x][y])*255)
        count_processed += 1
print('Saving to .npy files done.')

ou_dir =os.path.join(pred_dir,'epoch' + str(epoch))
    if not os.path.exists(ou_dir):
        os.mkdir(ou_dir)
    for x in range(0, len(imgs_test)):
        imgs=voutputs[0,x].to(norm)
        imgs=imgs.detach().numpy()
        for y in range(0,16):
            imsave(os.path.join(ou_dir, 'pre_processed_' + str(y) + '.png'), ((imgs[0][y])*255))

from torch.nn import Module, Sequential
from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool1d, Dropout3d
from torch.nn import ReLU, Sigmoid
import torch


class UNet(Module):
    # __                            __
    #  1|__   ________________   __|1
    #     2|__  ____________  __|2
    #        3|__  ______  __|3
    #           4|__ __ __|4
    # The convolution operations on either side are residual subject to 1*1 Convolution for channel homogeneity

    def __init__(self, num_channels=1, feat_channels=[8, 16, 32, 64, 128], residual='conv'):
        # residual: conv for residual input x through 1*1 conv across every layer for downsampling, None for removal of residuals

        super(UNet, self).__init__()

        # Encoder downsamplers
        self.pool1 = MaxPool3d((2, 2, 2))
        self.pool2 = MaxPool3d((2, 2, 2))
        self.pool3 = MaxPool3d((2, 2, 2))
        self.pool4 = MaxPool3d((2, 2, 2))

        # Encoder convolutions
        self.conv_blk1 = Conv3D_Block(num_channels, feat_channels[0], residual=residual)
        self.conv_blk2 = Conv3D_Block(feat_channels[0], feat_channels[1], residual=residual)
        self.conv_blk3 = Conv3D_Block(feat_channels[1], feat_channels[2], residual=residual)
        self.conv_blk4 = Conv3D_Block(feat_channels[2], feat_channels[3], residual=residual)
        self.conv_blk5 = Conv3D_Block(feat_channels[3], feat_channels[4], residual=residual)

        # Decoder convolutions
        self.dec_conv_blk4 = Conv3D_Block(2 * feat_channels[3], feat_channels[3], residual=residual)
        self.dec_conv_blk3 = Conv3D_Block(2 * feat_channels[2], feat_channels[2], residual=residual)
        self.dec_conv_blk2 = Conv3D_Block(2 * feat_channels[1], feat_channels[1], residual=residual)
        self.dec_conv_blk1 = Conv3D_Block(2 * feat_channels[0], feat_channels[0], residual=residual)

        # Decoder upsamplers
        self.deconv_blk4 = Deconv3D_Block(feat_channels[4], feat_channels[3])
        self.deconv_blk3 = Deconv3D_Block(feat_channels[3], feat_channels[2])
        self.deconv_blk2 = Deconv3D_Block(feat_channels[2], feat_channels[1])
        self.deconv_blk1 = Deconv3D_Block(feat_channels[1], feat_channels[0])

        # Final 1*1 Conv Segmentation map
        self.one_conv = Conv3d(feat_channels[0], num_channels, kernel_size=1, stride=1, padding=0, bias=True)

        # Activation function
        self.sigmoid = Sigmoid()

    def forward(self, x):
        # Encoder part

        x1 = self.conv_blk1(x)

        x_low1 = self.pool1(x1)
        x2 = self.conv_blk2(x_low1)

        x_low2 = self.pool2(x2)
        x3 = self.conv_blk3(x_low2)

        x_low3 = self.pool3(x3)
        x4 = self.conv_blk4(x_low3)

        x_low4 = self.pool4(x4)
        base = self.conv_blk5(x_low4)

        # Decoder part

        d4 = torch.cat([self.deconv_blk4(base), x4], dim=1)
        d_high4 = self.dec_conv_blk4(d4)

        d3 = torch.cat([self.deconv_blk3(d_high4), x3], dim=1)
        d_high3 = self.dec_conv_blk3(d3)
        d_high3 = Dropout3d(p=0.5)(d_high3)

        d2 = torch.cat([self.deconv_blk2(d_high3), x2], dim=1)
        d_high2 = self.dec_conv_blk2(d2)
        d_high2 = Dropout3d(p=0.5)(d_high2)

        d1 = torch.cat([self.deconv_blk1(d_high2), x1], dim=1)
        d_high1 = self.dec_conv_blk1(d1)

        seg = self.sigmoid(self.one_conv(d_high1))

        return seg


class Conv3D_Block(Module):

    def __init__(self, inp_feat, out_feat, kernel=3, stride=1, padding=1, residual=None):

        super(Conv3D_Block, self).__init__()

        self.conv1 = Sequential(
            Conv3d(inp_feat, out_feat, kernel_size=kernel,
                   stride=stride, padding=padding, bias=True),
            BatchNorm3d(out_feat),
            ReLU())

        self.conv2 = Sequential(
            Conv3d(out_feat, out_feat, kernel_size=kernel,
                   stride=stride, padding=padding, bias=True),
            BatchNorm3d(out_feat),
            ReLU())

        self.residual = residual

        if self.residual is not None:
            self.residual_upsampler = Conv3d(inp_feat, out_feat, kernel_size=1, bias=False)

    def forward(self, x):

        res = x

        if not self.residual:
            return self.conv2(self.conv1(x))
        else:
            return self.conv2(self.conv1(x)) + self.residual_upsampler(res)


class Deconv3D_Block(Module):

    def __init__(self, inp_feat, out_feat, kernel=3, stride=2, padding=1):
        super(Deconv3D_Block, self).__init__()

        self.deconv = Sequential(
            ConvTranspose3d(inp_feat, out_feat, kernel_size=(kernel, kernel, kernel),
                            stride=(stride, stride, stride), padding=(padding, padding, padding), output_padding=1, bias=True),
            ReLU())

    def forward(self, x):
        return self.deconv(x)


class ChannelPool3d(AvgPool1d):

    def __init__(self, kernel_size, stride, padding):
        super(ChannelPool3d, self).__init__(kernel_size, stride, padding)
        self.pool_1d = AvgPool1d(self.kernel_size, self.stride, self.padding, self.ceil_mode)

    def forward(self, inp):
        n, c, d, w, h = inp.size()
        inp = inp.view(n, c, d * w * h).permute(0, 2, 1)
        pooled = self.pool_1d(inp)
        c = int(c / self.kernel_size[0])
        return inp.view(n, c, d, w, h)


if __name__ == '__main__':
    import time
    import torch
    from torch.autograd import Variable

    torch.cuda.set_device(0)
    model =UNet(residual='pool').cuda()
