In [1]:
# In this notebook, we try to implement a Shallow CNN
# in PyTorch for "Retina Blood Vessel" dataset
#

In [2]:
# imports
#
import sys
import os
import glob
import time
import numpy as np
import scipy as sp
import skimage
from skimage import segmentation, io, filters, morphology
import sklearn
from sklearn import ensemble, metrics, svm
import matplotlib.pyplot as plt

import torch, torchvision

from IPython.core.debugger import set_trace

In [56]:
# Globals
#
train_image_ipath = 'RetinaBloodVessels/train/image/'
train_mask_ipath = 'RetinaBloodVessels/train/mask/'
test_image_ipath = 'RetinaBloodVessels/test/image/'
test_mask_ipath = 'RetinaBloodVessels/test/mask/'
NROWS, NCOLS = 512, 512
EPSILON = 1e-6
br = set_trace

In [437]:
# functions and classes

def read_images(path, rescale=True):
    images_fnames = sorted(glob.glob(os.path.join(path, '*.png')))
    images = []
    for fn in images_fnames:
        img = io.imread(fn)
        if rescale:
            img = np.float64(img)
            # img = (img - img.min()) / (img.max() + EPSILON)
            # img = 2*img - 1
            # img = np.float64(img)/img.max()
            img = (img - img.mean()) / img.std()
        images.append(img)
    images = np.array(images)
    return images

def gray(img):
    gr = img.mean(axis=2)
    gr = (gr - gr.min()) / (gr.max() - gr.min() + EPSILON)
    return gr

def show(img):
    if img.max != 255:
        img = np.float64(img)
        img = np.uint8(255*(img - img.min())/(img.max()-img.min() + EPSILON))
    fig = plt.figure(figsize=(5, 5))
    ax = fig.subplots()
    ax.imshow(img, cmap='gray')
    return True

class Activ(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Parameter(torch.randn(1))
    def forward(self, x):
        return torch.sin(self.a*torch.pi*x)

# ShallowCNN Block
class ShallowCNN(torch.nn.Module):
    def __init__(self, ksize=[(1, 1)], och=[1]):
        super().__init__()
        self.ksize = ksize
        self.conv = torch.nn.ModuleList()
        chc = 0
        for ks, oc in zip(ksize, och):
            self.conv.append(torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=3, out_channels=oc,
                                kernel_size=ks, stride=1, 
                                padding='same', bias=True),
                # torch.nn.LogSigmoid()))
                # torch.nn.Identity()))
                torch.nn.ReLU()))
                # torch.nn.Tanh()))
                # Activ()))
            chc += oc
        self.aggregate = torch.nn.Conv2d(in_channels=chc, out_channels=1,
                                         kernel_size=1, stride=1,
                                         padding='same', bias=True)
        self.sigmoid = torch.nn.Sigmoid()
        
        # weight = torch.ones_like(self.conv.weight)
        # weight /= weight.sum()
        # self.conv.weight = torch.nn.parameter.Parameter(weight,
        #                                                 requires_grad=False)
        
    def forward(self, batch):
        # print(batch.shape)
        out = []
        for layer in self.conv:
            out.append(layer(batch))
            # print(out[-1].shape)
        out = torch.concatenate(out, axis=1)
        out = self.aggregate(out)
        out = self.sigmoid(out)
        # print(out.shape)
        return out


