# U-Net Segmentation

A more sophisticated approach is needed to segment the wires from the images. Multiple ones are available but the first one used is the U-Net. 



In [1]:
# import modules (download with pip install first if not on local. Type on terminal: pip install <module name>)
import torch
import torch.nn as nn
import torch.nn.functional as F #contains some useful functions like activation functions & convolution operations you can use
import numpy as np

# install torchvision first

device = torch.device("cuda: 0" if torch.cuda.is_available() else "cpu")
print("Using",device,"...")

Using cpu ...


## Dataset

In [2]:
from load_data import *
# X, y = load_data('/content/drive/My Drive/AIWIRE', 'dataset')
# X, y = load_data('.', 'small_dataset')
X, y = load_data('.', 'data/iteration_1_dataset')

In [3]:
import dataset
from sklearn.model_selection import train_test_split

batch_size = 50
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
dataloaders = dataset.gen_dataloaders(X_train, X_test, y_train, y_test, batch_size)

{'train': 1600, 'val': 400}


In [None]:
from scipy.ndimage import gaussian_filter

def norm(im):
    im -= im.min()
    im = im / im.max()
    return im

def modify_gts(gts):
    gts_smoothed = [(norm(gaussian_filter(gt, sigma = 0.4)) > 0.01).astype(np.int) for gt in gts]
#     gts = [(1 - gaussian_filter(gt, sigma = 5)) ** (-0.03) for gt in gts]
#     for i in range(len(gts_smoothed)):
#         gt = gts_smoothed[i]
# #         gt[gt > 0] = (1 - gt[gt > 0]) ** (-3)
#         gts_smoothed[i] = gt - np.min(gt)
#         gts_smoothed[i] = gt / np.max(gt)
    return gts_smoothed

In [None]:
import matplotlib.pyplot as plt

loss_function = nn.MSELoss()

def show(im):
    plt.imshow(im, cmap='gray')
    plt.axis('off')

print(dataloaders['train'])
n = 3
first_batch = next(iter(dataloaders['train']))

ims = first_batch[0][:n]
gts_thin = first_batch[1][:n]
gts = modify_gts(gts_thin)
# gts = first_batch[1][:n].numpy()

print(gts[0].shape)
h2 = plt.hist(gts[0].reshape(-1,1), bins = 40, density = True, alpha=1)
plt.show()
h2 = plt.hist(ims[0].numpy().reshape(-1,1), bins = 40, density = True, alpha=1)
plt.show()

for [im], gt in zip(ims, gts):
    plt.figure(figsize=[5]*2)
    plt.subplot(121)
    show(im)

    plt.subplot(122)
    show(gt)

    base_loss = loss_function(torch.Tensor(np.ones(gt.shape) * np.average(gt)), torch.Tensor(gt))
    print('Class balance (one sigma): ', 100 * np.sum((gt - np.std(gt)) > 0) / np.prod(gt.shape), '%')
    print('Loss for null model:', base_loss.item())
    
    plt.show()

## Model

In [None]:
# Get a batch of training data
inputs, masks = next(iter(dataloaders['train']))

print(inputs.shape, masks.shape)
print(inputs.dtype, masks.dtype)

for x in [inputs.numpy(), masks.numpy()]:
    print(x.min(), x.max(), x.mean(), x.std())

# install torchsummary first
from torchsummary import summary
# import model from python file
import Unet_pytorch
import importlib
importlib.reload(Unet_pytorch)

model = Unet_pytorch.UNet(1)
model = model.to(device)

print("MODEL ARCHITECTURE ...")
summary(model, input_size=(1,128,64))

save_path = './saved_models/model_1'

## Training

