In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
import torchvision.transforms as T

In [2]:
# memorize all paths
videos_path = os.path.join(os.getcwd(), 'surgery.videos.hernitia') 
csv_path = os.path.join(os.getcwd(), 'video.phase.trainingData.clean.StudentVersion.csv')
labels_path = os.path.join(os.getcwd(), 'labels/labels.pkl')

# setting device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def transforms(resize = (60, 80)):
    """
    Description
    -------------
    Preprocess image screen before feeding it to a neural network.
    
    Parameters
    -------------
    resize : tuple, shape of the resized frame (default=(60,80))
    
    Returns
    -------------
    torchvision.transforms.transforms.Compose object, the composed transformations.
    """
    return T.Compose([T.ToPILImage(),
                T.Resize(resize),
                T.ToTensor()])

In [4]:
def get_frames(filename, resize = (60,80)):
    """
    Description
    -------------
    Resize and stack frames of a video

    Parameters
    -------------
    filename    : name of the video file (.mp4)
    resize      : tuple, shape of the resized frame (default=(60,80))

    Returns
    -------------
    tensor of dimension (#frames,channels,width,height)
    """
    frames = []
    video = cv2.VideoCapture(videos_path + '/' + filename + '.mp4')
    # checks whether frames were extracted
    success, image = video.read()
    while success:
        # brg -> rgb
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # send tensor image to device
        image = transforms(resize)(image)
        # save the frame to the list of frames
        frames.append(image)
        # function extract frames
        success, image = video.read()
    video.release()
    # stack frames
    frames = torch.stack(frames, dim=0).to(device)
    return frames

In [5]:
def count_frames(filename):
    """
    Description
    -------------
    Count number of frames in video

    Parameters
    -------------
    filename    : name of the video file (.mp4 or .mov)

    Returns
    -------------
    #frames in the video
    """
    video = cv2.VideoCapture(filename)
    totalframecount= int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    return totalframecount

In [6]:
def get_labels(filename):
    """
    Description
    -------------
    Produces tensor of labels for each frame of a video

    Parameters
    -------------
    filename    : name of the video file (.mp4 or .mov)

    Returns
    -------------
    tensor of dimension (#frames)
    """
    if not os.path.exists(labels_path): return 'no labels stored'
    # recover all labels
    all_labels = pd.read_pickle(labels_path)
    # recover labels of the video
    labels = torch.tensor([all_labels.loc[all_labels['videoName'] == filename]['label']])
    return labels