In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import sys
sys.path.append("..")

In [32]:
import os
import glob
from tqdm import tqdm

import numpy as np
import cv2

import torch
from torch.autograd import Variable as V
from torchvision import transforms as trn
from feature_extraction.resnet import load_resnet

def get_video_from_mp4(file, sampling_rate):
    cap = cv2.VideoCapture(file)
    frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frameWidth = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frameHeight = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    buf = np.empty((int(frameCount / sampling_rate), frameHeight,
                   frameWidth, 3), np.dtype('uint8'))
    fc = 0
    ret = True
    while fc < frameCount and ret:
        fc += 1
        (ret, frame) = cap.read()
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        if fc % sampling_rate == 0:
            buf[int((fc - 1) / sampling_rate)] = frame

    cap.release()
    return np.expand_dims(buf, axis=0), int(frameCount / sampling_rate)


def get_predictions(model, video_list, sampling_rate=4):
    centre_crop = trn.Compose([
                trn.ToPILImage(),
                trn.Resize((224,224)),
                trn.ToTensor(),
                trn.Normalize([0.485, 0.456, 0.406],
                              [0.229, 0.224, 0.225])])

    for video_file in tqdm(video_list[:1]):
        print(video_file)
        vid, num_frames = get_video_from_mp4(video_file, sampling_rate)
        video_file_name = os.path.split(video_file)[-1].split(".")[0]

        for frame in range(num_frames):
            img =  vid[0,frame,:,:,:]
            input_img = V(centre_crop(img).unsqueeze(0))
            if torch.cuda.is_available():
                input_img = input_img.cuda()

            preds = model.forward(input_img)[-1]
            topk=(1,5)
            _, pred = preds.topk(max(topk), dim=1, largest=True, sorted=True)
            print(pred)
    return pred


In [33]:
video_dir = '../data/AlgonautsVideos268_All_30fpsmax/'
video_list = glob.glob(video_dir + '/*.mp4')
video_list.sort()
model = load_resnet("resnet50")
pred = get_predictions(model, video_list, sampling_rate=4)

  0%|          | 0/1 [00:00<?, ?it/s]

../data/AlgonautsVideos268_All_30fpsmax/0001_0-0-1-6-7-2-8-0-17500167280.mp4
tensor([[ 97,  99,  98, 137, 975]])
tensor([[137,  98,  99,  97, 148]])
tensor([[ 98,  99, 137, 148, 142]])
tensor([[ 98, 137,  99, 195, 143]])
tensor([[ 98, 137, 232,  99, 143]])
tensor([[ 98, 137,  97,  99, 148]])
tensor([[ 98, 137,  99, 975, 148]])
tensor([[ 98,  97, 137, 148,  99]])
tensor([[148,  98, 137, 975,  99]])
tensor([[ 98, 137, 148,  97,  99]])
tensor([[ 98, 137,  97, 148,  99]])
tensor([[ 98, 137,  99,  97, 975]])
tensor([[137,  98,  99,  97, 148]])
tensor([[ 98,  97,  99, 137, 143]])
tensor([[ 98,  97, 472, 137,  99]])
tensor([[ 98, 148, 975,  99, 472]])
tensor([[ 98,  97, 148, 472,  99]])
tensor([[ 98,  97, 148, 975, 472]])
tensor([[ 98,  97,  99, 975, 137]])
tensor([[ 98,  97,  99, 975, 472]])
tensor([[ 98,  97,  99, 128, 137]])


100%|██████████| 1/1 [00:05<00:00,  5.19s/it]

tensor([[ 98,  97,  99, 137, 148]])





In [42]:
pred[0][0].item()

98

In [43]:
file = open('imagenet_labels.txt', "r")
import ast
contents = file.read()
dictionary = ast.literal_eval(contents)

file.close()

dictionary[98]

'red-breasted merganser, Mergus serrator'