In [None]:
import os
import cv2
from glob import glob
from tqdm.notebook import tqdm_notebook as tqdm 

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

### Mount drive and extract Zip from drive

In [None]:
import zipfile
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Extract the contents of zip file to current directory
zip_path = os.path.join(os.getcwd(),"drive","MyDrive","dataset.zip")
zip_ref = zipfile.ZipFile(zip_path, 'r')
zip_ref.extractall(os.getcwd())
zip_ref.close()

In [None]:
os.makedirs('./saved_model/tf2', exist_ok=True)
os.makedirs('./saved_model/keras', exist_ok=True)
os.makedirs('./saved_model/pytorch', exist_ok=True)
os.makedirs('./saved_model/torchscript', exist_ok=True)

### Read images

In [None]:
# train set
source_images_path = os.path.join(os.getcwd(),"dataset", "training", "images")
source_masks_path = os.path.join(os.getcwd(),"dataset", "training", "masks")

source_images = sorted(glob(os.path.join(source_images_path,"*")))
source_masks = sorted(glob(os.path.join(source_masks_path,"*")))

source_images = shuffle(source_images, random_state=1024)
source_masks = shuffle(source_masks, random_state=1024)


# test_set
test_images_path = os.path.join(os.getcwd(),"dataset", "testing", "images")
test_masks_path = os.path.join(os.getcwd(),"dataset", "testing", "masks")

test_images = sorted(glob(os.path.join(test_images_path,"*")))
test_masks = sorted(glob(os.path.join(test_masks_path,"*")))

### Create Train and Val dataset splits

In [None]:
def create_split(source_images, source_masks):
    train_x, val_x = train_test_split(source_images, test_size=0.05, random_state=77)
    train_y, val_y = train_test_split(source_masks, test_size=0.05, random_state=77)
    
    return (train_x, train_y), (val_x, val_y)

(train_x, train_y), (val_x, val_y) = create_split(source_images, source_masks)

print(f"TrainX: {len(train_x)} TrainY: {len(train_y)}")
print(f"TestX: {len(val_x)} TestX: {len(val_y)}")

### Image and Mask read functions

In [None]:
def read_image(path, _format=None):
    """ Read image, resize and scale"""
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = cv2.resize(x, (H, W))
    x = x / 255.0
    x = x.astype(np.float32)
    
    if _format == 'channel_first':
        x = np.moveaxis(x, -1, 0)
        
    return x

def read_mask(path, _format=None):
    """ Read mask and resize and scale"""
    x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    x = cv2.resize(x, (H, W))
    x = x / 255.0
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=-1)
    
    if _format == 'channel_first':
        x = np.moveaxis(x, -1, 0)
        
    return x

# ------------------------------------------------------------------------------------------------------------

# Tensorflow/ Keras Implementation

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras import backend as K
from tensorflow.keras.utils import CustomObjectScope

In [None]:
H, W = 224, 224   # Image height and width 
INPUT_CHANNEL = 3 
OUTPUT_CHANNEL = 1
INPUT_SHAPE = (H, W, INPUT_CHANNEL)

# Hyper parameters
batch_size = 8
lr = 1e-4
num_epochs = 10

### Data preprocessing pipeline

In [None]:
def preprocess(image_path, mask_path):
    """ Preprocess image and mask"""
    def f(image_path, mask_path):
        image_path = image_path.decode()
        mask_path = mask_path.decode()
        
        x = read_image(image_path)
        y = read_mask(mask_path)

        return x, y
    
    image, mask = tf.numpy_function(f, [image_path, mask_path], [tf.float32, tf.float32])
    image.set_shape([H, W, INPUT_CHANNEL])
    mask.set_shape([H, W, OUTPUT_CHANNEL])

    return image, mask

def tf_dataset(images, masks, batch):
    """ tf data processing """
    dataset = tf.data.Dataset.from_tensor_slices((images, masks))
    dataset = dataset.shuffle(buffer_size=batch*40)
    dataset = dataset.map(preprocess)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

