In [195]:
import os
import sys
import json

import cv2
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

sys.path.append('..')
import data
import model
import utils

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [196]:
params = utils.yaml_to_dict('../config.yml')
params['data_dir'] = os.path.join('..', params['data_dir'])
params['videos_folder'] = os.path.join('..', params['videos_folder'])
params['feature_map_folder'] = os.path.join('..', params['feature_map_folder'])
params['json_data_path'] = os.path.join('..', params['json_data_path'])
params['json_metadata_path'] = os.path.join('..', params['json_metadata_path'])

In [197]:
with open(params['json_metadata_path']) as data_file:
    metadata_json = json.load(data_file)

label_by_idx = {
    'level_3': {},
    'level_2': {},
    'level_1': {},
    'level_0': {}
}


In [198]:
for key, item in metadata_json.items():
    if key != 'classes_amount':
        label_by_idx['level_3'][item['idx']] = key
        for level in ['level_2', 'level_1', 'level_0']:
            label_by_idx[level][item[level]['idx']] = item[level]['name']


In [199]:
data_gen_test = data.DataGenerator(params, 'validation')

In [200]:
estimator = tf.estimator.Estimator(
    # Custom model function
    model_fn=model.model_fn,
    params=params,
    # Model directory
    model_dir=params['model_dir'],
    # warm_start_from=cfg.PRE_TRAIN,
    config=tf.estimator.RunConfig(
        keep_checkpoint_max=params['keep_checkpoint_max'],
        save_checkpoints_steps=params['save_checkpoints_steps'],
        save_summary_steps=params['save_summary_steps'],
        log_step_count_steps=params['log_step_count_steps']
    )
)

In [201]:
predictions = estimator.predict(
    input_fn = lambda: data.input_fn(data_gen_test, False, params)
)

In [202]:
prediction_results = {
    "results": {}
}
available_formats = ['.mkv', '.webm', '.mp4'] 
predictions_by_video = {}

for item in predictions:
    video_id = item['metadata'].decode('utf-8')
    batch_num = int(video_id.split('batch')[-1].replace('_', ''))
    video_id = video_id.split('batch')[0].replace('_', '')
    if video_id not in prediction_results['results']:
        prediction_results['results'][video_id] = {}

    for vformat in available_formats:
        video_path = os.path.join(params['videos_folder'] + '/validation', video_id + vformat)
        if os.path.isfile(video_path):
            break

    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_id = 0
    classes_pred = np.ones(item['probabilities'].shape[0]) * item['classes']
    classes_score = np.ones(item['probabilities'].shape[0])

    for frame_pred in classes_pred:
        frame_number = (params['batch_size'] * (batch_num - 1) + frame_id) * params['skip_frames']
        frame_seconds = frame_number / fps
        if frame_pred != 0:
            if frame_pred not in prediction_results['results'][video_id]:
                prediction_results['results'][video_id][frame_pred] = [frame_seconds]
            else:
                prediction_results['results'][video_id][frame_pred].append(frame_seconds)
        frame_id += 1

prediction_results

*******INPUTS.SHAPE Before FC******* (?, 15, 38400)
*******INPUTS.SHAPE After FC******* (?, 15, 1000)
*******LOGITS.SHAPE Before FC******* (?, 15, 101)


{'results': {'4R37E4Kevs4': {2.0: [7.507501585288523,
    7.607601606425703,
    7.707701627562884,
    7.807801648700064,
    7.907901669837244,
    8.008001690974424,
    8.108101712111605,
    8.208201733248785,
    8.308301754385965,
    8.408401775523146,
    8.508501796660326,
    8.608601817797506,
    8.708701838934687,
    8.808801860071867,
    8.908901881209047,
    12.012002536461637,
    12.112102557598817,
    12.212202578735997,
    12.312302599873178,
    12.412402621010358,
    12.512502642147538,
    12.612602663284719,
    12.712702684421899,
    12.81280270555908,
    12.91290272669626,
    13.01300274783344,
    13.11310276897062,
    13.2132027901078,
    13.31330281124498,
    13.413402832382161],
   24.0: [24.024005072923273,
    24.124105094060454,
    24.224205115197634,
    24.324305136334814,
    24.424405157471995,
    24.524505178609175,
    24.624605199746355,
    24.724705220883536,
    24.824805242020716,
    24.924905263157896,
    25.025005284295077,
