# Task 1 a) Process Quadrat in MP4 Video
This notebook is configured to read a video specified by `file`, then extract the lines that compose the quadrat and estimate the appropriate corner points for cropping its 
contents. The output is a lower resolution MP4 video that is annotated with the found lines 
and corner points.

In [None]:
# Check if notebook is running in Colab or local workstation
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Search for all video files on Google Drive...
    from google.colab import drive
    drive.mount('/content/drive')
    DATA_PATH = r'/content/drive/My Drive/Data'
    
    # cd into git repo so python can find utils
    %cd '/content/drive/My Drive/cciw-zebra-mussel'
    
    # clone repo, install packages
else:
    DATA_PATH = r'/scratch/gallowaa/cciw/Data'

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

from glob import glob
#from utils import draw_lines

# https://ipython.org/ipython-doc/3/config/extensions/autoreload.html
#%load_ext autoreload
#%autoreload 2

figsize = (8, 6)
save_figures = False

In [None]:
import os
import os.path as osp

#import glob

# for manually reading high resolution images
import cv2
import numpy as np

# for comparing predictions to lab analysis data frames
import pandas as pd

# for plotting
import matplotlib
# enable LaTeX style fonts
matplotlib.rc('text', usetex=True)
import matplotlib.pyplot as plt
plt.rc('text', usetex=True)
plt.rc('font', family='serif')

# pytorch core library
import torch
# pytorch neural network functions
from torch import nn
# pytorch dataloader
from torch.utils.data import DataLoader

# for post-processing model predictions by conditional random field 
import pydensecrf.densecrf as dcrf
import pydensecrf.utils as utils

from tqdm import tqdm  # progress bar

# evaluation metrics
from sklearn.metrics import r2_score
from sklearn.metrics import jaccard_score as jsc

# local imports (files provided by this repo)
#import transforms as T

# various helper functions, metrics that can be evaluated on the GPU
#from task_3_utils import evaluate, evaluate_loss, eval_binary_iou, pretty_image

# Custom dataloader for rapidly loading images from a single LMDB file
#from folder2lmdb import VOCSegmentationLMDB

In [None]:
import sys
sys.path.append("..")
#from utils import draw_lines

In [None]:
from fcn import FCN8slim
net = FCN8slim(n_class=1).to(device)

In [None]:
if IN_COLAB:
    root = osp.join(DATA_PATH, 'Checkpoints/fcn8slim_lr1e-03_wd5e-04_bs32_ep50_seed1')
else:
    #root = '/scratch/gallowaa/cciw/logs/lab-v1.0.0/fcn8slim/lr1e-03/wd5e-04/bs32/ep50/seed1/checkpoint' # a
    #root = '/scratch/gallowaa/cciw/logs/v1.0.1-debug/fcn8s/lr1e-03/wd5e-04/bs25/ep80/seed4/checkpoint' # b
    #root = '/scratch/gallowaa/cciw/logs/v1.1.0-debug/fcn8s/lr1e-03/wd5e-04/bs25/ep80/seed9/checkpoint/' # c
    #root = '/scratch/gallowaa/cciw/logs/v111/trainval/fcn8s/lr1e-03/wd5e-04/bs40/ep80/seed2/checkpoint/' # d
    #root = '/scratch/gallowaa/cciw/logs/v111/trainval/fcn8slim/lr1e-04/wd5e-04/bs40/ep80/seed1/checkpoint/' # e
    root = '/scratch/gallowaa/cciw/logs/cmp-dataset/train_v120/deeplabv3_resnet50/lr1e-01/wd5e-04/bs40/ep80/seed3/checkpoint'
    #root = '/scratch/gallowaa/cciw/logs/cmp-dataset/trainval_v120/fcn8slim/lr1e-03/wd5e-04/bs50/ep80/seed1/checkpoint/'
    #root = '/scratch/gallowaa/cciw/logs/cmp-dataset/trainval_v120/fcn8slim/lr1e-03/wd5e-04/bs50/ep80/seed1/checkpoint/'