In [None]:
# Create tf.data
train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
val_dataset = tf_dataset(val_x, val_y, batch=batch_size)

### UNet model - Encoder: VGG16

In [None]:
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, padding="same", use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same", use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def decoder_block(inputs, skip, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(inputs)
    x = Concatenate()([x, skip])
    x = conv_block(x, num_filters)

    return x

def build_vgg16_unet_tf(input_shape):    
    """ Input """
    inputs = Input(shape=input_shape)    ## (224, 224, 3)

    """ Pre-trained VGG16 """
    encoder = VGG16(include_top=False, weights="imagenet", input_tensor=inputs)
    
    """ freeze encoder layers """
    for layer in encoder.layers:
        layer.trainable = False


    """ Encoder """
    s1 = encoder.get_layer("block1_conv2").output    ## (224 x 224)
    s2 = encoder.get_layer("block2_conv2").output    ## (112 x 112)
    s3 = encoder.get_layer("block3_conv3").output    ## (56 x 56)
    s4 = encoder.get_layer("block4_conv3").output    ## (28 x 28)

    """ Bridge """
    b1 = encoder.get_layer("block5_conv3").output    ## (14 x 14)

    """ Decoder """
    d1 = decoder_block(b1, s4, 512)                  ## (28 x 28)
    d2 = decoder_block(d1, s3, 256)                  ## (56 x 56)
    d3 = decoder_block(d2, s2, 128)                  ## (112 x 112)
    d4 = decoder_block(d3, s1, 64)                   ## (224 x 224)

    """ Output """
    outputs = Conv2D(OUTPUT_CHANNEL, 1, padding="same", activation="sigmoid")(d4) ## (224, 224, 1)

    model = Model(inputs, outputs, name="VGG16_U-Net")
    return model

In [None]:
# Build UNet Model with VGG16 as Encoder
model_tf = build_vgg16_unet_tf(INPUT_SHAPE)
model_tf.summary()

### Metric, Compile and callbacks

In [None]:
# Dice Coefficient Metric
def dice_coef(y_true, y_pred, smooth=1.):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

In [None]:
# Compile the model
model_tf.compile(
    loss="binary_crossentropy",
    optimizer=tf.keras.optimizers.Adam(lr),
    metrics=[
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision(),
        "accuracy",
        dice_coef
    ]
)

### Training 

In [None]:
train_steps = len(train_x) // batch_size
val_steps = len(val_x) // batch_size

if len(train_x) % batch_size != 0:
    train_steps += 1

if len(val_x) % batch_size != 0:
    val_steps += 1

history = model_tf.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=num_epochs,
    steps_per_epoch=train_steps,
    validation_steps=val_steps,
    shuffle=True,
)

### Save model

In [None]:
# save as tf2 format 
tf.saved_model.save(model_tf, './saved_model/tf2')

# save as keras format
model_tf.save("./saved_model/keras/model_keras.h5")

### Load model

In [None]:
with CustomObjectScope({'dice_coef': dice_coef}):
    model_tf = tf.keras.models.load_model("./saved_model/keras/model_keras.h5")

### Loss and Metrics Visualization 

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
dice_coef = history.history['dice_coef']
val_dice_coef = history.history['val_dice_coef']
epochs = range(len(acc))

# Accuracy
plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()
plt.show()

# Loss
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend(loc=0)
plt.figure()
plt.show()

# Dice Coef
plt.plot(epochs, dice_coef, 'r', label='Training Dice coef')
plt.plot(epochs, val_dice_coef, 'b', label='Validation Dice coef')
plt.title('Training and validation dice coef')
plt.legend(loc=0)
plt.figure()
plt.show()

### Evaluation (on test set)

In [None]:
def intersection_over_union(y_true, y_pred):
    """ Function to calculate IOU """
    tn, fp, fn, tp = confusion_matrix(y_true.ravel(), y_pred.ravel()).ravel()
    iou = tp/(tp + fp + fn)
    return iou


