In [47]:
# IMPORTS
# -------
import numpy as np
import cv2
from pydicom import dcmread
import time
import os

In [48]:
# BEAT IDENTIFICATION
# -------------------

def find_beats(video):
    """ Description: Returns a list of frames where beats start in the video
        Input: A video (sequence of RGB images)
        Output: A list containing the indices of frames in the video that are the start of a heartbeat
    """
    signal_height = find_signal_height(video)
    beats = [i for i in range(len(video)) if is_beat(video[i], signal_height)]
    
    # condense nearby beats into first one
    i = 1
    last_frame = beats[0]
    while i < len(beats):
        if beats[i] - last_frame < 2:
            last_frame = beats[i]
            del beats[i]
        else:
            last_frame = beats[i]
            i += 1

    return beats


def find_signal_height(video):
    """ Description: Finds the typical height of heartbeat spike for a video (Helper for find_beats)
        Input: A video (sequence of RGB images)
        Output: An integer number of pixels indicating the typical height of a heartbeat peak
    """
    pcts = []
    for img in video:
        mask = signal_mask(img)
        x, _, _ = locate_line(mask)
        mask[:, x-4:x+4] = 0  # remove vertical line for location in signal
        ys, _ = np.where(mask > 0)
        y = np.median(ys) # use median y-position of pixels for baseline of signal
        pcts.append(max(abs(y - np.percentile(ys, 5)), abs(y - np.percentile(ys, 95)))) # use highest and lowest 5% marks for height
    signal_height = int(np.mean(pcts))
    return signal_height 


def is_beat(img, signal_height=None):
    """ Description: Returns True if image contains a beat and False otherwise (Helper for find_beats)
        Input: An RGB image and an optional height for typical heartbeat peak in pixels
        Output: True if frame overlaps with a heartbeat and False otherwise
    """
    mask = signal_mask(img)
    x, y, h = locate_line(mask)
    mask = mask//255

    # remove any extra of vertical line
    end = x
    while np.sum(mask[y-h:y+h, end]) == 2*h and end > 0:
        end -= 1

    # look at region to left of vertical green line
    target = mask[y-h:y+h, end-5:end]

    # if vertical variance of horizontal line exceeds threshold, consider it a beat
    ys, _ = np.where(target > 0)
    is_a_beat = np.var(ys) > 5

    # require line to also be a certain percentage of typical peak height if provided
    if signal_height is not None:
        is_a_beat = is_a_beat and (np.sum(target[:h-int(0.5*signal_height),:]) > 0 or np.sum(target[h+int(0.5*signal_height):,:]) > 0) 

    return is_a_beat


def signal_mask(img):
    """ Description: Returns mask containing only the green heartbeat signal (Helper for is_beat)
        Input: An RGB image 
        Output: A binary image containing the heartbeat signal
    """
    mask = cv2.inRange(cv2.cvtColor(img, cv2.COLOR_RGB2HSV), (30, 10, 50), (90, 255, 255))  # filter for green signal
    mask[:600,:] = 0 # remove all noise clearly from above the signal
    return mask


def locate_line(mask):
    """ Description: Returns coordinates of current location in signal (Helper for is_beat)
        Input: A binary image containing heartbeat signal (output of signal_mask)
        Output: X and Y coordinates of signal, as well as the signal's height
    """
    mask = np.array(mask)

    # find the y position of the horizontal green signal
    ys, _ = np.where(mask[:,:] > 0)
    y = int(np.median(ys))

    # remove the horizontal line, leaving just the top of the vertical green line
    mask[650:,:] = 0

    # identify the line (approximately)
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
    x,_,w,h = max([cv2.boundingRect(c) for c in contours], key=lambda b: b[3])

    # remove all noise from outside the relevant area
    mask[:, x+w:] = 0
    mask[:, :x-w] = 0

    # find the x position of the vertical green signal
    _, xs = np.where(mask[:,:] > 0)
    x = int(np.mean(xs))

    return x, y, h

In [49]:
# CYCLE NORMALIZATION
# -------------------

