# Evaluate a Pre-Trained Segmentation Model in Colab

Demonstrates image pre-processing, prediction and validation statistics. But first, some preliminaries...

__Note:__ To maintain a high priority Colab user status such that sufficient GPU resources are available in the future, ensure to free the runtime when finished running this notebook. This can be done using 'Runtime > Manage Sessions' and click 'Terminate'.

In [None]:
# Check if notebook is running in Colab or local workstation
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    !ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
    !pip install gputil
    !pip install psutil
    !pip install humanize

import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()

try:
    # XXX: only one GPU on Colab and isn’t guaranteed
    gpu = GPUs[0]
    def printm():
        process = psutil.Process(os.getpid())
        print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
        print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
    printm() 

    # Check if GPU capacity is sufficient to proceed
    if gpu.memoryFree < 10000:
        print("\nInsufficient memory! Some cells may fail. Please try restarting the runtime using 'Runtime → Restart Runtime...' from the menu bar. If that doesn't work, terminate this session and try again later.")
    else:
        print('\nGPU memory is sufficient to proceeed.')
except:
    print('Select the Runtime → "Change runtime type" menu to enable a GPU accelerator, ')
    print('and then re-execute this cell.')

In [None]:
if IN_COLAB:

    from google.colab import drive
    drive.mount('/content/drive')
    DATA_PATH = r'/content/drive/My Drive/Data'
    
    # cd into git repo so python can find utils
    %cd '/content/drive/My Drive/cciw-zebra-mussel/predict'

    sys.path.append('/content/drive/My Drive')

In [None]:
import os
import os.path as osp

import csv
import glob

# for manually reading high resolution images
import cv2
import numpy as np

# for comparing predictions to lab analysis data frames
import pandas as pd

# for plotting
import matplotlib
# enable LaTeX style fonts
matplotlib.rc('text', usetex=True)
import matplotlib.pyplot as plt
plt.rc('text', usetex=True)
plt.rc('font', family='serif')

# pytorch core library
import torch
# pytorch neural network functions
from torch import nn

from sklearn.metrics import jaccard_score as jsc

# various helper functions, metrics that can be evaluated on the GPU
from task_3_utils import eval_binary_iou, pretty_image, img_to_nchw_tensor, mask_and_preds_to_1hot

sys.path.append("..") # Adds higher directory to python modules path.

from utils.dataset_2_utils import colour_fmt_crop_and_resize

In [None]:
"""Confim that this cell prints "Found GPU, cuda". If not, select "GPU" as 
"Hardware Accelerator" under the "Runtime" tab of the main menu.
"""
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Found GPU,', device)
    
sig = nn.Sigmoid()  # initializes a sigmoid function    

## 1. Load the dataset

The architecture is fully-convolutional network (FCN) 8s.

In [None]:
root_path = '/scratch/ssd/gallowaa/cciw/VOCdevkit/Validation-v101-originals/'
jpeg_files = glob.glob(osp.join(root_path, 'JPEGImages/') + '*.jpg')
png_files = glob.glob(osp.join(root_path, 'SegmentationClass/') + '*_crop.png')

jpeg_files.sort()
png_files.sort()

# Both should equal 55 for v1.0.1 in-situ dataset
print(len(jpeg_files)) 
print(len(png_files))

In [None]:
fontsize = 16

left = 0.02  # the left side of the subplots of the figure
right = 0.98   # the right side of the subplots of the figure
bottom = 0.05  # the bottom of the subplots of the figure
top = 0.95     # the top of the subplots of the figure
wspace = 0.15  # the amount of width reserved for space between subplots,
# expressed as a fraction of the average axis width
hspace = 0.1  # the amount of height reserved for space between subplots,
# expressed as a fraction of the average axis height

## 2. Load a pre-trained model checkpoint

In [None]:
os.environ['DATA_PATH'] = '/scratch/gallowaa'

In [None]:
root = osp.join(os.environ['DATA_PATH'], 'cciw/logs/cmp-dataset/train_v120/')
print(root)

In [None]:
#6 (resnet50) + 12 (slim) + 3 (unet)

In [None]:
#files = glob.glob(root + '*/*/*/*/*/*/checkpoint/*epoch79.ckpt')
#files = glob.glob(root + 'deeplabv3_resnet50/*/*/*/*/*/checkpoint/*epoch79.ckpt')
files = glob.glob(root + 'unet/*/*/*/*/*/checkpoint/*epoch79.ckpt')