#ckpt_file = 'fcn8slim_lr1e-03_wd5e-04_bs32_ep50_seed1_epoch40.ckpt' # a
#ckpt_file = 'fcn8s_lr1e-03_wd5e-04_bs25_ep80_seed4_epoch70.ckpt' # b
#ckpt_file = 'fcn8s_lr1e-03_wd5e-04_bs25_ep80_seed9_epoch10.ckpt'
#ckpt_file = 'fcn8s_lr1e-03_wd5e-04_bs40_ep80_seed2amp_epoch79.pt' # d
#ckpt_file = 'fcn8slim_lr1e-04_wd5e-04_bs40_ep80_seed1amp_epoch79.pt' # e
ckpt_file = 'deeplabv3_resnet50_lr1e-01_wd5e-04_bs40_ep80_seed3_epoch79.ckpt'
#ckpt_file = 'fcn8slim_lr1e-03_wd5e-04_bs50_ep80_seed1_epoch79.pt'
#ckpt_file = 'fcn8slim_lr1e-03_wd5e-04_bs50_ep80_seed1amp_epoch79.pt'

"""Feel free to try these other checkpoints later after running epoch40 to get a 
feel for how the evaluation metrics change when model isn't trained as long."""
#ckpt_file = 'fcn8slim_lr1e-03_wd5e-04_bs32_ep50_seed1_epoch10.ckpt'
#ckpt_file = 'fcn8slim_lr1e-03_wd5e-04_bs32_ep50_seed1_epoch0.ckpt'

checkpoint = torch.load(osp.join(root, ckpt_file))
train_loss = checkpoint['trn_loss']
val_loss = checkpoint['val_loss']
print('==> Resuming from checkpoint..')
net = checkpoint['net']

# AMP
#net.load_state_dict(checkpoint['net'])
#amp.load_state_dict(checkpoint['amp'])
last_epoch = checkpoint['epoch'] + 1
torch.set_rng_state(checkpoint['rng_state'])

# later appended to figure filenames
model_stem = ckpt_file.split('.')[0]

print('Loaded model %s trained to epoch ' % model_stem, last_epoch)
print('Cross-entropy loss {:.4f} for train set, {:.4f} for validation set'.format(train_loss, val_loss))

net.eval()

In [None]:
from apex import amp

In [None]:
"""Set to True to save the model predictions in PNG format, 
otherwise proceed to predict biomass without saving images"""
SAVE_PREDICTIONS = True

if SAVE_PREDICTIONS:
    prediction_path = ''
    for t in root.split('/')[:-1]:
        prediction_path += t + '/'

    prediction_path = osp.join(prediction_path, 'videos')

    if not osp.exists(prediction_path):
        os.mkdir(prediction_path)
    
print(prediction_path)    

In [None]:
from apex import amp
net = amp.initialize(net, opt_level='O3')

In [None]:
def segmentation(im):
    img = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    img = img / 255.
    img = ((img - np.array([0.5, 0.5, 0.5])) / np.array([0.5, 0.5, 0.5]))
    imgt = torch.FloatTensor(img).to(device)
    imgt = imgt.unsqueeze(0)
    # Note: need to call contigious after the permute 
    # else max pooling will fail
    nchw_tensor = imgt.permute(0, 3, 1, 2).contiguous()
    with torch.no_grad():
        pred = sig(net(nchw_tensor)['out'])
    pred_np = pred.detach().cpu().numpy().squeeze()    
    return pred_np

In [None]:
sig = nn.Sigmoid()  # initializes a sigmoid function

In [None]:
#file = 'GLNI_456-CloseUp_2016-07-11_video-1.mp4'
#file = 'GLNI_456-2_2014-05-27_video-1.mp4'

#file = 'GLNI_456-1_2014-05-27_video-1.mp4'
#file = 'GLNI_1208-2_2014-08-27_video-1.mp4'
#file = 'GLNI_1208-1_2014-08-27_video-1.mp4'
#file = 'GLNI_1208-2_2014-06-04_video-2.mp4'
file = 'GLNI_1342-1_2014-08-21_video-1.mp4'
#file = 'GLNI_1342-3_2014-09-25_video-1.mp4'
#file = 'GLNI_1342-2_2014-09-25_video-1.mp4'
#file = 'GLNI_1342-2_2014-05-28_video-1.mp4'
#file = 'GLNI_1354-1_2014-05-28_video-1.mp4'

#file = 'GLNI_1347-3_2013-09-13_video-1.mp4'
#file = 'GLNI_1347-2_2013-09-13_video-1.mp4'
#file = 'GLNI_1347-1_2013-09-13_video-1.mp4'

