In [1]:
# 2023 Gabriel J. Diaz @ RIT

import os
import sys
import numpy as np
import av
import logging
import pickle
from tqdm import tqdm

import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from pathlib import Path, PurePath, PurePosixPath


logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler(stream=sys.stdout))


In [2]:
    
# file_path = Path('D:\\Github\\retinal_flow_toolkit\\videos\\Yoyo-LVRA.mp4')

file_path_a = 'testVid.mp4'
file_path_b = 'testVid_tvl1_hsv_overlay.mp4'
file_path_c = 'testVid_tvl1_hsv_notpts.mp4'

# file_path_a = PurePath('D:\\Github\\retinal_flow_toolkit\\sandbox\\cb1_original.mp4')
# file_path_b = PurePath('D:\\Github\\retinal_flow_toolkit\\sandbox\\cb1_flow.mp4')


In [3]:
def get_timing_info(file_path):

    container_in = av.open(file_path, mode="r", timeout = None)
    time_base_out = container_in.streams.video[0].time_base
    num_frames = container_in.streams.video[0].frames
    average_fps = container_in.streams.video[0].average_rate
    
    container_out = av.open('vid_out.mp4', mode="w", timeout = None)
    stream = container_out.add_stream("libx264", framerate = average_fps)
    stream.options["crf"] = "20"
    stream.pix_fmt = container_in.streams.video[0].pix_fmt
    stream.time_base = container_in.streams.video[0].time_base

    packet_pts_out = []
    pts_out = []
    dts_out = []
    frame_time_base_out = []
    relative_time_out = []
    
    count = 0
    
    #for raw_frame in container_in.decode(video=0):
    for raw_frame in tqdm(container_in.decode(video=0), desc="Working.", unit= 'frames',total = num_frames):
            
        pts_out.append(raw_frame.pts)
        dts_out.append(raw_frame.dts)
        frame_time_base_out.append(raw_frame.time_base)
        relative_time_out.append( np.float32(raw_frame.pts * raw_frame.time_base) )
        
        for packet in stream.encode(raw_frame):
            packet_pts_out.append((raw_frame.pts, packet.pts))
            container_out.mux(packet)
        
        count = count+1
        
    
    for packet in stream.encode(raw_frame):
        container_out.mux(packet)
                
    print('Time base ' + str( container_in.streams.video[0].time_base))
    print('Num frames ' + str(container_in.streams.video[0].frames))
    print('Avg rate ' + str(container_in.streams.video[0].average_rate))
    print('Start time ' + str(container_in.start_time))
    print('Duration ' + str(container_in.duration))

    container_in.close()
    container_out.close()
    return time_base_out, pts_out, dts_out, packet_pts_out

In [4]:
# container_out = av.open(file_path_b, mode="r", timeout = None)    

In [5]:
def set_timing_info(filename_in, filename_out, pts, dts, time_base):

    container_in = av.open(filename_in)
    average_fps = container_in.streams.video[0].average_rate
    num_frames = container_in.streams.video[0].frames
    time_base = container_in.streams.video[0].time_base
    encoded_frame_count = container_in.streams.video[0].encoded_frame_count
    # container_in.sort_dts = True
    # container_in.flush_packets = True
    # container_in.streams.video[0].codec_context.skip_frame = "NONKEY"
    
    container_out = av.open(filename_out, mode="w", timeout = None)

    stream = container_out.add_stream("libx264", framerate = average_fps)
    stream.options["crf"] = "20"
    stream.pix_fmt = container_in.streams.video[0].pix_fmt
    stream.time_base = container_in.streams.video[0].time_base
    
    idx = 0
    
    for raw_frame in tqdm(container_in.decode(video=0), desc="Working.", unit= 'frames',total = num_frames):

        raw_frame.pts = pts[idx]
        raw_frame.time_base = pts[idx]
        # raw_frame.dts = dts[idx]
        
        for packet in stream.encode(raw_frame):
            packet.stream = stream
            packet.time_base = time_base
            packet.pts = pts[idx]
            packet.pts = dts[idx]
            container_out.mux(packet)
        
        idx = idx + 1

    # # Flush stream
    for packet in stream.encode():
        
        packet.stream = stream
        packet.time_base = time_base
        packet.pts = pts[idx]
        container_out.mux(packet)


    container_in.close()
    container_out.close()

