In [None]:
%matplotlib qt

In [None]:
import os
import numpy as np
import cv2
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import ipdb

In [None]:
def interactive_image_click(image):
    """
    Interactive function to click on an image and store coordinates.
    
    Parameters:
    image: numpy array of shape (H, W, 3)
        The image to display for interaction.

    Returns:
    clicked_points: list of tuples
        A list of coordinates (x, y) of points clicked on the image.
    """
    clicked_points = []
    
    
    if len(image.shape) == 3:
        c, h, w = image.shape
        
        if c == 3:
            image = image.permute(1,2,0)

    def onclick(event):
        # Check if the click is within the axes
        if event.inaxes:
            x, y = int(event.xdata), int(event.ydata)
            clicked_points.append((x, y))
            print(f"{x}\t {y}")
    
    # Function to handle interaction
    def collect_points():
        nonlocal clicked_points
        plt.imshow(image)
        plt.title("Click on points. Close the window to finish.")
        plt.connect('button_press_event', onclick)
        plt.show(block=True)

    # Interaction loop
    more_points = True
    while more_points:
        collect_points()
        response = input("Do you want to select more points? (yes/no): ").strip().lower()
        more_points = response in ["yes", "y"]
    
    # print("Clicked points:", clicked_points)
    return clicked_points


def crop_blocks(image, points, block_size=20):
    """
    Crop square blocks centered around given points from the image.

    Parameters:
        image (numpy.ndarray): Input image, either of shape (batch, 3, H, W) or (batch, H, W).
        points (list): List of (x, y) coordinates for cropping centers.
        block_size (int): Size of the square block (default: 15x15).
    Returns:
        list: List of cropped blocks as numpy arrays.
    """
    if len(image.shape) == 4:
        # Case: Image shape is (batch, 3, H, W)
        batch, channels, h, w = image.shape
        color_image = True
    elif len(image.shape) == 3:
        # Case: Image shape is (batch, H, W)
        batch, h, w = image.shape
        channels = None
        color_image = False
    else:
        raise ValueError("Image shape must be either (batch, 3, H, W) or (batch, H, W).")

    half_block = block_size // 2
    cropped_blocks = []

 
    for (x, y) in points:
        # Ensure the cropping stays within image boundaries
        y_start = max(y - half_block, 0)
        y_end = min(y + half_block + 1, h)
        x_start = max(x - half_block, 0)
        x_end = min(x + half_block + 1, w)

        if color_image:
            cropped_block = image[:, :, y_start:y_end, x_start:x_end]  # Handle (3, H, W)
        else:
            cropped_block = image[:, y_start:y_end, x_start:x_end]  # Handle (H, W)

        cropped_blocks.append(cropped_block)

    return cropped_blocks

from scipy.signal import resample

time_s = 1
space_s = 2 # width of the spatial window
time_series_data = {}  # create only once

## variable from mtlnet_train_rppg.py
# frames = torch.from_numpy(data); frames = frames.permute(0,3,1,2)[:,:,:,:]
# frames from MTLnet_train_rppg

def plot_cos_or_dot(cosine_sim, clicked_points):

    lines_ = []
    if clicked_points is None:
        clicked_points = interactive_image_click(cosine_sim[200,1])
    
    cropped_blocks = crop_blocks(cosine_sim, clicked_points, int(space_s))

    for line in cropped_blocks:
        lines = line.mean(-1).mean(-1)
        plt.plot(lines- lines.mean( 0))
        
        lines_.append(lines-lines.mean(0))
    
    plt.legend([i for i in range(len(clicked_points))])
    plt.show()


    return lines_, clicked_points

from scipy.signal import butter, filtfilt

# Assuming output is a PyTorch tensor, extract the signal

# Define the bandpass filter
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyquist = 0.5 * fs  # Nyquist frequency is half of the sampling rate
    low = lowcut / nyquist  # Normalize the lowcut frequency
    high = highcut / nyquist  # Normalize the highcut frequency
    b, a = butter(order, [low, high], btype='band')
    return b, a

def apply_bandpass_filter(data, lowcut, highcut, fs, order=10):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)  # Apply filter to data
    return y

def rank_sig(sig_):
    '''
    Parameters
    ----------
    sig_ : a list containing time series signal
        DESCRIPTION.
    Returns
    -------
    int: ordered points' index based on their location.

    '''
    ranked_max = (np.array(sig_)**2).sum(axis = 1)
    return ranked_max.argsort()[::-1]
    

def filt_sig(signal, lowcut = 0.8,  highcut = 4, ch = 1, fs = 10):
    '''
    Parameters
    ----------
    signal : List
        contain time series signal
    lowcut = lower cut off freq
    highcut = Higher cut off freq

    Returns
    -------
    sig_ : TYPE
        Filtered signal 
    '''
    sig_ = []
    for sig in signal:
        temp = sig[:,ch]
        f_signal = apply_bandpass_filter(temp.T, lowcut, highcut, fs =fs, order= 5)
        sig_.append(f_signal)
    return sig_, rank_sig(sig_)

