In [1]:
import os
import pickle
import numpy as np
from utils import mAP_f1_p_fix_r
from utils import evaluate_scenes, predictions_to_scenes
from utils import get_frames, get_batches, scenes2zero_one_representation, visualize_predictions
import ffmpeg
# PT EVALUATION
import os
import pickle

import numpy as np
import torch
from transnetv2.transnetv2_pytorch import TransNetV2
from supernet_flattransf_3_8_8_8_13_12_0_16_60 import TransNetV2Supernet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# load data & build fnm - path dict
fnm_path_dict = {}
dir_list = [
    "../original_videos/original_videos/",
    "../merged_ads_game_videos/",
    "../merged_video_download/"
]
for cur_dir in dir_list:
    for fnm in os.listdir(cur_dir):
        if fnm.endswith(".mp4"):
            fnm_path_dict[fnm[:-len(".mp4")]] = cur_dir + fnm

# load test annotation keep only one shot
with open('./gt_scenes_dict_baseline_v2.pickle', 'rb') as handle:
    gt_scenes_dict = pickle.load(handle)
handle.close()

print(sum( [len(annot) for _, annot in gt_scenes_dict.items()]))

2716


# Load the model

In [3]:
baseline_model = TransNetV2()
baseline_model.load_state_dict(torch.load('./transnetv2/transnetv2-pytorch-weights.pth'))
baseline_model.eval().cuda()
print("Model loaded")

Model loaded


# Make prediction

In [5]:
baseline_one_hot_pred_dict = {}
i = 0
for fnm, annot in gt_scenes_dict.items():
    i += 1
    print(i, fnm)
    
    video_frames, single_frame_predictions, all_frame_predictions = baseline_model.predict_video(fnm_path_dict[fnm])

    baseline_one_hot_pred_dict[fnm] = (single_frame_predictions > 0.5).astype(np.uint8)

1 31602670982
[TransNetV2] Extracting frames from ../merged_video_download/31602670982.mp4
2 33338203782
[TransNetV2] Extracting frames from ../original_videos/original_videos/33338203782.mp4
3 33803189290
[TransNetV2] Extracting frames from ../original_videos/original_videos/33803189290.mp4
4 34996642719
[TransNetV2] Extracting frames from ../original_videos/original_videos/34996642719.mp4
5 35087786411
[TransNetV2] Extracting frames from ../original_videos/original_videos/35087786411.mp4
6 35219009690
[TransNetV2] Extracting frames from ../original_videos/original_videos/35219009690.mp4
7 35460496768
[TransNetV2] Extracting frames from ../original_videos/original_videos/35460496768.mp4
8 38667643656
[TransNetV2] Extracting frames from ../merged_video_download/38667643656.mp4
9 40355287299
[TransNetV2] Extracting frames from ../merged_video_download/40355287299.mp4
10 41268001962
[TransNetV2] Extracting frames from ../merged_video_download/41268001962.mp4
11 42394122179
[TransNetV2] E

In [3]:
# with open('../baseline_one_hot_pred_dict_baseline.pickle', 'rb') as handle:
#     baseline_one_hot_pred_dict = pickle.load(handle)
# handle.close()

In [6]:
mAP, metric_F1, precision, recall, threshold, miou = mAP_f1_p_fix_r(baseline_one_hot_pred_dict, gt_scenes_dict)
print("Baseline 0.5", metric_F1, precision, recall, threshold)

Baseline 0.5 0.7992903082723443 0.9041645760160562 0.7162162162162162 0.9990234375


# Better Model

In [4]:
supernet_best_f1 = TransNetV2Supernet().eval()

pretrained_path = os.path.join("../ckpt_0_200_0.pth")
if os.path.exists(pretrained_path):
    print('Loading pretrained_path from %s' % pretrained_path)
    model_dict = supernet_best_f1.state_dict()
    pretrained_dict = torch.load(pretrained_path, map_location=device)
    pretrained_dict = {k: v for k, v in pretrained_dict['net'].items() if k in model_dict}
    print("Current model has %d paras, Update paras %d " % (len(model_dict), len(pretrained_dict)))
    model_dict.update(pretrained_dict)
    supernet_best_f1.load_state_dict(model_dict)
else:
    raise Exception("Error: Can NOT find pretrained best model!!")

supernet_best_f1 = supernet_best_f1.cuda(0)

Loading pretrained_path from ../ckpt_0_200_0.pth
Current model has 90 paras, Update paras 90 


