In [2]:
import numpy as np 
import matplotlib.pyplot as plt 
import cv2
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.gridspec import GridSpec
from sklearn.decomposition import PCA

from AQSM_SW1PerS.utils.data_processing import *
from AQSM_SW1PerS.SW1PerS import *
from AQSM_SW1PerS.utils.paths import get_data_path


In [8]:

from collections import deque

class RealTimeRollingMedian:
    def __init__(self, window_size):
        self.window_size = window_size  # Size of the rolling window
        self.window = deque()  # Deque to store the window of values

    def update(self, value):
        # Add new value to the rolling window
        self.window.append(value)
        if len(self.window) > self.window_size:
            self.window.popleft()  # Remove the oldest value if the window exceeds the size

        # Compute the rolling median excluding NaN values
        valid_values = [v for v in self.window if not np.isnan(v)]
        if valid_values:
            return np.median(valid_values)
        else:
            return np.nan  # Return NaN if the window has only NaN values
            

def interpolate_and_detrend(t_x, t_y, t_vals, num_points):
        
    keypoint_x = t_x(t_vals)
    keypoint_y = t_y(t_vals)
    
    keypoint_x = signal.detrend(keypoint_x)
    keypoint_y = signal.detrend(keypoint_y)

    f_x=CubicSpline(t_vals, keypoint_x)
    f_y=CubicSpline(t_vals, keypoint_y)

    cs = [f_x, f_y]

    return keypoint_x, keypoint_y, cs
    