def save_results(image, mask, y_pred, save_image_path):
    """ Fucntion that saves the original image, ground truth mask, predicted mask"""
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    line = np.ones((H, 10, 3)) * 128

    mask = np.expand_dims(mask, axis=-1)    # (224, 224, 1)
    mask = np.concatenate([mask, mask, mask], axis=-1)  # (224, 224, 3)
    mask = mask * 255

    y_pred = np.expand_dims(y_pred, axis=-1)    # (224, 224, 1)
    y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=-1)  # (224, 224, 3)
    y_pred = y_pred * 255

    cat_images = np.concatenate([image, line, mask, line, y_pred], axis=1)
    cv2.imwrite(save_image_path, cat_images)

In [None]:
# Evaluation on test set
SCORE = []
for x, y in tqdm(zip(test_images, test_masks), total=len(test_masks)):
    # Extract the name
    if os.name == 'nt': # windows
        name = x.split("\\")[-1].split(".")[0]
    else: # Linux
        name = x.split("/")[-1].split(".")[0]
    
    # Reading the image
    image = cv2.imread(x, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (H, W))
    x = image/255.0
    x = np.expand_dims(x, axis=0)

    # Reading the mask
    mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
    mask = mask/255.0
    mask = cv2.resize(mask, (H, W))
    mask = mask.astype(np.int32)

    # Prediction
    y_pred = model_tf.predict(x)[0]
    y_pred = np.squeeze(y_pred, axis=-1)
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(np.int32)

    # Saving the prediction
    save_image_path = os.path.join("results", "tf_keras", f"{name}.png")
    save_results(image, mask, y_pred, save_image_path)

    # Flatten the array
    mask = mask.flatten()
    y_pred = y_pred.flatten()

    # Calculating the metrics values
    acc_value = accuracy_score(mask, y_pred)
    iou = intersection_over_union(mask, y_pred)
    recall_value = recall_score(mask, y_pred, labels=[0, 1], average="binary")
    precision_value = precision_score(mask, y_pred, labels=[0, 1], average="binary")
    SCORE.append([name, acc_value, iou, recall_value, precision_value])

# Metrics values
score = [s[1:] for s in SCORE]
score = np.mean(score, axis=0)
print(f"Accuracy: {score[0]:0.5f}")
print(f"IOU: {score[1]:0.5f}")
print(f"Recall: {score[2]:0.5f}")
print(f"Precision: {score[3]:0.5f}")

# ------------------------------------------------------------------------------------------------------------

# Pytorch Implementation

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

from torch.utils.data import Dataset, DataLoader
from torchvision.models import vgg16

from torchmetrics.functional import accuracy, precision, recall

from torch_lr_finder import LRFinder

In [None]:
H, W = 224, 224   # Image height and width 
INPUT_CHANNEL = 3 
OUTPUT_CHANNEL = 1
INPUT_SHAPE = (INPUT_CHANNEL, H, W)

total_train_len = len(train_x)
total_val_len = len(val_x)

# Hyper parameters
batch_size = 8
lr = 1e-4
num_epochs = 20

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device, type(device)

### Data preprocessing pipeline

In [None]:
# Inheriting from torch.utils.data.Dataset

class DatasetPreprocessor(Dataset):
    def __init__(self, inputs: list, targets: list, device: torch.device):
        self.inputs = inputs
        self.targets = targets
        self.device = device

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index: int):
        
        # Select the sample
        input_ID = self.inputs[index]
        target_ID = self.targets[index]

        # Read input and target
        x = read_image(input_ID, _format='channel_first')
        y = read_mask(target_ID, _format='channel_first')

        # Typecasting and device
        x = torch.from_numpy(x)
        y = torch.from_numpy(y)

        return x, y

In [None]:
train_dataset = DatasetPreprocessor(train_x, train_y, device)

train_dataloader = DataLoader(
                        dataset=train_dataset, 
                        batch_size=batch_size,
                        shuffle=True, 
                        num_workers=0,
#                         pin_memory=True
                    )

val_dataset = DatasetPreprocessor(val_x, val_y, device)