#file = 'GLNI_503-1_2013-05-08_video-1.mp4'
#file = 'GLNI_1347-2_2013-05-01_video-1.mp4'
#file = 'GLNI_1347-3_2016-07-05_video-1.mp4'
#file = 'GLNI_1347-1_2016-07-05_video-1.mp4'
#file = 'GLNI_12-3_2016-07-11_video-1.mp4'
#file = 'GLNI_12-1_2016-07-11_video-1.mp4'
#file = 'GLNI_12-1_2016-07-11_video-1.mp4'
#file = 'GLNI_456-3_2015-07-17_video-1.mp4'

In [None]:
#all_videos[-2]

In [None]:
#all_videos = glob(os.path.join(DATA_PATH, 'Videos_and_stills/GLNI/*/*/*/Videos/Quad*/*.mp4'))
all_videos = glob(os.path.join(DATA_PATH, 'Videos_and_stills/GLNI/*/*/*/Videos/CloseUp/*.mp4'))
videotable_path = os.path.join(DATA_PATH, 'Tables', 'QuadratVideos.csv')
video_df = pd.read_csv(videotable_path, index_col=0)

vpath = video_df.iloc[video_df[video_df['Name'] == file].index]['Quadrat Video Path']
tokens = video_df[video_df['Name'] == file]['Quadrat Video Path'].values[0].split('\\')

video_path = DATA_PATH + '/Videos_and_stills/GLNI'
for tok in tokens[4:-1]:
    video_path += '/' + tok
    
video_path = os.path.join(video_path, file)
print('Loading video: ', video_path)
#all_videos[-2]


In [None]:
all_videos

In [None]:
"""These are meta-parameters of the Probabilistic Hough Line Transform, 
note there are additional params in cell 6 which we set according to 
the input resolution.

@param rho Distance resolution of the accumulator (pixels)"""
rho = 1  

"""
@param theta Angle resolution of the accumulator (radians)

Suggest to use a value no less than np.pi/90 else too many 
spurious lines will be found. theta=np.pi/45 means lines 
must differ by 4 degrees to be considered distinct."""
theta = np.pi / 45

In [None]:
#file

cap = cv2.VideoCapture(video_path)
sz = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
      int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
print('Raw input resolution', sz)

In [None]:
# read first frame to adjust resolution of output stream
ret, im = cap.read() 

# set additional meta-parameters according to input res
if sz[0] == 1440:
    """x_trim and y_trim are used to remove black padding 
    which triggers spurious edges"""
    x_trim, y_trim = 1, 145
    im = im[y_trim:-y_trim, x_trim:-x_trim, :]
    crop_frame_border = True
    '''
    """@param canny_thresh# hysteresis values for Canny edge 
    detector, input to HoughLines"""
    canny_thresh1, canny_thresh2 = 60, 300
    
    """@param threshold Accumulator threshold, return 
    lines with more than threshold of votes. (intersection points)"""
    threshold = 125
    
    """@param minLineLength Minimum line length. 
    Line segments shorter than that are rejected. (pixels)"""
    mLL = 400
    
    """@param maxLineGap Maximum allowed gap between points 
    on the same line to link them. (pixels)"""
    mLG = 150
    '''
else:
    # params as described above
    canny_thresh1, canny_thresh2 = 30, 300
    threshold = 125
    mLG, mLL = 250, 600
    crop_frame_border = False

"""this method may downsample, so set the video writer 
resolution to the processed image resolution"""
'''
img, edges = draw_lines(im, rho=rho, theta=theta, mll=mLL, 
                        mlg=mLG, threshold=threshold, ds=1)    
'''                    
sz = (im.shape[1], im.shape[0])
print(sz)

In [None]:
#osp.join(prediction_path, file.split('.')[0] + '_' + model_stem + '-quadrat-demo.mp4')

In [None]:
if cap.isOpened():
    fps = 20
    vout = cv2.VideoWriter()
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    outfile = file.split('.')[0] + '_' + model_stem + '-quadrat-demo.mp4'
    vout.open(osp.join(prediction_path, outfile), fourcc, fps, sz, True)
    print('Opened stream for writing, output resolution is', sz)
else:
    print('cap is not open')

In [None]:
"""Confim that this cell prints "Found GPU, cuda". If not, select "GPU" as 
"Hardware Accelerator" under the "Runtime" tab of the main menu.
"""
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Found GPU,', device)