print(len(files))
files

In [None]:
"""Set to True to save the model predictions in PNG format, 
otherwise proceed to predict biomass without saving images"""
PLOT = False
SAVE_PREDICTIONS = False
IOU_C2 = True

# src is the training dataset, tgt is the testing dataset
src = 'train_v120'
tgt = 'val_v101'

In [None]:
# for f in range(len(files)):
for f in range(3):
    print('==> Resuming from checkpoint..')

    checkpoint = torch.load(osp.join(files[f]))
    train_loss = checkpoint['trn_loss']
    val_loss = checkpoint['val_loss']

    net = checkpoint['net']
    last_epoch = checkpoint['epoch']
    torch.set_rng_state(checkpoint['rng_state'])

    net.eval()

    # later appended to figure filenames
    model_stem = files[f].split('/')[-1].split('.')[0]
    print('Loaded model %s trained to epoch ' % model_stem, last_epoch)
    print(
        'Cross-entropy loss {:.4f} for train set, {:.4f} for validation set'.format(train_loss, val_loss))

    if SAVE_PREDICTIONS:
        prediction_path = ''
        for t in files[f].split('/')[:-1]:
            prediction_path += t + '/'
        prediction_path = osp.join(prediction_path, 'predictions')
        if not osp.exists(prediction_path):
            os.mkdir(prediction_path)
        else:
            print('Folder', prediction_path, 'already exists')

        with open(osp.join(prediction_path, 'val_preds.csv'), 'w') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(['image index', 'miou'])

    iou_list = []

    #for i in range(3):
    for i in range(len(jpeg_files)):

        image_stem = jpeg_files[i].split('/')[-1].split('.')[0]
        bgr_lab = cv2.imread(osp.join(root_path, png_files[i]))
        bgr_img = cv2.imread(osp.join(root_path, jpeg_files[i]))

        imgc, mask = colour_fmt_crop_and_resize(
            bgr_img, bgr_lab, 0, scale_percent=100)

        nchw_tensor = img_to_nchw_tensor(imgc, device)

        with torch.no_grad():
            pred = sig(net(nchw_tensor))
            #pred = sig(net(nchw_tensor)['out']) # torchvision models

        pred_np = pred.detach().cpu().numpy().squeeze()

        if IOU_C2:

            p_one_hot, t_one_hot = mask_and_preds_to_1hot(pred_np, mask)

            iou = jsc(p_one_hot.reshape(1, -1),
                      t_one_hot.reshape(1, -1), average='samples')
        else:
            targets = torch.LongTensor(mask).to(device)
            iou = eval_binary_iou(pred, targets).item()

        print('Image %d of %d, IoU %.4f' % (i, len(png_files), iou))
        iou_list.append(iou)

        if SAVE_PREDICTIONS:
            with open(osp.join(prediction_path, 'val_preds.csv'), 'a') as logfile:
                logwriter = csv.writer(logfile, delimiter=',')
                logwriter.writerow([i, np.round(iou, 4)])
    
        if PLOT:
            plt.close('all')
            image = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
            fig, axes = plt.subplots(1, 3, figsize=(10, 4))
            axes = axes.flatten()
            axes[0].imshow(image)
            axes[0].set_title('Input', fontsize=fontsize)
            axes[1].imshow(image, alpha=0.75)
            axes[1].imshow(pred_np, alpha=0.5)
            axes[1].set_title('Input \& Preds, IoU = %.4f' %
                              iou, fontsize=fontsize)
            axes[2].imshow(mask)
            axes[2].set_title('Ground Truth', fontsize=fontsize)
            plt.subplots_adjust(left=left, bottom=bottom,
                                right=right, top=top, wspace=wspace, hspace=hspace)
            pretty_image(axes)

            if SAVE_PREDICTIONS:
                filename = src + '-' + tgt + '__' + image_stem + '__' + model_stem
                out_file = osp.join(prediction_path, filename)
                fig.savefig(out_file + '.png', format='png')
                #fig.savefig(out_file + '.eps', format='eps')

    if SAVE_PREDICTIONS:
        with open(osp.join(prediction_path, 'val_preds.csv'), 'a') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(
                ['mean', np.round(np.asarray(iou_list).mean(), 4)])
    else:
        print('Seed %d %.4f' % (f, np.asarray(iou_list).mean()))