val_dataloader = DataLoader(
                    dataset=val_dataset, 
                    batch_size=batch_size,
                    shuffle=False, 
                    num_workers=0,
#                     pin_memory=True
                )

In [None]:
len(train_dataloader), len(val_dataloader)

### UNet model - Encoder: VGG16

In [None]:
class conv_block(nn.Module):
    """ 
    Convolutional block:
    It follows a two 3x3 convolutional layer, each followed by a batch normalization and a relu activation.
    """
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_c)

        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_c)

        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x


class decoder_block(nn.Module):
    """ 
    Decoder block:
    The decoder block begins with a transpose convolution, followed by a concatenation with the skip
    connection from the conv block. Next comes the conv_block.
    Here the number filters decreases by half and the height and width doubles.
    """
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(2 * out_c, out_c)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)

        return x


class build_vgg16_unet_torch(nn.Module):
    def __init__(self):
        super().__init__()
        
        """ Pre-trained VGG16 """
        encoder = vgg16(pretrained=True)
        
        """ freeze the encoder layers """
        for param in encoder.parameters():
            param.requires_grad = False

        """ Encoder """
        self.e1 = nn.Sequential(*encoder.features[0:4])
        self.e2 = nn.Sequential(*encoder.features[4:9])
        self.e3 = nn.Sequential(*encoder.features[9:16])
        self.e4 = nn.Sequential(*encoder.features[16:23])

        """ Bottleneck """
        self.b1 = nn.Sequential(*encoder.features[23:30])

        """ Decoder """
        self.d1 = decoder_block(512, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)

        """ Classifier """
        self.logits = nn.Conv2d(64, OUTPUT_CHANNEL, kernel_size=1, padding=0)
        self.outputs = nn.Sigmoid()

    def forward(self, inputs):
        """ Encoder """
        s1 = self.e1(inputs)
        s2 = self.e2(s1)
        s3 = self.e3(s2)
        s4 = self.e4(s3)

        """ Bottleneck """
        b1 = self.b1(s4)

        """ Decoder """
        d1 = self.d1(b1, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)
        
        logits = self.logits(d4)
        
        """ Classifier """
        outputs = self.outputs(logits)
        
        return outputs

In [None]:
model_torch = build_vgg16_unet_torch()

model_torch.to(device)

In [None]:
summary(model_torch, input_size=INPUT_SHAPE, batch_size=-1, device='cuda')

### Metric, Loss and Optimizer

In [None]:
# Dice Coefficient Metric
def dice_coef(y_true, y_pred, smooth=1.):
    batch_size = y_pred.size(0)
    y_true_f = y_true.view(batch_size, -1).float()  # Flatten
    y_pred_f = y_pred.view(batch_size, -1).float()  # Flatten
    intersection = (y_true_f * y_pred_f).sum().float()
    return (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)

In [None]:
# Loss and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(filter(lambda param: param.requires_grad, model_torch.parameters()), lr=lr)