In [23]:
def cut_video(video_path, frame_ranges):
    # Get the base name of the video file without extension
    base_name = os.path.splitext(os.path.basename(video_path))[0]
    # Create a directory with the same name as the video
    output_dir = os.path.join(os.path.dirname(video_path), base_name)
    os.makedirs(output_dir, exist_ok=True)
    
    # Get the video's frame rate
    probe = ffmpeg.probe(video_path)
    video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
    frame_rate = eval(video_streams[0]['r_frame_rate'])
    
    for i, (start_frame, end_frame) in enumerate(frame_ranges):
        start_time = start_frame / frame_rate
        end_time = (end_frame + 1) / frame_rate
        output_file = os.path.join(output_dir, f"{base_name}_part_{i+1}.mp4")
        
        ffmpeg.input(video_path, ss=start_time, to=end_time).output(output_file).run()

    print(f"All video parts have been saved in the folder: {output_dir}")

In [27]:
def save_images(video_path, frame_ranges, output_dir=None, timeinterval=1):
    # Get the base name of the video file without extension
    base_name = os.path.splitext(os.path.basename(video_path))[0]
    # Use the provided output directory or create one with the same name as the video
    if output_dir is None:
        output_dir = os.path.join(os.path.dirname(video_path), base_name)
    else:
        output_dir = os.path.join(output_dir, base_name)
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Get the video's frame rate
    probe = ffmpeg.probe(video_path)
    video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
    frame_rate = eval(video_streams[0]['r_frame_rate'])
    
    for i, (start_frame, end_frame) in enumerate(frame_ranges):
        start_time = start_frame / frame_rate
        end_time = (end_frame + 1) / frame_rate
        segment_duration = end_time - start_time
        
        if segment_duration < timeinterval:
            # Save one image from the middle point if segment is shorter than 500ms
            middle_time = start_time + segment_duration / 2
            output_file = os.path.join(output_dir, f"{base_name}_part_{i+1}_middle.png")
            ffmpeg.input(video_path, ss=middle_time).output(output_file, vframes=1).run()
        else:
            # Save images every 500ms
            current_time = start_time
            img_count = 1
            while current_time < end_time:
                output_file = os.path.join(output_dir, f"{base_name}_part_{i+1}_img_{img_count}.png")
                ffmpeg.input(video_path, ss=current_time).output(output_file, vframes=1).run()
                current_time += timeinterval
                img_count += 1

    print(f"All images have been saved in the folder: {output_dir}")

In [28]:
# Evaluation
def predict(batch):
    batch = torch.from_numpy(batch.transpose((3, 0, 1, 2))[np.newaxis, ...]) * 1.0
    batch = batch.to(device)
    one_hot = supernet_best_f1(batch)
    if isinstance(one_hot, tuple):
        one_hot = one_hot[0]
    return torch.sigmoid(one_hot[0])

supernet_best_f1_one_hot_pred_dict = {}
i = 0
for fnm, annot in gt_scenes_dict.items():
    i += 1
    print(i, fnm)
    
    predictions = []
    frames = get_frames(fnm_path_dict[fnm])

    for batch in get_batches(frames):
        one_hot = predict(batch)
        one_hot = one_hot.detach().cpu().numpy()
        
        predictions.append(one_hot[25:75])

    predictions = np.concatenate(predictions, 0)[:len(frames)]
    shot_cut = predictions_to_scenes((predictions > np.array([0.5])).astype(np.uint8))
    supernet_best_f1_one_hot_pred_dict[fnm] = shot_cut
    # cut_video(fnm_path_dict[fnm], shot_cut)
    save_images(fnm_path_dict[fnm], shot_cut, './outputs')
    
    break

1 31602670982


ffmpeg version 4.3 Copyright (c) 2000-2020 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/opt/conda/conda-bld/ffmpeg_1597178665428/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh --cc=/opt/conda/conda-bld/ffmpeg_1597178665428/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --disable-openssl --enable-avresample --enable-gnutls --enable-hardcoded-tables --enable-libfreetype --enable-libopenh264 --enable-pic --enable-pthreads --enable-shared --disable-static --enable-version3 --enable-zlib --enable-libmp3lame
  libavutil      56. 51.100 / 56. 51.100
  libavcodec     58. 91.100 / 58. 91.100
  libavformat    58. 45.100 / 58. 45.100
  libavdevice    58. 10.100 / 58. 10.100
  libavfilter     7. 85.100 /  7. 85.100
  libavresample   4.  0.  0 /  4.  0.  0
  libsw

All images have been saved in the folder: ./outputs/31602670982


frame=    1 fps=0.0 q=-0.0 Lsize=N/A time=00:00:00.01 bitrate=N/A speed=0.0939x    
video:849kB audio:0kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: unknown
