# SegNet Implementation - Segmentation

### Importacion de librerías

In [1]:
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

%matplotlib inline

import torch
from torch.autograd import Variable
from __future__ import print_function
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
import torch.optim as optim
from torch.nn import init
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [2]:
class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.NLLLoss2d(weight, size_average)

    def forward(self, inputs, targets):
        return self.nll_loss(F.log_softmax(inputs), targets)


def CrossEntropy2d(input, target, weight=None, size_average=False):
    # input:(n, c, h, w) target:(n, h, w)
    n, c, h, w = input.size()

    input = input.transpose(1, 2).transpose(2, 3).contiguous()
    input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0].view(-1, c)

    target_mask = target >= 0
    target = target[target_mask]
    #loss = F.nll_loss(F.log_softmax(input), target, weight=weight, size_average=False)
    loss = F.cross_entropy(input, target, weight=weight, size_average=False)
    if size_average:
        loss /= target_mask.sum().data[0]

    return loss

### Data

In [8]:
class LeftVentricleDataset(Dataset):
    def __init__(self, transform = None):
        
        data = sio.loadmat('cardiac-dig.mat')
        images_LV = data['images_LV']
        endo_LV = data['endo_LV']
        epi_lv = data['epi_LV']
        lv_phase = data['lv_phase']
        rwt = data['rwt']
        areas = data['areas']
        dims = data['dims']
        
        # 1 - MRI Left Ventricle
        # Swap of indices to use in PyTorch
        images_LV_m = np.swapaxes(images_LV,0,2)
        images_LV_m = np.swapaxes(images_LV_m, 1, 2)
        images_LV_m = torch.from_numpy(images_LV_m)
        self.images_LV = images_LV_m
        self.images_LV = images_LV_m[0:2020, : ,:]
        
        # 2 - Cardiac phase
        self.lv_phase = torch.from_numpy(lv_phase.flatten('C'))
        self.lv_phase = self.lv_phase[0:2020]
        
        # 3 - RWT (6 indices)
        self.rwt = torch.from_numpy(rwt)
        self.rwt = self.rwt[:, 0:2020]
        
        # 4 - Areas (cavity and myocardium)
        self.areas = torch.from_numpy(areas)
        self.areas = self.areas[:, 0:2020]
        
        # 5 - Dims
        self.dims = torch.from_numpy(dims)
        self.dims = self.dims[:, 0:2020]
        
        # 6 - Endo (segmentation)
        endo_LV_m = np.swapaxes(endo_LV,0,2)
        endo_LV_m = np.swapaxes(endo_LV_m, 1, 2)
        endo_LV_m = torch.from_numpy(endo_LV_m)
        self.endo_LV = endo_LV_m
        self.endo_LV = endo_LV_m[0:2020, : ,:]
        
        
    def __getitem__(self, index):
        
        image_lv = self.images_LV[index, :, :]
        
        endo_lv = self.endo_LV[index, :, :]
        
        phase_lv = self.lv_phase[index]
        
        rwt = self.rwt[:, index]
        
        areas = self.areas[:, index]
        
        dims = self.dims[:, index]
        
        sample = {'images_lv': image_lv, 'endo_lv': endo_lv, 'phase_lv': phase_lv, 
                  'rwt_lv': rwt, 'areas_lv': areas, 'dims_lv': dims}
        
        return sample

    def __len__(self):
        return self.images_LV.__len__()

In [9]:
dataset_lv = LeftVentricleDataset()

In [10]:
train_loader = torch.utils.data.DataLoader(dataset = dataset_lv, batch_size = 2)

# SegNet - ArXiv: https://arxiv.org/pdf/1511.00561.pdf