In [None]:
class Trainer:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.history = {
            "loss" : [],
            "val_loss" : [],
            "accuracy" : [],
            "val_accuracy" : [],
            "precision" : [],
            "val_precision" : [],
            "recall" : [],
            "val_recall" : [],
            "dice_coef" : [],
            "val_dice_coef" : [],
            "lr" : []
        }
        self.is_compiled = False
        
    def compile(self, criterion, optimizer, metrics):
        self.criterion = criterion
        self.optimizer = optimizer
        self.accuracy = metrics["accuracy"]
        self.precision = metrics["precision"]
        self.recall = metrics["recall"]
        self.dice_coef = metrics["dice_coef"]   
        
        self.is_compiled = True
        
    
    def calculate_metrics(self, preds, targets):
        preds = preds.to('cpu')
        targets = targets.to('cpu') > 0.5
        acc = self.accuracy(preds, targets, average='samples')
        pre = self.precision(preds, targets, average='samples')
        re = self.recall(preds, targets, average='samples')
        dice_coef = self.dice_coef(targets, preds)
        
        return acc, pre, re, dice_coef
        
    def epoch_runner(self, epoch, num_epochs, mode, dataloader, total_len):
        
        running_loss = 0.0
        running_accuracy = 0.0
        running_precision = 0.0
        running_recall = 0.0
        running_dice_coef = 0.0
        
        pbar = tqdm(enumerate(dataloader), total=len(dataloader), leave=True)
        pbar_postfix_keys = {
            key: key if mode == 'train' else f'{mode}_{key}'
            for key in ['loss', 'accuracy', 'precision', 'recall', 'dice_coef']
        }
        pbar_postfix_keys['lr'] = 'lr'

        for batch_idx, (inputs, targets) in pbar:
            batch_size = inputs.size(0)
            inputs = inputs.to(device=self.device, non_blocking=True)
            targets = targets.to(device=self.device, non_blocking=True) 
            
            if mode == 'train':
                self.model.train()
                # forward
                preds = self.model(inputs)
                loss = self.criterion(preds, targets)
        
                # backward
                self.optimizer.zero_grad()
                loss.backward()
        
                # gradient descent/ Adam step
                self.optimizer.step()
            
            elif mode == 'val':
                self.model.eval()
                # forward
                with torch.no_grad():
                    preds = self.model(inputs)
                    loss = self.criterion(preds, targets)
                
            # calculate custom metrics (averaged by batch_size)
            acc, pre, re, dice_coef = self.calculate_metrics(preds, targets)
            
            # update running vars
            running_loss += loss.item() * batch_size
            running_accuracy += acc.item() * batch_size
            running_precision += pre.item() * batch_size
            running_recall += re.item() * batch_size
            running_dice_coef += dice_coef.item() * batch_size
            
            # update progress bar
            pbar.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            
            if batch_idx < len(dataloader) - 1:
                # for all batches except for the last batch
                # averaged by batch_size
                pbar.set_postfix({
                    pbar_postfix_keys["loss"] : loss.item(),
                    pbar_postfix_keys["accuracy"] : acc.item(),
                    pbar_postfix_keys["precision"] : pre.item(),
                    pbar_postfix_keys["recall"] : re.item(),
                    pbar_postfix_keys["dice_coef"] : dice_coef.item(),
                    pbar_postfix_keys['lr'] : self.optimizer.param_groups[0]['lr']
                })
                
            else:
                # last batch -> thus dispaly total epoch loss & metrics
                epoch_loss = running_loss / total_len
                epoch_accuracy = running_accuracy / total_len
                epoch_precision = running_precision / total_len
                epoch_recall = running_recall / total_len
                epoch_dice_coef = running_dice_coef / total_len
                
                pbar.set_postfix({
                    pbar_postfix_keys["loss"] : epoch_loss,
                    pbar_postfix_keys["accuracy"] : epoch_accuracy,
                    pbar_postfix_keys["precision"] : epoch_precision,
                    pbar_postfix_keys["recall"] : epoch_recall,
                    pbar_postfix_keys["dice_coef"] : epoch_dice_coef,
                    pbar_postfix_keys['lr'] : self.optimizer.param_groups[0]['lr']
                })
        
        self.lr = self.optimizer.param_groups[0]['lr']
        
        return (
            epoch_loss,
            epoch_accuracy,
            epoch_precision,
            epoch_recall,
            epoch_dice_coef
        )   
            
    def fit(self, num_epochs, train_dataloader, total_train_len, val_dataloader=None, total_val_len=None):
        
        if not self.is_compiled:
            raise Exception("Please compile first!")
        
        for epoch in range(num_epochs):
            # Train set
            mode = 'train'
            loss, acc, pre, re, dice_coef = self.epoch_runner(
                                                    epoch, 
                                                    num_epochs,
                                                    mode,
                                                    train_dataloader, 
                                                    total_train_len
                                                )
            
            self.history["loss"].append(loss)
            self.history["accuracy"].append(acc)
            self.history["precision"].append(pre)
            self.history["recall"].append(re)
            self.history["dice_coef"].append(dice_coef)
            self.history["lr"].append(self.lr)
            
            # Validation set
            if val_dataloader is not None:
                mode = 'val'
                val_loss, val_acc, val_pre, val_re, val_dice_coef = self.epoch_runner(
                                                                            epoch, 
                                                                            num_epochs,
                                                                            mode,
                                                                            val_dataloader, 
                                                                            total_val_len
                                                                        )
                
                self.history["val_loss"].append(val_loss)
                self.history["val_accuracy"].append(val_acc)
                self.history["val_precision"].append(val_pre)
                self.history["val_recall"].append(val_re)
                self.history["val_dice_coef"].append(val_dice_coef)
                
        return self.history