def create_concept_video(t_x, t_y, t_x_a, t_y_a, segments, fps, sensor = 'Chest'):

    #Initalize rolling median period filter to ensure continuity of estimation
    window_size = int(fps)
    period_filter_spatial = RealTimeRollingMedian(window_size=window_size)
    period_filter_accel = RealTimeRollingMedian(window_size=window_size)


    fourcc = cv2.VideoWriter_fourcc(*'XVID')  
    out = cv2.VideoWriter(f'{sensor}_Plots.avi', fourcc, fps,  (2000, 800)) #Study 1: (352, 288), Study 2: (640,480)
    
    d = 23

    prime_coeff = next_prime(2 * d)
    method = 'PS1'
    
    for i,segment in enumerate(segments):
        start=np.min(segment)
        end=np.max(segment)
        
        num_points = 150
        t_vals=np.linspace(start,end,num_points) 

        keypoint_x, keypoint_y, cs_spatial = interpolate_and_detrend(t_x, t_y, t_vals, num_points)
        keypoint_x_accel, keypoint_y_accel, cs_accel = interpolate_and_detrend(t_x_a, t_y_a, t_vals, num_points)

        sampling_rate = num_points / (end - start)

        try:
            period_spatial = estimate_period(keypoint_x, keypoint_y, sampling_rate)
            smoothed_period_spatial = period_filter_spatial.update(period_spatial)
            tau_spatial = smoothed_period_spatial / (d + 1)                         
            SW_spatial = SW_cloud_nD(cs_spatial, t_vals, tau_spatial, d, 300, 2)
            pca = PCA(n_components=2) 
            proj_2D_spatial = pca.fit_transform(SW_spatial)
            
            result_spatial = ripser(SW_spatial, coeff = prime_coeff, maxdim = 1) 
            diagrams_saptial = result_spatial['dgms']
            dgm1_spatial = np.array(diagrams_saptial[1])
            score_spatial = compute_PS(dgm1_spatial, method = method)
        except:
            diagrams_saptial = [np.array([[0.0, 0.0]]), np.empty((0, 2))]  # H0 and H1
            score_spatial = [0]
            proj_2D_spatial = np.array([[0.0, 0.0]]) 

        try:
            period_accel = estimate_period(keypoint_x_accel, keypoint_y_accel, sampling_rate)
            smoothed_period_accel = period_filter_accel.update(period_accel)
            tau_accel = smoothed_period_accel / (d + 1)                         
            SW_accel = SW_cloud_nD(cs_accel, t_vals, tau_accel, d, 300, 2)
            pca = PCA(n_components=2) 
            proj_2D_accel = pca.fit_transform(SW_accel)
            
            result_accel = ripser(SW_accel, coeff = prime_coeff, maxdim = 1) 
            diagrams_accel = result_accel['dgms']
            dgm1_accel = np.array(diagrams_accel[1])
            score_accel = compute_PS(dgm1_accel, method = method)
        except:
            diagrams_accel = [np.array([[0.0, 0.0]]), np.empty((0, 2))]  # H0 and H1
            score_accel = [0]
            proj_2D_accel = np.array([[0.0, 0.0]]) 

            
        fig = plt.figure(figsize=(20, 8))
        
        gs = GridSpec(2, 4, figure=fig)

        ax1 = fig.add_subplot(gs[0,0])
        ax1.plot(t_vals,keypoint_x,color='r',label = 'X')
        ax1.plot(t_vals,keypoint_y,color='g',label = 'Y')
        ax1.set_title(f'{sensor} Spatial Positions')
        ax1.set_xlabel("Time")
        ax1.set_ylabel("Normalized Position")
        ax1.set_yticks([])
        ax1.legend()

        ax2 = fig.add_subplot(gs[0,1])
        ax2.scatter(proj_2D_spatial[:,0], proj_2D_spatial[:,1], s=10, alpha=0.7, color='deepskyblue')
        ax2.set_xticks([])
        ax2.set_yticks([])
        ax2.set_title(fr'PCA SW Point Cloud')

        ax3 = fig.add_subplot(gs[0,2])
        plot_diagrams(diagrams_saptial, plot_only=[1], xy_range=[0, 2, 0, 2], ax = ax3)
        ax3.set_xticks([])
        ax3.set_yticks([])
        ax3.set_title(fr'Persistence Diagram')

        ax4 = fig.add_subplot(gs[0,3])
        ax4.bar(range(1), score_spatial, alpha=0.5)
        ax4.set_title(fr'$PS_1$')
        ax4.set_xlim(-0.5, 0.5)
        ax4.set_ylim(0, 1)
        ax4.set_xticks([])

        ax5 = fig.add_subplot(gs[1,0])
        ax5.plot(t_vals,keypoint_x_accel,color='r',label = 'X')
        ax5.plot(t_vals,keypoint_y_accel,color='g',label = 'Y')
        ax5.set_title(f'{sensor} Acceleration')
        ax5.set_xlabel("Time")
        ax5.set_ylabel("Acceleration")
        ax5.set_yticks([])
        ax5.legend()

        ax6 = fig.add_subplot(gs[1,1])
        ax6.scatter(proj_2D_accel[:,0], proj_2D_accel[:,1], s=10, alpha=0.7, color='darkcyan')
        ax6.set_xticks([])
        ax6.set_yticks([])
        ax6.set_title(fr'PCA SW Point Cloud')

        ax7 = fig.add_subplot(gs[1,2])
        plot_diagrams(diagrams_accel, plot_only=[1], xy_range=[0, 2, 0, 2], ax = ax7)
        ax7.set_xticks([])
        ax7.set_yticks([])
        ax7.set_title(fr'Persistence Diagram')

        ax8 = fig.add_subplot(gs[1,3])
        ax8.bar(range(1), score_accel, alpha=0.5)
        ax8.set_title(fr'$PS_1$')
        ax8.set_xlim(-0.5, 0.5)
        ax8.set_ylim(0, 1)
        ax8.set_xticks([])

        plt.tight_layout()

        canvas = FigureCanvas(fig)  # Attach the canvas to the figure

        # Now render the figure
        canvas.draw()
        
        # Extract the RGB buffer from the *canvas* (not fig)
        plot_image = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
        plot_image = plot_image.reshape(fig.canvas.get_width_height()[::-1] + (4,))  # RGBA
        
        # Optionally convert RGBA to RGB
        plot_image_rgb = plot_image[..., :3]
        plot_image_bgr = cv2.cvtColor(plot_image_rgb, cv2.COLOR_RGB2BGR)

        out.write(plot_image_bgr)
        cv2.imshow('Plots',plot_image_bgr)
        
        cv2.waitKey(1)
        # Close Matplotlib plot to avoid memory issues
        plt.close(fig)
        plt.close('all')

        if start >= 40:
            break

    out.release()
    cv2.destroyAllWindows()
    return i


In [4]:

pkl_file = get_data_path("dataset.pkl")
data = open_pickle(pkl_file)

entry = data[13] 
fps, frame_times, segments, annotated_segments = segment_video(entry, 4)


In [5]:

head = extract_keypoints(entry, 0, frame_times, fps)
lshoulder = extract_keypoints(entry, 11, frame_times, fps)
rshoulder = extract_keypoints(entry, 12, frame_times, fps)
rwrist = extract_keypoints(entry, 16, frame_times, fps, do_wrists=True, elbow_index=14, shoulder_index=12)
lwrist = extract_keypoints(entry, 15, frame_times, fps, do_wrists=True, elbow_index=13, shoulder_index=11)

chest = getChest(head, lshoulder, rshoulder)