class DiceLoss(torch.nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        # #comment out if your model contains a sigmoid or equivalent activation layer
        # inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

        return 1 - dice

# Dice Binary Cross Entropy Coefficient
class DiceBCELoss(torch.nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        # #comment out if your model contains a sigmoid or equivalent activation layer
        # inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        BCE = torch.nn.functional.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss

        return Dice_BCE
    
class Train:
    """
    Initialize, train, evaluate, decode
    """
    def __init__(self, model, data, label, nepoch=4, bsize=8):
        """
        data: [batch, num_channel, rows, cols]
        label: [batch, num_channel, rows, cols]
        """
        self.model = model
        self.data = data
        self.label = label
        self.nepoch = nepoch
        self.device = (torch.device("mps")
                       if torch.backends.mps.is_available()
                       else torch.device('cpu'))
        # self.crit = torch.nn.MSELoss()
        # self.crit = DiceLoss()
        self.crit =DiceBCELoss()
        self.bsize = bsize
    def run(self, lr=1e-4):
        model = self.model
        model = model.to(self.device)
        crit = self.crit
        optim = torch.optim.Adam(model.parameters(), lr=lr)
        # sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
        #     optim, factor=0.5, patience=4, threshold=0.001)
        t1 = time.time()
        for epoch in range(self.nepoch):
            # t1 = time.time()
            model, optim = self._train(model, crit, optim)
            # model.train()
            # for bc in range(1, self.data.shape[0] // self.bsize):
            #     slc = slice((bc-1)*self.bsize, bc*self.bsize)
            #     batch, lb = self.data[slc], self.label[slc]
            #     batch, lb = batch.to(self.device), lb.to(self.device)
            #     optim.zero_grad()
            #     out = model(batch).squeeze(dim=1)
            #     loss = crit(out, lb)
            #     loss.backward()
            #     optim.step()
            
            loss = self._valid(model, crit)
            # sch.step(loss)
            # if lr != optim.param_groups[0]['lr']:
            #     lr = optim.param_groups[0]['lr']
            #     print(f'Learning rate changed to {lr:.04f}.')
            if (epoch % 10 == 0) | (epoch == (self.nepoch-1)):
                print(f'Ep: {epoch}, Secs: {time.time() - t1:.0f}, ' + 
                      f'loss: {loss:.04f}')
                t1 = time.time()
            
    def _train(self, model, crit, optim):
        model.train()
        for bc in range(1, self.data.shape[0] // self.bsize):
            slc = slice((bc-1)*self.bsize, bc*self.bsize)
            batch, lb = self.data[slc], self.label[slc]
            batch, lb = batch.to(self.device), lb.to(self.device)
            optim.zero_grad()
            out = model(batch).squeeze(dim=1)
            loss = crit(out, lb)
            loss.backward()
            optim.step()
        return model, optim
    
    def _valid(self, model, crit):
        model = model.to(self.device)
        model.eval()
        loss_sum = 0.0
        with torch.no_grad():
            for bc in range(1, self.data.shape[0] // self.bsize):
                slc = slice((bc-1)*self.bsize, bc*self.bsize)
                batch, lb = self.data[slc], self.label[slc]
                batch, lb = batch.to(self.device), lb.to(self.device)
                out = model(batch).squeeze(dim=1)
                loss = crit(out, lb)
                loss_sum += loss.item()
        return loss_sum / bc

In [358]:
# read all images and masks and store in two matrices
train_images = read_images(train_image_ipath)
train_masks = 1 * (read_images(train_mask_ipath, rescale=False) > 0)
test_images = read_images(test_image_ipath)
test_masks = 1 *(read_images(test_mask_ipath, rescale=False) > 0)
print(train_images.shape, train_masks.shape,
      test_images.shape, test_masks.shape)
print(train_images.min(), train_images.max(), train_images.mean())

(80, 512, 512, 3) (80, 512, 512) (20, 512, 512, 3) (20, 512, 512)
-1.5989004704853944 4.097728417232479 1.5034270125132329e-18


In [359]:
# ibat = torch.Tensor(train_images[:2, :, :, :])
# ibat = torch.swapdims(ibat, 1, 3)
# scnn = ShallowCNN(ksize=[1, 3, 5],
#                   och=[20, 20, 20])
# print(scnn)
# # obat = scnn(ibat)
# # for iimg, oimg in zip(ibat, obat):
# #     show(torch.swapdims(iimg, 0, 2).detach().numpy())
# #     show(torch.swapdims(oimg, 0, 2).detach().numpy())

In [451]:
# scnn = ShallowCNN(ksize=[1, 3, 5, 7, 9, 11],
#                   och = [1, 1, 1, 1, 1, 1])
# scnn = ShallowCNN(ksize=[(1, 1), (3, 1), (1, 3), (3, 3),],
#                          # (1, 5), (5, 1), (3, 5), (5, 3), (5, 5)],
#                   och = [1, 1, 1, 1])
scnn = ShallowCNN(ksize=[(1, 1), (3, 3), (5, 5), (7, 7)],
                  och = [1, 1, 1, 1])
train = Train(scnn, torch.Tensor(train_images).swapdims(1, 3),
              torch.Tensor(train_masks), nepoch=80, bsize=4)
t0 = time.time()
train.run(lr=1.0e-2)
print(f'Finished in {time.time() - t0:.0f} seconds.')
# Records:
#    Time: 299, Loss: 1.1772 
#    Time: 313, Loss: 1.1651
#    Time: 63,  Loss: 1.1791
#    Time: 84,  Loss: 1.1367

Ep: 0, Secs: 1, loss: 1.1561
Ep: 10, Secs: 6, loss: 1.1437
Ep: 20, Secs: 6, loss: 1.1416
Ep: 30, Secs: 6, loss: 1.1422
Ep: 40, Secs: 6, loss: 1.1392
Ep: 50, Secs: 6, loss: 1.1386
Ep: 60, Secs: 6, loss: 1.1393
Ep: 70, Secs: 6, loss: 1.1396
Ep: 79, Secs: 6, loss: 1.1379
Finished in 52 seconds.


In [None]:
# One issue with CNN is that it get the convolution and then the activation
# is used. Instead, if we could have a function with all local values as input
# then a better nonlinearity could be found.

# Using kernel==1 and polynomial activation, we can use AvgPool2d to design
# a locally optimized segmentor.
# AvgPool2d is a convolution with an all-one window. So, nothing!

# We can shift an image into different directions to make an overlapped 
# neighbors. Then we can define a PyTorch multi-input nonlinear function and
# then try to optimize its parameters to estimate the best function for 
# the segmentation.