In [None]:
currentFrame = 0

"""it can take 30s-1min to process entire video, 
can optionally process a small number of frames"""

'''
for _ in range(10):
    # Capture frame-by-frame
    ret, im = cap.read()
    if not ret: break
'''

# to process whole video    
#for _ in range(100):
while(True):
    
    # Capture frame-by-frame
    ret, im = cap.read()
    if not ret: break

    '''
    # For saving still images
    name = 'pframe_pi45_' + str(currentFrame) + '.jpg'
    save_path = os.path.join(out_path, name)
    print ('Creating...' + name)
    '''
    
    if crop_frame_border:
        im = im[y_trim:-y_trim, x_trim:-x_trim, :]
        
    pred_np = segmentation(im)
    
    # Do processing
    '''
    img, _ = draw_lines(im, rho=rho, theta=theta, mll=mLL, 
                        mlg=mLG, threshold=threshold, ds=1, 
                        canny_1=canny_thresh1, canny_2=canny_thresh2)
    '''
    
    """Save still image in jpeg format
    cv2.imwrite(save_path, img)"""
    
    """For annotating video
    @param org Bottom-left corner of the text string (default=50).
    @param org Bottom-left corner of the text string in the image (default=50).
    @param fontFace Font type, see #HersheyFonts (default=cv2.FONT_HERSHEY_PLAIN).
    @param fontScale Font scale factor that is multiplied 
                     by the font-specific base size (default=2).
    @param color Text color (default=(R=0, G=255, B=0)).
    @param thickness Thickness of the lines used to draw a text (default=1).
    @param lineType Line type. See #LineTypes (default=cv2.LINE_AA)."""
    '''
    cv2.putText(                 # x, y
        img, str(currentFrame), (50, 50), cv2.FONT_HERSHEY_PLAIN, 2, (0,255,0), 1, cv2.LINE_AA)
    '''
    
    #im[:, :, 0] += (pred_np * 255).astype('uint8')
    p = (pred_np * 255).astype('uint8')
    #p = cv2.cvtColor(p, cv2.COLOR_GRAY2RGB)
    src2 = np.zeros((p.shape[0], p.shape[1], 3), np.uint8)
    src2[:, :, 0] = p
    #p[:, :, 0] = p[:, :, 0] * 2
    dst = cv2.addWeighted(im, 0.5, src2, 0.5, 0)
    vout.write(dst)    

    # increment frame counter
    currentFrame += 1

"""When everything done, release the 
capture to flush the video stream. 

Before re-running this cell, first do 
cells 5, 6, and 7 to re-open cap and 
vout, otherwise ret=False and this 
cell does nothing."""
cap.release()
vout.release()

You should see new mp4 files when listing current dir. You may not be able to view them directly in Google Drive,
but these can be downloaded to your local machine.

In [None]:
#net.eval()
net.training

In [None]:
#im.shape

In [None]:
#plt.imshow((pred_np * 255).astype('uint8'))
#plt.imshow(im)

# End Demo

What follows is additional code for debugging

In [None]:
"""
# to seek into a specific frame
for _ in range(28):
    ret, im = cap.read()
"""    
"""
ret, im = cap.read()
if ret is not None:
    plt.imshow(im)
"""

In [None]:
'''
if crop_frame_border:
    im = im[y_trim:-y_trim, x_trim:-x_trim, :]

img, edges = draw_lines(im, rho=1, theta=np.pi/45, mll=mLL,
                        mlg=mLG, threshold=125, ds=1, 
                        canny_1=canny_thresh1, canny_2=canny_thresh2)
'''                        

In [None]:
#plt.figure(figsize=(14, 12))
#plt.imshow(img)
#plt.imshow(im[y_trim:-y_trim, x_trim:-x_trim, :])

In [None]:
#plt.imshow(edges)

In [None]:
"""it can be helpful to save particularly difficult
frames for processing in a different notebook"""
#cv2.imwrite('test.jpg', im)

In [None]:
"""
#cv2.circle(img, (800, 400), 10, (255, 0, 0), thickness=2, lineType=8, shift=0)
plt.figure(figsize=figsize)
plt.imshow(img)
plt.tight_layout()
if save_figures:
    plt.savefig('img/' + outpath + '-Step-2.png')
plt.show()
"""