lshoulder_accel = getAccel(lshoulder,fps)
rshoulder_accel = getAccel(rshoulder,fps)

rwrist_accel = getAccel(rwrist,fps)
lwrist_accel = getAccel(lwrist,fps)

head_accel = getAccel(head,fps)
chest_accel = getAccel(chest,fps)

'''
Turn Keypoints into Spline Representation for the Main Algorithm to Interpolate Values
'''

#Raw keypoint positions

h_x, h_y = CubicSpline(frame_times,head[:,0]), CubicSpline(frame_times,head[:,1])

r_x, r_y = CubicSpline(frame_times,rwrist[:,0]), CubicSpline(frame_times,rwrist[:,1])

l_x, l_y = CubicSpline(frame_times,lwrist[:,0]), CubicSpline(frame_times,lwrist[:,1])

rs_x, rs_y = CubicSpline(frame_times,rshoulder[:,0]), CubicSpline(frame_times,rshoulder[:,1])

ls_x, ls_y = CubicSpline(frame_times,lshoulder[:,0]), CubicSpline(frame_times,lshoulder[:,1])

c_x, c_y = CubicSpline(frame_times,chest[:,0]), CubicSpline(frame_times,chest[:,1])

#Acceleration representation of keypoint movement

h_x_a, h_y_a = CubicSpline(frame_times,head_accel[:,0]), CubicSpline(frame_times,head_accel[:,1])

r_x_a, r_y_a = CubicSpline(frame_times,rwrist_accel[:,0]), CubicSpline(frame_times,rwrist_accel[:,1])

l_x_a, l_y_a = CubicSpline(frame_times,lwrist_accel[:,0]), CubicSpline(frame_times,lwrist_accel[:,1])

rs_x_a, rs_y_a = CubicSpline(frame_times,rshoulder_accel[:,0]), CubicSpline(frame_times,rshoulder_accel[:,1])

ls_x_a, ls_y_a = CubicSpline(frame_times,lshoulder_accel[:,0]), CubicSpline(frame_times,lshoulder_accel[:,1])

c_x_a, c_y_a = CubicSpline(frame_times,chest_accel[:,0]), CubicSpline(frame_times,chest_accel[:,1])



In [None]:

sensor = "Chest"
segment_indices = create_concept_video(c_x, c_y, c_x_a, c_y_a, segments, fps, sensor)


In [17]:

mediapipe_video = data_file = get_data_path("MPVideos", "004-01-17-08_0_study1_mp.avi")

#Create full concept video with with MediaPipe video and plots
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(f'concept_video_final.avi', fourcc, fps, (1600, 800))

cap1 = cv2.VideoCapture(mediapipe_video)
cap2 = cv2.VideoCapture(f'{sensor}_Plots.avi')

segment_index = 0
frame_index = 0

y_class = entry['annotations']

while cap1.isOpened():
    
    ret1,frame1=cap1.read()
    if not ret1:
        break
    
    frame1 = cv2.resize(frame1, (1600, 800))
    # Add text to the MediaPipe frame
    if y_class[frame_index] == 0:
        cv2.putText(frame1, f'Annotation: No Stereotypy', (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1, cv2.LINE_AA)
    elif y_class[frame_index] == 1:
        cv2.putText(frame1, f'Annotation: Rocking', (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1, cv2.LINE_AA)
    elif y_class[frame_index] == 2:
        cv2.putText(frame1, f'Annotation: Flapping', (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1, cv2.LINE_AA)
    else:
        cv2.putText(frame1, f'Annotation: Flap-Rock', (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1, cv2.LINE_AA)
    frame_index += 1
    current_frame = cap1.get(cv2.CAP_PROP_POS_FRAMES)
    current_time = current_frame / fps
    if current_time <= np.max(segments[segment_index]):
        cap2.set(cv2.CAP_PROP_POS_FRAMES, segment_index)
    else: 
        segment_index+=1
        cap2.set(cv2.CAP_PROP_POS_FRAMES, segment_index)

    counter=0
    while counter==0:
        ret2, frame2 = cap2.read()

        if not ret2:
            break 
            
        frame2_resized = cv2.resize(frame2, (1600, 800))
        combined_plots = frame2_resized
        
        combined_frame = cv2.hconcat([frame1, combined_plots])
        
        combined_frame = cv2.resize(combined_frame, (1600, 800))
        
        out.write(combined_frame)
        cv2.imshow('Combined Video', combined_frame)
        
        counter+=1
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
            
    if segment_index==segment_indices:
        break
        
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
        
cap1.release()
cap2.release()
out.release()
cv2.destroyAllWindows()
        
