In [11]:
import cv2
import torch
import torchvision.models as models
import numpy as np
import skimage
from torch.autograd import Variable

In [4]:
def preprocess_frame(image, target_height=224, target_width=224):
    if len(image.shape) == 2:
        # 把单通道的灰度图复制三遍变成三通道的图片
        image = np.tile(image[:, :, None], 3)
    elif len(image.shape) == 4:
        image = image[:, :, :, 0]

    image = skimage.img_as_float(image).astype(np.float32)
    height, width, channels = image.shape
    if height == width:
        resized_image = cv2.resize(image, (target_height, target_width))
    elif height < width:
        resized_image = cv2.resize(image, (int(width * target_height / height),
                                           target_width))
        cropping_length = int((resized_image.shape[1] - target_height) / 2)
        resized_image = resized_image[:, cropping_length:resized_image.shape[1] - cropping_length]
    else:
        resized_image = cv2.resize(image, (target_height,
                                           int(height * target_width / width)))
        cropping_length = int((resized_image.shape[0] - target_width) / 2)
        resized_image = resized_image[cropping_length:resized_image.shape[0] - cropping_length]
    return cv2.resize(resized_image, (target_height, target_width))


In [5]:
vgg = models.vgg16()
vgg.load_state_dict(torch.load('./models/vgg16-00b39a1b.pth'))
vgg.eval()
vgg.cuda()

VGG (
  (features): Sequential (
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU (inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU (inplace)
    (4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU (inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU (inplace)
    (9): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU (inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU (inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU (inplace)
    (16): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), pa

In [13]:
cap = cv2.VideoCapture('./video/video0.mp4')

num_frames = 80
frame_count = 0
frame_list = []

while True:
    ret, frame = cap.read()
    if ret is False:
        break
    frame_list.append(frame)
    frame_count += 1

frame_list = np.array(frame_list)
if frame_count > 80:
    frame_indices = np.linspace(0, frame_count, num=num_frames, endpoint=False).astype(int)
    frame_list = frame_list[frame_indices]

cropped_frame_list = np.array([preprocess_frame(x) for x in frame_list]).transpose((0, 3, 1, 2))
cropped_frame_list = Variable(torch.from_numpy(cropped_frame_list), volatile=True).cuda()

In [16]:
print(cropped_frame_list)

Variable containing:
( 0 , 0 ,.,.) = 
  0.0921  0.0930  0.0931  ...   0.0555  0.0806  0.1046
  0.3706  0.3716  0.3755  ...   0.3329  0.3537  0.3681
  0.2971  0.2987  0.3059  ...   0.3220  0.3455  0.3657
           ...             ⋱             ...          
  0.7143  0.4808  0.4060  ...   0.0807  0.0620  0.0566
  0.7545  0.5010  0.3926  ...   0.0807  0.0621  0.0570
  0.7725  0.5125  0.3849  ...   0.0807  0.0621  0.0570

( 0 , 1 ,.,.) = 
  0.2333  0.2342  0.2343  ...   0.1421  0.1708  0.1966
  0.5172  0.5182  0.5222  ...   0.4166  0.4409  0.4570
  0.4892  0.4909  0.4981  ...   0.3812  0.4083  0.4285
           ...             ⋱             ...          
  0.7077  0.4651  0.3724  ...   0.0386  0.0306  0.0262
  0.7479  0.4853  0.3590  ...   0.0386  0.0308  0.0265
  0.7659  0.4968  0.3513  ...   0.0386  0.0308  0.0265

( 0 , 2 ,.,.) = 
  0.3195  0.3204  0.3206  ...   0.2548  0.2727  0.2940
  0.6148  0.6158  0.6198  ...   0.5382  0.5521  0.5641
  0.6814  0.6831  0.6902  ...   0.5766  0.5965

In [19]:
feats = vgg.features(cropped_frame_list)

In [None]:
print(f)