#### This file contains : get an output video showing ASL classification results from our model

In [None]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

In [None]:
path = "./"

In [None]:
import torch
import numpy as np
import cv2
from torchvision.transforms import transforms   
from PIL import Image
from glob import glob
import math
import os

In [None]:
device = (torch.device('cuda') if torch.cuda.is_available()
                  else torch.device('cpu'))

In [None]:
def eval_video(model, transform, videoFile, datasetclasses):
    videodir,videoFname = os.path.dirname(videoFile), os.path.basename(videoFile)
    tempdir = os.path.join(videodir,'temp')
    if not os.path.isdir(tempdir):
        os.makedirs(tempdir)
        print("created folder : ", tempdir)
    else:
        files = glob(tempdir+'/*')
        print('Clearing folder : ', tempdir)
        for f in files:
            os.remove(f)
    preddir = os.path.join(videodir,'pred')
    if not os.path.isdir(preddir):
        os.makedirs(preddir)
        print("created folder : ", preddir)
    else:
        files = glob(preddir+'/*')
        print('Clearing folder : ', preddir)
        for f in files:
            os.remove(f)
    
    cap = cv2.VideoCapture(videoFile)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frameRate = cap.get(5)
    outfrateratio=10
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))
    print(frame_width,frame_height)
    
    count = 0
    while(cap.isOpened()):
        frameId = cap.get(1)
        ret, frame = cap.read()
        if (frameId % math.floor(outfrateratio) == 0):
            framename = "_frame%d.jpg"%count
            count+=1
            framepath = os.path.join(tempdir,framename)
            cv2.imwrite(framepath, frame)
            
            img = Image.open(framepath)
            img_tensor = transform(img).to(device).unsqueeze(0)
            model.eval()
            output = model(img_tensor)
            _, index = torch.max(output, 1)
            index = index.cpu()
            prediction = datasetclasses[index]
            if prediction != 'nothing':
                cv2.putText(frame, prediction, (20, 500), cv2.FONT_HERSHEY_SIMPLEX, 8, (0, 200, 0), 40)
            
            predframepath = os.path.join(preddir,framename)
            cv2.imwrite(predframepath, frame)
        if(frameId == frame_count):
            break
    cap.release()
    return frameRate, frame_width, frame_height, preddir

def get_outmp4(videoFile, output_dir, model_path, test_transforms, classes):
    
    model = torch.jit.load(model_path)
    model = model.to(device)
    model_name = model_path.split('/')[-1].split('.')[0]
    output_name = model_name + 'out.mp4'

    frameRate,frame_width,frame_height,preddir = eval_video(model, test_transforms, videoFile, classes)

    output_path = os.path.join(output_dir, output_name)
    out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'MP4V'), frameRate, (frame_width,frame_height))
    img_array=[]
    for file in os.listdir(preddir):
        path = preddir+'/'+file
        img = cv2.imread(path)
        img_array.append(img)
    for i in range(len(img_array)):
        out.write(img_array[i])

    out.release()

In [None]:
videoFile = os.path.join(path, "data/asl/test_video/Cropped_Video.mp4")
output_dir = os.path.join(path, "data/asl/test_video/")
classes = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 
        'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 
        'W', 'X', 'Y', 'Z', 'del', 'nothing', 'space']
test_transforms = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5190, 0.4993, 0.5141),
                                                        (0.2280, 0.2555, 0.2637)),
                                transforms.Resize(224)])
                                
models_root = os.path.join(path, "data/asl/models/")
for f in glob(os.path.join(models_root, '*.pt')):
        print(f)
        get_outmp4(videoFile, output_dir, f, test_transforms, classes)