def normalize(video, sec_per_cycle, name):
    """ Description: Sets each cycle in a heartbeat video to the same length
        Input: A video (sequence of RGB images), the number of seconds per heartbeat cycle, and a name to save to
        Output: Returns the number of heart cycles in the normalized video, saves video to computer
    """
    beats = find_beats(video)
    beat_videos = split_video(video, beats)

    # write out individual cycle videos
    for i in range(len(beat_videos)):
        vid = beat_videos[i]
        num_frames = len(vid)
        fps = num_frames / sec_per_cycle
        write_video(f"./{name}-{i}.avi", vid, fps)

    # combine into single video
    vids = "|".join([f"./{name}-{i}.avi" for i in range(len(beat_videos))])
    os.system(f"ffmpeg -i 'concat:{vids}' -c copy ./{name}.avi")

    # delete old individual videos
    for i in range(len(beat_videos)):
        os.remove(f"./{name}-{i}.avi")
    
    return len(beat_videos)


def split_video(video, beats):
    """ Description: Splits video into segments each containing a single heart beat cycle (Helper for normalize_cycles)
        Input: A video (sequence of RGB images) and a list of beats (output of find_beats)
        Output: A list of videos, each video containing a single complete heartbeat cycle
    """
    beat_videos = []
    for i in range(1, len(beats)):
        beat_videos.append(video[beats[i-1]:beats[i]])
    return beat_videos


def write_video(fname, video, fps):
    """ Description: Saves a video to file with a given fps
        Input: Filename, a video (sequence of RGB images), and an fps
        Output: None, saves video to computer
    """
    h,w = video[0].shape[:2] 
    fourcc = cv2.VideoWriter.fourcc(*'MJPG')
    out = cv2.VideoWriter(fname, fourcc, fps, (w, h))
    for img in video:
        out.write(img)
    out.release()
    cv2.destroyAllWindows()

In [50]:
# VIDEO ALIGNMENT
# ---------------

def align_videos(fnames, sec_per_cycle, names=[]):
    """ Description: Normalizes and trims videos to same length
        Input: List of video paths, the number of seconds a cycle should last, and a list of names to save the videos as
        Output: None, saves the set of normalized and trimmed videos to computer
    """
    if len(names) != len(fnames):
        names = [fname.split('/')[-1].split('.')[0] for fname in fnames]
    min_cycle_count = normalize_videos(fnames, sec_per_cycle, names)
    min_secs = min_cycle_count * sec_per_cycle
    trim_videos(names, min_secs)
    

def normalize_videos(fnames, sec_per_cycle, names=[]):
    """ Description: Normalizes a set of videos
        Input: List of video paths, the number of seconds a cycle should last, and a list of names to save the videos as
        Output: Returns the number of cycles in the shortest video and saves the normalized videos to computer
    """
    if len(names) != len(fnames):
        names = [fname.split('/')[-1].split('.')[0] for fname in fnames]
    num_cycles = []
    for fname, name in zip(fnames, names):
        video = load_video(fname)
        cycle_count = normalize(video, sec_per_cycle, name)
        num_cycles.append(cycle_count)
    return min(num_cycles)


def trim_videos(names, sec, delete_old=True):
    """ Description: Trims a set of videos to a specified integer number of seconds
        Input: A list of video names, the number of seconds to trim to, and an optional boolean to indicate deleting the old videos
        Output: None, saves trimmed videos to computer
    """
    max_time = time.strftime('%H:%M:%S', time.gmtime(sec))
    for name in names:
        os.system(f"ffmpeg -i {name}.avi -ss 00:00:00 -t {max_time} -c:v copy -c:a copy {name}_trimmed.avi")
        if delete_old:
            os.remove(f"./{name}.avi")
            os.system(f"mv ./{name}_trimmed.avi ./{name}.avi")


def load_video(fname):
    """ Description: Loads a video from a filename
        Input:  The filename of a video (string)
        Output: The video (sequence of RGB images)
    """
    video = [img for img in dcmread(fname).pixel_array]
    return video

In [51]:
# EXAMPLE
# -------

data_path = './data'

patients = ["10026_20171208", 
            "10651_20171118", 
            "14369_20180420", 
            "15751_20180728", 
            "20198_20180511", 
            "20218_20171201", 
            "20596_20171118", 
            "21104_20180322", 
            "22444_20180404"]

indices = [["02", "03", "04"],
           ["02", "03"],
           ["41", "42"],
           ["02", "03"],
           ["03", "04"],
           ["02", "05", "06", "07"],
           ["02", "03"],
           ["02", "03", "04"],
           ["02", "03", "04"]]

video_paths = sum([[f"{data_path}/{patients[i]}/IM-0001-00{index}.dcm" for index in indices[i]] for i in range(len(patients))], [])
video_names=[f"video{i}" for i in range(len(video_paths))]

align_videos(video_paths, 2, video_names)