# Video prediction using the trained model

## Navigate to the mmsegmentation directory

In [1]:
import os
os.chdir('mmsegmentation')

In [2]:
os.getcwd()

'/home/featurize/work/CFA/mmsegmentation'

## MMSegmentation Official Video/Camera Prediction

In [2]:
# !python3 demo/video_demo.py \
#          data/video_watermelon_3.mov \
#          Zihao-Configs/ZihaoDataset_KNet_20230712.py \
#          checkpoint/KNet_best_mIoU_iter_14000.pth \
#          --device cpu \
#          --opacity 0.5 \
#          --show

## OpenCV

### Import Packages

In [3]:
import time
import numpy as np
from tqdm import tqdm
import cv2

import mmcv
from mmseg.apis import init_model, inference_model

### Load the trained model

In [4]:
# Model config file
config_file = 'Zihao-Configs/ZihaoDataset_KNet_20230818.py'

# Model checkpoint file
checkpoint_file = 'checkpoint/KNet_best_mIoU_iter_14000.pth'

# Computing device
# device = 'cpu'
device = 'cuda:0'

In [6]:
model = init_model(config_file, checkpoint_file, device=device)

Loads checkpoint by local backend from path: checkpoint/KNet_best_mIoU_iter_14000.pth


### Color scheme for each category (BGR)

In [7]:
palette = [
    ['background', [127,127,127]],
    ['sandwich', [0,0,200]],
    ['fries', [0,200,0]],
    ['drink', [200,0,0]]
]

In [8]:
palette_dict = {}
for idx, each in enumerate(palette):
    palette_dict[idx] = each[1]

In [9]:
palette_dict

{0: [127, 127, 127], 1: [0, 0, 200], 2: [0, 200, 0], 3: [200, 0, 0]}

### Frame-by-Frame Processing Function

In [10]:
opacity = 0.3 # Opacity, higher values make it closer to the original image

In [11]:
def process_frame(img_bgr):
    
    # Record the start time of processing this frame
    start_time = time.time()
    
    # Perform semantic segmentation prediction
    result = inference_model(model, img_bgr)
    pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
    
    # Map predicted integer IDs to corresponding category colors
    pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
    for idx in palette_dict.keys():
        pred_mask_bgr[np.where(pred_mask==idx)] = palette_dict[idx]
    pred_mask_bgr = pred_mask_bgr.astype('uint8')
    
    # Overlay the semantic segmentation prediction on the original image
    pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1-opacity, 0)

    return pred_viz

### Frame-by-Frame Video Processing

In [12]:
def generate_video(input_path='data/demo_v0.mp4'):
    filehead = input_path.split('/')[-1]
    output_path = "out-" + filehead
    
    print('Video processing started:', input_path)
    
    # Get total frame count
    cap = cv2.VideoCapture(input_path)
    frame_count = 0
    while(cap.isOpened()):
        success, frame = cap.read()
        frame_count += 1
        if not success:
            break
    cap.release()
    print('Total frames:', frame_count)
    
    cap = cv2.VideoCapture(input_path)
    frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
    # fourcc = cv2.VideoWriter_fourcc(*'XVID')
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = cap.get(cv2.CAP_PROP_FPS)

    out = cv2.VideoWriter(output_path, fourcc, fps, (int(frame_size[0]), int(frame_size[1])))
    
    # Progress bar for total frames
    with tqdm(total=frame_count-1) as pbar:
        try:
            while(cap.isOpened()):
                success, frame = cap.read()
                if not success:
                    break

                try:
                    frame = process_frame(frame)
                except Exception as error:
                    print('Error:', error)
                    pass
                
                if success == True:
                    # cv2.imshow('Video Processing', frame)
                    out.write(frame)

                    # Update progress bar by one frame
                    pbar.update(1)
    
                # if cv2.waitKey(1) & 0xFF == ord('q'):
                    # break
        except:
            print('Interrupted')
            pass

    cv2.destroyAllWindows()
    out.release()
    cap.release()
    print('Video saved:', output_path)

In [None]:
generate_video(input_path='data/demo_v0.mp4')