In [11]:
class SegNet(nn.Module):
    
    def __init__(self):
        super(SegNet, self).__init__()
        
        self.conv1_1 = nn.Conv2d(1, 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size = 3, stride = 1, padding = 1)
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size = 3, stride = 1, padding = 1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1)
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size = 3, stride = 1, padding = 1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
        self.conv5_1 = nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1)

        self.pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0, return_indices = True)
        self.pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0, return_indices = True)
        self.pool3 = nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0, return_indices = True)
        self.pool4 = nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0, return_indices = True)
        self.pool5 = nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0, return_indices = True)
        
        self.unpool5 = nn.MaxUnpool2d(kernel_size = 2, stride = 2, padding = 0)
        self.unpool4 = nn.MaxUnpool2d(kernel_size = 2, stride = 2, padding = 0)
        self.unpool3 = nn.MaxUnpool2d(kernel_size = 2, stride = 2, padding = 0)
        self.unpool2 = nn.MaxUnpool2d(kernel_size = 2, stride = 2, padding = 0)
        self.unpool1 = nn.MaxUnpool2d(kernel_size = 2, stride = 2, padding = 0)

        self.deconv5_1 = nn.ConvTranspose2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
        self.deconv5_2 = nn.ConvTranspose2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
        self.deconv5_3 = nn.ConvTranspose2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
        self.deconv4_1 = nn.ConvTranspose2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
        self.deconv4_2 = nn.ConvTranspose2d(512, 512, kernel_size = 3, stride = 1, padding = 1)
        self.deconv4_3 = nn.ConvTranspose2d(512, 256, kernel_size = 3, stride = 1, padding = 1)
        self.deconv3_1 = nn.ConvTranspose2d(256, 256, kernel_size = 3, stride = 1, padding = 1)
        self.deconv3_2 = nn.ConvTranspose2d(256, 256, kernel_size = 3, stride = 1, padding = 1)
        self.deconv3_3 = nn.ConvTranspose2d(256, 128, kernel_size = 3, stride = 1, padding = 1)
        self.deconv2_1 = nn.ConvTranspose2d(128, 128, kernel_size = 3, stride = 1, padding = 1)
        self.deconv2_2 = nn.ConvTranspose2d(128, 64, kernel_size = 3, stride = 1, padding = 1)
        self.deconv1_1 = nn.ConvTranspose2d(64, 64, kernel_size = 3, stride = 1, padding = 1)
        self.deconv1_2 = nn.ConvTranspose2d(64, 2, kernel_size = 3, stride = 1, padding = 1)

        self.batch_norm1 = nn.BatchNorm2d(64)
        self.batch_norm2 = nn.BatchNorm2d(128)
        self.batch_norm3 = nn.BatchNorm2d(256)
        self.batch_norm4 = nn.BatchNorm2d(512)
    
    def forward(self, x):
        
        size_1 = x.size()
        x = self.conv1_1(x)
        x = self.batch_norm1(x)
        x = F.relu(x)
        x = self.conv1_2(x)
        x = self.batch_norm1(x)
        x = F.relu(x)
        x, idxs1 = self.pool1(x)
        
        size_2 = x.size()
        x = self.conv2_1(x)
        x = self.batch_norm2(x)
        x = F.relu(x)
        x = self.conv2_2(x)
        x = self.batch_norm2(x)
        x = F.relu(x)
        x, idxs2 = self.pool2(x)
        
        size_3 = x.size()
        x = self.conv3_1(x)
        x = self.batch_norm3(x)
        x = F.relu(x)
        x = self.conv3_2(x)
        x = self.batch_norm3(x)
        x = F.relu(x)
        x = self.conv3_3(x)
        x = self.batch_norm3(x)
        x = F.relu(x)
        x, idxs3 = self.pool3(x)
        
        size_4 = x.size()
        x = self.conv4_1(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        x = self.conv4_2(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        x = self.conv4_3(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        x, idxs4 = self.pool4(x)
        
        size_5 = x.size()
        x = self.conv5_1(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        x = self.conv5_2(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        x = self.conv5_3(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        x, idxs5 = self.pool5(x)

        
        x = self.unpool5(x, idxs5, output_size = size_5)
        x = self.deconv5_1(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        x = self.deconv5_2(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        x = self.deconv5_3(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        
        x = self.unpool4(x, idxs4, output_size = size_4)
        x = self.deconv4_1(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        x = self.deconv4_2(x)
        x = self.batch_norm4(x)
        x = F.relu(x)
        x = self.deconv4_3(x)
        x = self.batch_norm3(x)
        x = F.relu(x)
        
        x = self.unpool3(x, idxs3, output_size = size_3)
        x = self.deconv3_1(x)
        x = self.batch_norm3(x)
        x = F.relu(x)
        x = self.deconv3_2(x)
        x = self.batch_norm3(x)
        x = F.relu(x)
        x = self.deconv3_3(x)
        x = self.batch_norm2(x)
        x = F.relu(x)
        
        x = self.unpool2(x, idxs2, output_size = size_2)
        x = self.deconv2_1(x)
        x = self.batch_norm2(x)
        x = F.relu(x)
        x = self.deconv2_2(x)
        x = self.batch_norm1(x)
        x = F.relu(x)
        
        x = self.unpool1(x, idxs1, output_size = size_1)
        x = self.deconv1_1(x)
        x = self.batch_norm1(x)
        x = F.relu(x)
        x = self.deconv1_2(x)
        
        return x

In [12]:
learning_rate = 0.01
num_epochs = 3500
model = SegNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate, momentum = 0.9, weight_decay = 0.005)

In [13]:
def to_img(x):
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 2, 80, 80)
    return x

In [14]:
for epoch in range(num_epochs):
    for batch_idx, data in enumerate(train_loader):
        img = data['images_lv']
        img = torch.unsqueeze(img, 1)
        img = img.type(torch.FloatTensor)
        
        label_endo = data['endo_lv']
        label_endo = torch.unsqueeze(label_endo, 1)
        label_endo = label_endo.type(torch.LongTensor)
        
        optimizer.zero_grad() # eliminar gradientes acumulados
        
        # Forward
        output_r = model(img)
        output = output_r.view(output_r.size(0),output_r.size(1), -1)
        output = torch.transpose(output,1,2).contiguous()
        output = output.view(-1, output.size(2))
        label_endo = label_endo.view(-1)
        loss = criterion(output, label_endo)
        
        # Backward
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print('Epoch [{}/{}], Loss:{:.4f}'
                .format(epoch+1, num_epochs, loss.data))
            
        if epoch % 1 == 0:
            pic = to_img(output_r.data)
            #prediction = output_r.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()
            #prediction = voc.colorize_mask(prediction)
            #print(prediction.shape)
            save_image(pic[1,1,:,:], './img_segnet/image_{}.png'.format(epoch))

Epoch [1/3500], Loss:0.6303
Epoch [1/3500], Loss:0.4070
Epoch [1/3500], Loss:0.3319
Epoch [1/3500], Loss:0.1834
Epoch [1/3500], Loss:0.4049
Epoch [1/3500], Loss:0.1089
Epoch [1/3500], Loss:0.0667
Epoch [1/3500], Loss:0.1085
Epoch [1/3500], Loss:0.0561
Epoch [1/3500], Loss:0.0432
Epoch [1/3500], Loss:0.1665
Epoch [1/3500], Loss:0.0671
Epoch [1/3500], Loss:0.0986
Epoch [1/3500], Loss:0.0607
Epoch [1/3500], Loss:0.0612
Epoch [1/3500], Loss:0.0641
Epoch [1/3500], Loss:0.1452
Epoch [1/3500], Loss:0.0572
Epoch [1/3500], Loss:0.0838
Epoch [1/3500], Loss:0.0357


KeyboardInterrupt: 