In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import cv2
from scipy import ndimage
from typing import Tuple, Optional
import numpy as np
import math
import matplotlib.pyplot as plt
#import tensorflow as tf
#from get_models import rotnet

class ecg_rotation():
    """This class will take a scanned ECG and rotate it to get the correct orientation"""

    def __init__(self, model):
        """
        
        :Parameters:
        ------------
            model:
                rotation detection model
        """
        self.model = model
        
    @staticmethod
    def edge_detection(img: np.ndarray):
        """Detect edges in the scanned ECG image. 
        This will be used to orient the image when we rotate it
        
        :Parameters:
        ------------
            img:
                Scanned ECG image
        """
        img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img_edges = cv2.Canny(img_gray, 500, 500, apertureSize=3)
        lines = cv2.HoughLinesP(img_edges, 1, math.pi / 180.0, 100, minLineLength=100, maxLineGap=20)
        return lines
    
    @staticmethod
    def get_rotation_angle(img: np.ndarray, lines):
        """Get possible rotation angles
        
        :Parameters:
        -------------
            img:
                Scanned ECG image
            lines:
                horisontal lines used to orient image
        """
        angles = []
        for [[x1, y1, x2, y2]] in lines:
            cv2.line(img, (x1, y1), (x2, y2), (255, 0, 0), 3)
            angle = math.degrees(math.atan2(y2 - y1, x2 - x1))
            angles.append(angle)
        return angles

    @staticmethod        
    def _crop_image(img: np.ndarray, height: int, width: int) -> np.ndarray:
        #left limit
        for i in range(width):
            if np.sum(img[:,i,:]) > 0:
                break
        #right limit
        for j in range(width-1,0,-1):
            if np.sum(img[:,j,:]) > 0:
                break
        #top lim
        for k in range(height):
            if np.sum(img[k,:,:]) > 0:
                break
        #bottom limit
        for l in range(height-1,0,-1):
            if np.sum(img[l,:,:]) > 0:
                break
        return img[k:l+1,i:j+1,:]

    def rotate_and_crop_image(self, img:  np.ndarray, angles) -> np.ndarray:
        median_angle = np.median(angles)
        img_rotated = ndimage.rotate(img, median_angle)
        height, width, _ = img_rotated.shape
        cropped_img = self._crop_image(img_rotated, height, width)
        return cropped_img

    @staticmethod
    def resize_image(img:  np.ndarray, width: int = 2339, height: int = 1654) -> np.ndarray:
        ECG_image_resized = cv2.resize(img,(width,height))
        return ECG_image_resized


    def up_down_detection(self, img: np.ndarray) -> np.ndarray:
        pred = self.model.predict(np.expand_dims(img,0))
        if int(pred) == 1:
            return ndimage.rotate(img, 180)

        elif int(pred) == 0:
            return ndimage.rotate(img, 0)
    
    def rotate(self, img: np.ndarray) -> np.ndarray:
        """Rotate the image"""
        lines = self.edge_detection(img)
        angles = self.get_rotation_angle(img.copy(), lines)
        img = self.rotate_and_crop_image(img, angles)
        img = self.resize_image(img)

        #img = self.up_down_detection(img)
        return img





## Taing a look at one ECG

In [None]:
ecg = pd.read_csv("/kaggle/input/physionet-ecg-image-digitization/train/1006427285/1006427285.csv")

In [None]:
def plot_ecg_waveform(waveform_arr, length):

    lead_labels = [
        'I',  'II', 'III',
        'aVR','aVL','aVF', 
        'V1', 'V2', 'V3', 
        'V4', 'V5', 'V6'
    ]

    fig, axes = plt.subplots(12, 1, sharex=True)
    fig.set_figwidth(10)
    fig.set_figheight(10)

    for ax, lead_waveform, lead_name in zip(axes, waveform_arr, lead_labels):
        ax.plot(lead_waveform, label=lead_name)
        ax.set_xlim((0,length))
        ax.legend(loc='center left')
        xaxis = ax.get_shared_x_axes()

    fig.show()



In [None]:
plot_ecg_waveform(np.moveaxis(np.asarray(ecg),0,1), len(ecg))

In [None]:
path = "/kaggle/input/physionet-ecg-image-digitization/train/1006427285/"
images = [f for f in os.listdir(path) if f.endswith(".png")]
images = images[:8]



fig, axes = plt.subplots(2, 4, figsize=(20, 12))  
axes = axes.flatten()
for ax, img_name in zip(axes, images):
    img = mpimg.imread(os.path.join(path, img_name))
    ax.imshow(img)
    ax.set_title(img_name, fontsize=8)
    ax.axis("off")

for ax in axes[len(images):]:
    ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
path = "/kaggle/input/physionet-ecg-image-digitization/train/1006867983/"
images = [f for f in os.listdir(path) if f.endswith(".png")]
images = images[:8]



fig, axes = plt.subplots(2, 4, figsize=(20, 12))  
axes = axes.flatten()
for ax, img_name in zip(axes, images):
    img = mpimg.imread(os.path.join(path, img_name))
    ax.imshow(img)
    ax.set_title(img_name, fontsize=8)
    ax.axis("off")

for ax in axes[len(images):]:
    ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
#ecg_rotation()