In [None]:
trainer = Trainer(model_torch, device)
trainer.compile(
    criterion=criterion,
    optimizer=optimizer,
    metrics={
        "accuracy" : accuracy, 
        "precision": precision, 
        "recall": recall,
        "dice_coef" : dice_coef
    }
)

In [None]:
history = trainer.fit(num_epochs,
                      train_dataloader,
                      total_train_len,
                      val_dataloader=val_dataloader,
                      total_val_len=total_val_len
                    )

### save pytorch model along with optimizer

In [None]:
checkpoint = {
    'state_dict': trainer.model.state_dict(),
    'optimizer': trainer.optimizer.state_dict()
}

torch.save(checkpoint, "./saved_model/pytorch/model_torch.pth.tar")

# only save model
# torch.save(trainer.model.state_dict(), "./saved_model/pytorch/model_torch.pth")

### load pytorch model

In [None]:
# load pytorch model
checkpoint = torch.load("./saved_model/pytorch/model_torch.pth.tar")

model_torch.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

### save & load history

In [None]:
import pickle

with open('history.pickle', 'wb') as handle:
    pickle.dump(trainer.history, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('history.pickle', 'rb') as handle:
    b = pickle.load(handle)

### Loss and Metrics Visualization 

In [None]:
acc = history['accuracy']
val_acc = history['val_accuracy']
loss = history['loss']
val_loss = history['val_loss']
dice_coef = history['dice_coef']
val_dice_coef = history['val_dice_coef']
epochs = range(len(acc))

# Accuracy
plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()
plt.show()

# Loss
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend(loc=0)
plt.figure()
plt.show()

# Dice Coef
plt.plot(epochs, dice_coef, 'r', label='Training Dice coef')
plt.plot(epochs, val_dice_coef, 'b', label='Validation Dice coef')
plt.title('Training and validation dice coef')
plt.legend(loc=0)
plt.figure()
plt.show()

### Evaluation (on test set)

In [None]:
def intersection_over_union(y_true, y_pred):
    """ Function to calculate IOU """
    tn, fp, fn, tp = confusion_matrix(y_true.ravel(), y_pred.ravel()).ravel()
    iou = tp/(tp + fp + fn)
    return iou


def save_results(image, mask, y_pred, save_image_path):
    """ Fucntion that saves the original image, ground truth mask, predicted mask"""
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    line = np.ones((H, 10, 3)) * 128

    mask = np.expand_dims(mask, axis=-1)    # (224, 224, 1)
    mask = np.concatenate([mask, mask, mask], axis=-1)  # (224, 224, 3)
    mask = mask * 255

    y_pred = np.expand_dims(y_pred, axis=-1)    # (224, 224, 1)
    y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=-1)  # (224, 224, 3)
    y_pred = y_pred * 255

    cat_images = np.concatenate([image, line, mask, line, y_pred], axis=1)
    cv2.imwrite(save_image_path, cat_images)

In [None]:
model_torch.eval()