def save_filtered_graphs(time_series_BR, video_name, i):
    """
    Saves filtered signals for each channel (B, G, R) for 9 points.
    - Graphs 1–9: Each shows B, G, and R signals for one point.
    - Graph 10: Shows the averaged BGR signal for each point.
    File names include descriptive labels: Abdomen, Chest, Shoulder, Control.
    """

    folder_name = video_name
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    # ROI labels in the order your points are clicked
    roi_labels = [
        "Abdomen Point 1",
        "Abdomen Point 2",
        "Chest Point 1",
        "Chest Point 2",
        "Shoulder Point 1",
        "Shoulder Point 2",
        "Control Point 1",
        "Control Point 2",
        "Control Point 3"
    ]

    # Extract signals for all channels
    channel_keys = [(0, (1, 2)), (1, (1, 2)), (2, (1, 2))]  # B, G, R
    signals_by_channel = [time_series_BR[key] for key in channel_keys]

    # --- Graphs 1–9: Each point with B, G, R ---
    for point_idx, roi_name in enumerate(roi_labels):
        plt.figure()
        for ch_idx, ch_name in zip(range(3), ["B", "G", "R"]):
            plt.plot(signals_by_channel[ch_idx][point_idx], label=ch_name)
        plt.title(f"Filtered Signals - {roi_name}")
        plt.xlabel("Frame")
        plt.ylabel("Amplitude")
        plt.legend()
        plt.grid(True)
        filename = f"{roi_name.replace(' ', '_')}_Filtered.png"
        plt.savefig(os.path.join(folder_name, filename))
        plt.close()

    # --- Graph 10: Averaged BGR signals for all 9 points ---
    plt.figure()
    for point_idx, roi_name in enumerate(roi_labels):
        avg_signal = np.mean([
            signals_by_channel[0][point_idx],
            signals_by_channel[1][point_idx],
            signals_by_channel[2][point_idx]
        ], axis=0)
        plt.plot(avg_signal, label=roi_name)

    plt.title("Filtered 9 Points Graph (BGR Averaged)")
    plt.xlabel("Frame")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(folder_name, "Filtered_9_Points_Graph.png"))
    plt.close()


# Start from Here

In [None]:
video_path = "input.mp4"
bgr_frames = [] # BGR frames
frames = [] # frames for processing
cap = cv2.VideoCapture(video_path)

start_frame = 69
mid_frame = 200 - start_frame
last_frame = mid_frame + 200

# video sampling rate
fs = cap.get(cv2.CAP_PROP_FPS)
# print(f"Video sampling rate: {fs} FPS")

cap = cv2.VideoCapture(video_path)

# Skip first 69 frames
for _ in range(start_frame - 1):
    ret, _ = cap.read()
    if not ret:
        raise ValueError("Video too short, cannot reach frame 70.")

# Read the 70th frame
ret, frame70 = cap.read()
if not ret:
    raise ValueError("Cannot read the 70th frame from the video.")

# Show frame and let user select ROI
roi = cv2.selectROI("Select ROI", frame70, showCrosshair=True, fromCenter=False)
cv2.destroyWindow("Select ROI")

# roi returns (x, y, w, h), convert to (x1, y1) and (x2, y2)
x, y, w, h = roi
x1, y1 = x, y
x2, y2 = x + w, y + h

print(f"Selected ROI: (x1, y1)=({x1}, {y1}), (x2, y2)=({x2}, {y2})")

upscale = 0 # change it to 1, when we want upscaling. 

idx = 0

# Skip first 69 frames
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame - 1)


while(cap.isOpened()):
    ret, frame = cap.read()

    # to observe: run the line in the console: plt.imshow(frame)
    # if observed proceed: y1:y2, x1:x2 see the observation after video path
    
    if not ret:
        break

    # w and h, we will look into this later!! we need to do upsampling
    # for baseline: no upscaline
    
    frame =  frame[ y1:y2, x1:x2, :]

    if upscale: 
        w = 600
        h = 600
        frame =  cv2.resize(frame, (w, h))

        
    frame = np.float32(frame/255)
    bgr_frames.append(frame) # H, W, C
    idx += 1
    if idx == 2000: 
        break

cap.release()
cv2.destroyAllWindows()
del cap

# List to Numpy Array to Tensor
# From: N, H, W, C = Batch Dimensions, Height, Width, Channels
# To: N, C, H, W = Batch Dimensions, Channels, Height, Width
bgr_frames = torch.from_numpy(np.transpose(bgr_frames, axes=(0, 3, 1, 2))) 
print(f"Number of frames: {bgr_frames.shape[0]} or {idx}")
print(f"Shape of each frame: {bgr_frames.shape[1:]}")