In [None]:
from torch.utils.tensorboard import SummaryWriter
import datetime
writer = SummaryWriter('./logs/{0}'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
import matplotlib.gridspec as gridspec

import matplotlib.pyplot as plt
def visualise(outputs, title, n=7):
    fig = plt.figure(figsize=[10, 3])
    fig.suptitle(title, fontsize=14)
    for i, output in enumerate(outputs[:n]):
        plt.subplot(1,n,i+1)
        plt.axis('off')
        plt.imshow(output) if output.shape[0] == 3 else plt.imshow(output, cmap='gray')
    plt.show()

In [None]:
from collections import defaultdict
import torch.optim as optim
from torch.optim import lr_scheduler

import time
import copy

def train_model(model, optimizer, scheduler = None, num_epochs=25, train_stats_period=10):
    best_loss = 1e10

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step() if scheduler is not None else None
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            epoch_samples = 0
            loss_vec = []

            for batch_id, (inputs, labels) in enumerate(dataloaders[phase]):
                inputs = inputs.to(device)
                labels_mod = torch.Tensor(modify_gts(labels))
                labels_mod = labels_mod.to(device)             

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs.float())
                    loss = loss_function(torch.squeeze(outputs), labels_mod.type(torch.float32))
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        #print (loss.item())
                    loss_vec.append(loss.item())  

                    if batch_id % train_stats_period == 0:
                        print('Batch', batch_id, ':', loss.item())

                # statistics
                epoch_samples += inputs.size(0)

            # print loss at every epoch 
            epoch_loss = np.mean(np.asarray(loss_vec), dtype=np.float32)
            #epoch_loss = np.mean(np.asarray(loss_vec), dtype=np.float32)/float(epoch_samples)
            print('Loss ' + phase, ': {:.8f}'.format(epoch_loss))
            writer.add_scalar('Loss ' + phase, epoch_loss, epoch)

            if phase == 'val':
                visualise(inputs.squeeze(1).cpu().numpy(), 'Simulated images')
                preds = outputs.squeeze(1).cpu().numpy()
                visualise(preds, 'Network catheter predictions')
                colour_outputs = outputs.repeat(1, 3, 1, 1).permute(0, 2, 3, 1).cpu().numpy()
                colour_outputs = norm(colour_outputs)
                red_labels = np.zeros(colour_outputs.shape)
                red_labels[:,:,:,0] = labels * (1.3 - preds)
                red_labels[:,:,:,1] = labels * preds
                overlay = colour_outputs + red_labels
                overlay = norm(overlay)
                visualise(overlay, 'Predictions with ground truth overlay')
                
            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                torch.save(model.state_dict(), save_path)

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))

num_class = 1
model = Unet_pytorch.UNet(num_class).float()
model = model.to(device)

optimizer_ft = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9) 

# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)

train_model(model, optimizer_ft, num_epochs=100)

## Evaluate on <u>real</u> data

In [None]:
model.load_state_dict(torch.load(save_path))
model = model.to(device)
model.eval()

def infer(model, im):
    im = torch.Tensor(im).unsqueeze(0).unsqueeze(0).to(device)
    return model(im).squeeze(0).squeeze(0).cpu().detach().numpy()

print('Performance on real data')
import imageio
from PIL import Image
gif = imageio.get_reader('cropped_gif.gif')
test_im = np.array(list(gif))[100]
test_im = np.array(Image.fromarray(test_im).resize([64, 128]))
test_im = norm(test_im)
# test_im = norm(np.array(Image.fromarray(test_im).resize([64, 128])))
plt.subplot(121)
show(test_im)
plt.title('Real acquisition')

# test_im = torch.Tensor(ims[0].float()).unsqueeze(0).to(device)
test_pred = infer(model, test_im)
plt.subplot(122)
show(test_pred)
plt.title('Predicted catheter \nsegmentation')
plt.show()

### Testing skeletonisation

In [None]:
print('Skeletonisation')
test_im = np.squeeze(ims[0]).float()
from skimage.morphology import skeletonize
plt.subplot(141)
show(test_im)
plt.title('Simulated\n image')

plt.subplot(144)
show(np.squeeze(gts_thin[0]))
plt.title('Ground truth')

sim_pred = infer(model, test_im)
plt.subplot(142)
show(sim_pred > 0.5)
plt.title('Predicted\n segmentation')

plt.subplot(143)
show(skeletonize(sim_pred > 0.5))
plt.title('Skeletonised')
plt.show()