# Evaluation on test set
SCORE = []
for x, y in tqdm(zip(test_images, test_masks), total=len(test_masks)):
    # Extract the name
    if os.name == 'nt': # windows
        name = x.split("\\")[-1].split(".")[0]
    else: # Linux
        name = x.split("/")[-1].split(".")[0]
    
    # Reading the image
    image = cv2.imread(x, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (H, W))
    x = np.moveaxis(image, -1, 0)
    x = x/255.0
    x = x.astype(np.float32)
    x = np.expand_dims(x, axis=0)
    x = torch.from_numpy(x).to(device) 

    # Reading the mask
    mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
    mask = mask/255.0
    mask = cv2.resize(mask, (H, W))
    mask = mask.astype(np.int32)

    # Prediction
    with torch.no_grad():
        y_pred = model_torch(x)[0]
        
    y_pred = y_pred.cpu().detach().numpy()
    y_pred = np.squeeze(y_pred, axis=0)
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(np.int32)

    # Saving the prediction
    save_image_path = os.path.join("results", "pytorch", f"{name}.png")
    save_results(image, mask, y_pred, save_image_path)

    # Flatten the array
    mask = mask.flatten()
    y_pred = y_pred.flatten()

    # Calculating the metrics values
    acc_value = accuracy_score(mask, y_pred)
    iou = intersection_over_union(mask, y_pred)
    recall_value = recall_score(mask, y_pred, labels=[0, 1], average="binary")
    precision_value = precision_score(mask, y_pred, labels=[0, 1], average="binary")
    SCORE.append([name, acc_value, iou, recall_value, precision_value])

# Metrics values
score = [s[1:] for s in SCORE]
score = np.mean(score, axis=0)
print(f"Accuracy: {score[0]:0.5f}")
print(f"IOU: {score[1]:0.5f}")
print(f"Recall: {score[2]:0.5f}")
print(f"Precision: {score[3]:0.5f}")

### Webcam Inference

In [None]:
model_torch.eval()

BG_COLOR = [0, 255, 0]  # [R, G, B]

cap = cv2.VideoCapture(0)
    
while True:
    ret, frame = cap.read()
    
    if ret == False:
        cap.release()
        break

    h, w, _ = frame.shape
    ori_frame = frame
    frame = cv2.resize(frame, (H, W))
    frame = np.moveaxis(frame, -1, 0)
    frame = np.expand_dims(frame, axis=0)
    frame = frame / 255.0
    frame = frame.astype(np.float32)
    frame = torch.from_numpy(frame).to(device) 
    
    # prediction
    with torch.no_grad():
        mask = model_torch(frame)[0]
    
    mask = mask.cpu().detach().numpy()
    mask = np.squeeze(mask, axis=0)
    mask = cv2.resize(mask, (w, h))
    mask = mask > 0.5
    mask = mask.astype(np.float32)
    mask = np.expand_dims(mask, axis=-1)

    photo_mask = mask
    background_mask = np.abs(1-mask)

    masked_frame = ori_frame * photo_mask

    background_mask = np.concatenate([background_mask, background_mask, background_mask], axis=-1)
    background_mask = background_mask * BG_COLOR
    final_frame = masked_frame + background_mask
    final_frame = final_frame.astype(np.uint8)
        
    cv2.imshow('Window', final_frame)
      
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

### Save as torchscript model

In [None]:
# 1. using script
#    -> converts model by analyzing your python code
#    -> preserves control flow (condition, loops,etc - for e.g if/for loop in forward function)
#    -> May not cover 100% of operators

print("-----[USING SCRIPT]-----")
scripted_model = torch.jit.script(model_torch)
scripted_model.save('./saved_model/torchscript/scripted_model.pt')
# print(scripted_model.code)
print("-----[SCRIPT SAVED]-----")
print("-" * 80)


# 2. using trace
#    -> requires a sample input to trace the computaion path
#    -> does not preserve control flow
#    -> works with just about any code

print("-----[USING TRACE]-----")
sample_input = torch.rand(INPUT_SHAPE).unsqueeze(0).to(device=device)
traced_model = torch.jit.trace(model_torch, sample_input)
traced_model.save('./saved_model/torchscript/traced_model.pt')
print("-----[TRACE SAVED]-----")

In [None]:
# load scripted/traced model
# scriptedmodel = torch.jit.load('./saved_model/torchscript/scripted_model.pt')
tracedmodel = torch.jit.load('./saved_model/torchscript/traced_model.pt')