In [None]:
lines_, clicked_points = plot_cos_or_dot(bgr_frames[::time_s], clicked_points = None)

time_series_data[(time_s, space_s)] =  lines_

In [None]:
#%% Signal Extraction  and Filtering

time_series_BR  = {}

for i in time_series_data.keys():
    # lines_ = time_series_data[i]
    lines_ = list(np.array(time_series_data[i]))
    for chan in range(3):
        sig_R, sig_R_pos =filt_sig(lines_, 0.05, 0.7, chan, fs= fs/i[0])
        time_series_BR[(chan, i)] =  sig_R
        

In [None]:
save_filtered_graphs(time_series_BR, "6m_h", i=0)

In [None]:
ranges = [(start_frame, mid_frame), (mid_frame, last_frame), (last_frame, idx)]
channel_names = {0: "B", 1: "G", 2: "R"}
points_labels = [
    "Abdomen 1", "Abdomen 2", "Chest 1", "Chest 2",
    "Shoulder 1", "Shoulder 2", "Control 1", "Control 2", "Control 3"
]

# print("Channel\tFrame From\tFrame To\t" + "\t".join(points_labels))

results = []

for chan in [0, 1, 2]:  # B, G, R
    key = (chan, (1, 2))
    signals = time_series_BR[key]

    for (start, end) in ranges:
        means = [abs(np.mean(sig[start:end])) for sig in signals]  # <-- abs()
        means_str = "\t".join(f"{m:.6f}" for m in means)
        print(f"{channel_names[chan]}\t{start}\t{end}\t{means_str}")
        results.append([channel_names[chan], start, end] + means)

    # Full range
    means_full = [abs(np.mean(sig[0:idx])) for sig in signals]  # <-- abs()
    means_full_str = "\t".join(f"{m:.6f}" for m in means_full)
    print(f"{channel_names[chan]}\t0\t{idx}\t{means_full_str}")
    results.append([channel_names[chan], 0, idx] + means_full)

# --- BGR combined ---
for (start, end) in ranges:
    combined_means = []
    for i in range(9):
        all_channels = [time_series_BR[(c, (1, 2))][i][start:end] for c in [0, 1, 2]]
        avg_val = np.mean([np.mean(sig) for sig in all_channels])
        combined_means.append(abs(avg_val))  # <-- abs()
    combined_means_str = "\t".join(f"{m:.6f}" for m in combined_means)
    print(f"BGR\t{start}\t{end}\t{combined_means_str}")

combined_full = []
for i in range(9):
    all_channels = [time_series_BR[(c, (1, 2))][i][0:idx] for c in [0, 1, 2]]
    avg_val = np.mean([np.mean(sig) for sig in all_channels])
    combined_full.append(abs(avg_val))  # <-- abs()
combined_full_str = "\t".join(f"{m:.6f}" for m in combined_full)
print(f"BGR\t0\t{idx}\t{combined_full_str}")


In [None]:
ranges = [(start_frame, mid_frame), (mid_frame, last_frame), (last_frame, idx)]
roi_labels = [
    "Abdomen 1", "Abdomen 2", "Chest 1", "Chest 2",
    "Shoulder 1", "Shoulder 2", "Control 1", "Control 2", "Control 3"
]

# --- Step 1: Compute magnitude BGR averages ---
bgr_averages = {}
for start, end in ranges:
    combined_means = []
    for i in range(9):  # 9 ROIs
        # Compute mean of each channel, then average them
        channel_means = [np.mean(time_series_BR[(c, (1, 2))][i][start:end]) for c in [0, 1, 2]]
        avg_val = np.mean(channel_means)
        combined_means.append(abs(avg_val))  # <-- take absolute value
    bgr_averages[(start, end)] = combined_means

# --- Step 2: Compute ratios (ROI magnitude / Control magnitude) ---
control_indices = [6, 7, 8]  # Control 1, 2, 3
roi_indices = [0, 1, 2, 3, 4, 5]

print("ROI\tRange\tControl 1\tControl 2\tControl 3")
for roi_idx in roi_indices:
    for (start, end) in ranges:
        ratios = []
        roi_val = bgr_averages[(start, end)][roi_idx]
        for c_idx in control_indices:
            ctrl_val = bgr_averages[(start, end)][c_idx]
            ratio = roi_val / ctrl_val if ctrl_val != 0 else np.nan
            ratios.append(ratio)

        ratios_str = "\t".join(f"{r:.4f}" for r in ratios)
        print(f"{roi_labels[roi_idx]}\t{start}-{end}\t{ratios_str}")