In [6]:
time_base_a, pts_a, dts_a, packet_pts_out_a = get_timing_info(file_path_a)
time_base_d, pts_d, dts_d, packet_pts_out_d = get_timing_info('vid_out.mp4')

Working.: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 2333/2333 [00:05<00:00, 443.11frames/s]
non-strictly-monotonic PTS


Time base 1/65535
Num frames 2333
Avg rate 152893155/2662726
Start time 16999
Duration 40631000


Working.:   0%|                                                                                                                   | 0/2293 [00:00<?, ?frames/s]Packet corrupt (stream = 0, dts = 54216)
.
Invalid NAL unit size (4025 > 1336).
Error splitting the input into NAL units.
Working.:   2%|██                                                                                                      | 46/2293 [00:00<00:01, 1472.07frames/s]


InvalidDataError: [Errno 1094995529] Invalid data found when processing input; last error log: [h264] Error splitting the input into NAL units.

In [None]:
packet_pts_out_a[:20]

In [None]:
set_timing_info(file_path_b, 'remuxed.mp4', pts_a, dts_a, time_base_a)

In [None]:
packet_pts_out[:10]

In [None]:
time_base_b, pts_b, dts_b = get_timing_info(file_path_b)

In [None]:
time_base_c, pts_c, dts_c = get_timing_info(file_path_c)

In [None]:
print(len(pts_a))
print(len(pts_b))
print(len(pts_c))

In [None]:
# plt.rc('xtick', labelsize=20) 
# plt.rc('ytick', labelsize=20) 

# dts_a = np.array(dts_a,dtype=np.float32)
# dts_b = np.array(dts_b,dtype=np.float32)

# fig, ax = plt.subplots(figsize=(15, 10), layout='constrained',dpi=300)

# a_time = np.array(pts_a) * float(time_base_a)
# b_time = np.array(pts_b) * float(time_base_b)
# c_time = np.array(pts_c) * float(time_base_c)

# plt.subplot(1,1,1)
# plt.plot(np.arange(0,len(a_time[1:])) ,a_time[1:]-b_time,'r')
# plt.plot(np.arange(0,len(a_time[1:])) ,a_time[1:]-c_time,':b')

# plt.ylabel('time')
# plt.xlabel('pts * time_base')
# plt.legend(['original','transcoded'])

# # plt.subplot(2,1,2)
# # plt.plot(np.arange(0,len(dts_a)) ,np.array(dts_a) * float(time_base_a),'r')
# # plt.plot(np.arange(0,len(dts_b)) ,np.array(dts_b) * float(time_base_b),':b')
# # plt.plot(np.arange(0,len(dts_c)) ,np.array(dts_c) * float(time_base_c),':g')

# plt.ylabel('time')
# plt.xlabel('dts * time_base')
# plt.legend(['original','transcoded'])

# plt.savefig( 'frame_timings.png')



In [None]:
plt.rc('xtick', labelsize=20) 
plt.rc('ytick', labelsize=20) 

dts_a = np.array(dts_a,dtype=np.float32)
dts_b = np.array(dts_b,dtype=np.float32)

fig, ax = plt.subplots(figsize=(15, 10), layout='constrained',dpi=300)

plt.subplot(2,1,1)
plt.plot(np.arange(0,len(pts_a)) ,np.array(pts_a) * float(time_base_a),'r')
plt.plot(np.arange(0,len(pts_b)) ,np.array(pts_b) * float(time_base_b),':b')
plt.plot(np.arange(0,len(pts_c)) ,np.array(pts_c) * float(time_base_c),':g')
plt.ylabel('time')
plt.xlabel('pts * time_base')
plt.legend(['original','transcoded'])

plt.subplot(2,1,2)
plt.plot(np.arange(0,len(dts_a)) ,np.array(dts_a) * float(time_base_a),'r')
plt.plot(np.arange(0,len(dts_b)) ,np.array(dts_b) * float(time_base_b),':b')

# plt.plot(np.arange(0,len(dts_c)) ,np.array(dts_c) * float(time_base_c),':g')

plt.ylabel('time')
plt.xlabel('dts * time_base')
plt.legend(['original','transcoded'])

plt.savefig( 'frame_timings.png')



In [None]:
container_out = av.open(file_path_b, mode="w", timeout = None)

In [None]:
container_out.options