In [None]:
import sys
import os
import glob
import numpy as np
from tqdm import tqdm
import cv2
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
print(os.getcwd())
def get_records():
    """ Get paths for data in data/mit/ directory """
    #Download if doesn't exist
    
    # There are 3 files for each record
    # *.atr is one of them
    paths = glob.glob('mitbih/*.atr') # returns an array of path names that matches the arguement
    #paths = [os.path.join(os.getcwd(),path) for path in paths]
    # Get rid of the extension
    paths = [path[:-4] for path in paths]
    paths.sort()

    return paths

records = get_records()
print ('There are {} record files'.format(len(records)))
print (records)

In [None]:
def beat_annotations(annotation, type):
    """ Get rid of non-beat markers """
    """'N' for normal beats. Similarly we can give the input 'L' for left bundle branch block beats. 'R' for right bundle branch block
        beats. 'A' for Atrial premature contraction. 'V' for ventricular premature contraction. '/' for paced beat. 'E' for Ventricular
        escape beat."""
    
    good = [type] # ['N']   
    ids = np.in1d(annotation.symbol, good)

    # We want to know only the positions
    beats = annotation.sample[ids]

    return beats

In [None]:
import wfdb

e = records[4]
print(e[-3:])

signals, fields = wfdb.rdsamp(e, channels = [0])
print('signals=\n{}\nfields=\n{}'.format(signals,fields))
plt.plot(signals[0:720],linewidth=0.5)

In [None]:
"""
A  --  Atrial premature beat
f  --  Fusion of paced and normal beat
L  --  Left bundle branch block beat
N or .  --  Normal beat
Q  --  Unclassifiable beat
R  --  Right bundle branch block beat
V  --  Premature ventricular contraction
!  --  Ventricular flutter wave
/  --  Paced beat
|  --  Isolated QRS-like artifact
~  --  
+  --
"""

ann = wfdb.rdann(e, 'atr')

print('annotator symbol=\n{}\nannotator sample=\n{}'.format(ann.symbol,ann.sample))
print('annotator symbol length=\n{}\nannotator sample length=\n{}'.format(len(ann.symbol),len(ann.sample)))

In [None]:
imp_beats = beat_annotations(ann, '/')
print('imp_beats=\n{}'.format(imp_beats))
beats = (ann.sample)
print('beats=\n{}'.format(beats))

In [None]:
import math

result = []
x_axis = []
for i in imp_beats:
    beats = np.array(beats)
    index_i = np.where(beats == i) # find the indexes (location tuples) of all imp_beats(desired annotated beats) inside the array of all beats  
    # print("i:", index_i )
    j = index_i[0][0]
    # print("j:", j)
    if(j!=0 and j!=(len(beats)-1)):
            x = beats[j-1]
            y = beats[j+1]
            # print('x={}, y={}'.format(x,y))
            # diff1 = abs(x - beats[j])//2  # // --> floor division e.g. 15//2 = floor(15/2)=7
            # diff2 = abs(y - beats[j])//2
            # print('diff1={}, diff2={}'.format(diff1,diff2))
            sig_start = beats[j-1]+20
            sig_end = beats[j+1]-20
            #print('diff1={}, diff2={}'.format(sig_start,sig_end))
            data =signals[sig_start:sig_end, 0]
            #print('data: ',len(data))
            result.append(data)
            plot_y = [j * 1 for j in range(sig_start, sig_end)]
            x_axis.append(plot_y)
            #print('plot_y: ',len(plot_y))
            #print('data={}'.format(data))

In [None]:
fig = plt.figure(dpi=300, frameon=False, figsize=(1.0,0.5))
plt.plot(result[5], linewidth=0.5)
plt.xticks([]), plt.yticks([])
for spine in plt.gca().spines.values():
    spine.set_visible(False)
filename = 'fig_example.png'
fig.savefig(filename)
#plt.close()
im_gray = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
im_h, im_w = im_gray.shape # as grayscale image so third param (channel) = 1
# im_gray = im_gray[20:im_h-20, 30:im_w-30] #crop image
im_gray = cv2.copyMakeBorder(im_gray,75,75,0,0, cv2.BORDER_REPLICATE) # as the image shape (from plt.savefig) is 300px*150px
# cv2.imshow('', im_gray)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
print(im_gray.shape)
im_gray = cv2.resize(im_gray, (128,128), interpolation=cv2.INTER_LANCZOS4)
cv2.imwrite('fig_example_resize.png', im_gray)

In [None]:
cv2.imshow('image',im_gray)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [None]:
noOfCols = 5
noOfRows = math.ceil(len(result)/noOfCols)
fig,axes = plt.subplots(noOfRows,noOfCols, figsize=(15,30))  # figsize=(width (in inches),height (in inches)) of frame 
k=0
for i in range (0,noOfRows): 
    for j in range (0,noOfCols):
        axes[i,j].plot(result[k])
        k+=1
        if k==len(result):
            break

In [None]:
import wfdb

#def segmentation(records, type,record_index,output_dir=''):
def segmentation(records, type,output_dir=''):
    os.makedirs(output_dir, exist_ok=True)
    results = []
    #Normal = []
    kernel = np.ones((4, 4), np.uint8)
    count = 1
    
    
#     max_values = []
#     min_values = []
#     mean_values = []
#     for e in tqdm(records):
#         signals, fields = wfdb.rdsamp(e, channels=[0])
#         mean_values.append(np.mean(signals))

#     mean_v = np.mean(np.array(mean_values))
#     std_v = 0
#     count = 0
#     for e in tqdm(records):
#         signals, fields = wfdb.rdsamp(e, channels=[0])
#         count += len(signals)
#         for s in signals:
#             std_v += (s[0] - mean_v)**2

#     std_v = np.sqrt(std_v/count)

    mean_v = -0.33859
    std_v = 0.472368
    floor = mean_v - 3*std_v
    ceil = mean_v + 3*std_v
    
    # tqdm adds progressbar
    for e in tqdm(records):
        signals, fields = wfdb.rdsamp(e, channels = [0])
        # print(signals)
        # fig= plt.figure(figsize=(12,3))
        # axes= fig.add_axes([0.1,0.1,0.8,0.8]) # dimensions [left, bottom, width, height] of the new axes
        # axes.plot(signals)
        # plt.show()
        # print(fields)
        ann = wfdb.rdann(e, 'atr')
        # print('annotator symbol:', ann.symbol)
        # print('annotator sample:', ann.sample)
        #good = ['N']
        #ids = np.in1d(ann.symbol, good)
        imp_beats = beat_annotations(ann, type) #ann.sample[ids]
        # print(imp_beats)
        beats = (ann.sample)
        for i in tqdm(imp_beats):
            #beats = list(beats)
            #j = beats.index(i)
            beats = np.array(beats)
            index_i = np.where(beats == i)
            j = index_i[0][0] # as numpy.where returns tuples we only need the first index of item that match
            if(j!=0 and j!=(len(beats)-1)):
                x = beats[j-1]
                y = beats[j+1]
                diff1 = abs(x - beats[j])//2  # // --> floor division e.g. 15//2 = floor(15/2)=7
                diff2 = abs(y - beats[j])//2
                data =signals[beats[j] - diff1: beats[j] + diff2, 0]
                #Normal.append(data)
                results.append(data)
                
                plt.axis([0, 192, floor, ceil])
                plt.plot(data, linewidth=0.5)
                plt.xticks([]), plt.yticks([])
                for spine in plt.gca().spines.values():
                    spine.set_visible(False)

                filename = output_dir + 'fig_{}_{}'.format(records.index(e),count) + '.png'
                #filename = output_dir + 'fig_{}_{}'.format(record_index,count) + '.png'
                plt.savefig(filename)
                plt.close()
                im_gray = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
                im_gray = cv2.erode(im_gray, kernel, iterations=1)
                im_gray = cv2.resize(im_gray, (128, 128), interpolation=cv2.INTER_LANCZOS4)
                cv2.imwrite(filename, im_gray)
                print('img writtten {}'.format(filename))
                count += 1
        print('img completed {}'.format(e))
    #return Normal
    return results
    

# commenting/uncommenting a block of code: select the block of code then press  'ctrl'+'/'

# testing with one signal file that the segmentation function will work fine
# import math

# test_records = [records[2]]
# signalArr = segmentation(test_records)
# print(len(signalArr))
# noOfCols = 5
# noOfRows = math.ceil(len(signalArr)/noOfCols)
# # plt.rcParams['figure.figsize'] = [5,5] # [width (in inches), height (in inches)]
# fig,axes = plt.subplots(noOfRows,noOfCols, figsize=(15,30))  # figsize=(width (in inches),height (in inches)) of frame 
# k=0
# for i in range (0,noOfRows): 
#     for j in range (0,noOfCols):
#         axes[i,j].plot(signalArr[k])
#         k+=1
#         if k==len(signalArr):
#             break
# #print(signalArr)

In [None]:
# creating database by segmentation of ecg beats into image

labels = ['A', 'L', 'N', '/', 'V', 'R', 'E', '!']
output_dirs = ['APC/', 'LBBB/', 'NOR/', 'PAB/', 'PVC/', 'RBBB/', 'VEB/', 'VFE/']
#one_record=[records[47]]
for type, output_dir in zip(labels, output_dirs):
    sgs = segmentation(records, type, output_dir='./MIT-BIH_AD/'+output_dir)
    #sgs = segmentation(one_record, type, record_index=47, output_dir='./MIT-BIH_AD/'+output_dir)

In [None]:
# count the length of different directory inside dataset
Database_DIR = 'MIT-BIH_AD_BASE/'
image_dirs = ['NOR/', 'LBBB/', 'RBBB/', 'APC/', 'PVC/', 'VEB/','PAB/', 'VFE/']
no_of_files_in_dir=[]
for image_dir in image_dirs:
    path, dirs, files = next(os.walk(os.path.join(Database_DIR,image_dir)))
    no_of_files_in_dir.append(len(files)) 

print('Number of images in each directory={} and total number of images={}'.format(no_of_files_in_dir, sum(no_of_files_in_dir)))
plt.figure(figsize=(10,10))
my_circle=plt.Circle((0,0), 0.7, color='white')
plt.pie(no_of_files_in_dir, labels=['NOR', 'LBBB', 'RBBB', 'APC', 'PVC', 'VEB','PAB', 'VFE'], colors=['red','green','blue','skyblue','orange', 'yellow','magenta', 'cyan'],autopct='%1.1f%%')
p=plt.gcf()
p.gca().add_artist(my_circle)
plt.show()
#p.savefig('data_distribution.png', dpi=400)

In [None]:
# data augmentation paper function
def cropping(image, filename):
    
    #Left Top Crop
    crop = image[:96, :96]
    crop = cv2.resize(crop, (128, 128))
    cv2.imwrite(filename[:-4] + 'leftTop' + '.png', crop)
    
    #Center Top Crop
    crop = image[:96, 16:112]
    crop = cv2.resize(crop, (128, 128))
    cv2.imwrite(filename[:-4] + 'centerTop' + '.png', crop)
    
    #Right Top Crop
    crop = image[:96, 32:]
    crop = cv2.resize(crop, (128, 128))
    cv2.imwrite(filename[:-4] + 'rightTop' + '.png', crop)
    
    #Left Center Crop
    crop = image[16:112, :96]
    crop = cv2.resize(crop, (128, 128))
    cv2.imwrite(filename[:-4] + 'leftCenter' + '.png', crop)
    
    #Center Center Crop
    crop = image[16:112, 16:112]
    crop = cv2.resize(crop, (128, 128))
    cv2.imwrite(filename[:-4] + 'centerCenter' + '.png', crop)
    
    #Right Center Crop
    crop = image[16:112, 32:]
    crop = cv2.resize(crop, (128, 128))
    cv2.imwrite(filename[:-4] + 'rightCenter' + '.png', crop)
    
    #Left Bottom Crop
    crop = image[32:, :96]
    crop = cv2.resize(crop, (128, 128))
    cv2.imwrite(filename[:-4] + 'leftBottom' + '.png', crop)
    
    #Center Bottom Crop
    crop = image[32:, 16:112]
    crop = cv2.resize(crop, (128, 128))
    cv2.imwrite(filename[:-4] + 'centerBottom' + '.png', crop)
    
    #Right Bottom Crop
    crop = image[32:, 32:]
    crop = cv2.resize(crop, (128, 128))
    cv2.imwrite(filename[:-4] + 'rightBottom' + '.png', crop)

In [None]:
# data augmentation
augment_dirs = ['LBBB/', 'RBBB/', 'APC/', 'PVC/', 'VEB/','PAB/', 'VFE/']
# print(os.path.join(Database_DIR,image_dir,'fig_28_1.png'))
# img = cv2.imread(os.path.join(Database_DIR,image_dir,'fig_28_1.png'))
# cv2.imshow('cropped',img)
# cv2.waitKey(0)
for image_dir in augment_dirs:
    path, dirs, files = next(os.walk(os.path.join(Database_DIR,image_dir)))
    for file in tqdm(files):
        imagefilepath = os.path.join(Database_DIR,image_dir,file)
        image = cv2.imread(imagefilepath)
        cropping(image, imagefilepath)

In [None]:
# count the length of different directory inside dataset
no_of_files_in_dir_new=[]
for image_dir in image_dirs:
    path, dirs, files = next(os.walk(os.path.join(Database_DIR,image_dir)))
    no_of_files_in_dir_new.append(len(files)) 

print('Number of images in each directory={} and total number of images={}'.format(no_of_files_in_dir_new, sum(no_of_files_in_dir_new)))
plt.figure(figsize=(10,10))
my_circle=plt.Circle((0,0), 0.7, color='white')
plt.pie(no_of_files_in_dir_new, labels=['NOR', 'LBBB', 'RBBB', 'APC', 'PVC', 'VEB','PAB', 'VFE'], colors=['red','green','blue','skyblue','orange', 'yellow','magenta', 'cyan'],autopct='%1.1f%%')
p1=plt.gcf()
p1.gca().add_artist(my_circle)
plt.show()
#p1.savefig('data_distribution_after_augmentation.png', dpi=400)

In [None]:
import numpy as np
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
# Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Activation, Conv2D, Dense, Dropout, Flatten, MaxPool2D
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy

from tensorflow.keras.preprocessing.image import img_to_array, load_img, ImageDataGenerator

In [None]:
public_devices = tf.config.experimental.list_physical_devices('GPU')
print('Number of GPU available', len(public_devices))

if len(public_devices) > 0:
    for gpu in public_devices:
        tf.config.experimental.set_memory_growth(gpu, True)  # preventing tensorflow to allocate all gpu memory at start of declaration

In [None]:
# only a pc with rtx2070 super gpu need this block

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [None]:
# divide data images into train, test, validation directory
import glob
import random
import shutil

# os.chdir('Dataset')

image_dirs = ['NOR', 'LBBB', 'RBBB', 'APC', 'PVC', 'VEB','PAB', 'VFE']

if os.path.isdir('MIT-BIH_AD/train/APC') is False:
    for i in image_dirs:
        current_path = 'MIT-BIH_AD/'+i
        path_train = 'MIT-BIH_AD/train/'+i
        path_valid = 'MIT-BIH_AD/valid/'+i
        path_test = 'MIT-BIH_AD/test/'+i
        os.makedirs(path_train)
        os.makedirs(path_valid)
        os.makedirs(path_test)
        path, dirs, files = next(os.walk(current_path))
        no_of_files = len(files)
        no_of_valid_dir_files = round(no_of_files*0.1)
        no_of_test_dir_files = round(no_of_files*0.2)
        no_of_train_dir_files = no_of_files - (no_of_test_dir_files + no_of_valid_dir_files)
        print(no_of_files)
        for j in random.sample(glob.glob(current_path+'/fig*'),no_of_train_dir_files):
            shutil.move(j,path_train)
        for j in random.sample(glob.glob(current_path+'/fig*'),no_of_valid_dir_files):
            shutil.move(j,path_valid)
        for j in random.sample(glob.glob(current_path+'/fig*'),no_of_test_dir_files):
            shutil.move(j,path_test)
        
        
# moving 70%,10%,20% data from dataset/with_mask (no of data=1915) directory to train,valid,test /with_mask directory    
# moving 70%,10%,20% data from dataset/without_mask (no of data=1918) directory to train,valid,test /without_mask directory  

# os.chdir('../')

# no need to run after first time

In [None]:
# count the length of different directory inside dataset
Database_DIR = 'MIT-BIH_AD/valid/'
image_dirs = ['NOR/', 'LBBB/', 'RBBB/', 'APC/', 'PVC/', 'VEB/','PAB/', 'VFE/']
no_of_files_in_dir=[]
for image_dir in image_dirs:
    path, dirs, files = next(os.walk(os.path.join(Database_DIR,image_dir)))
    no_of_files_in_dir.append(len(files)) 

print('Number of images in each directory={} and total number of images={}'.format(no_of_files_in_dir, sum(no_of_files_in_dir)))
plt.figure(figsize=(10,10))
my_circle=plt.Circle((0,0), 0.7, color='white')
plt.pie(no_of_files_in_dir, labels=['NOR', 'LBBB', 'RBBB', 'APC', 'PVC', 'VEB','PAB', 'VFE'], colors=['red','green','blue','skyblue','orange', 'yellow','magenta', 'cyan'],autopct='%1.1f%%')
p=plt.gcf()
p.gca().add_artist(my_circle)
plt.show()
#p.savefig('data_distribution.png', dpi=400)

In [None]:
def plotImages(images_arr, batchSize, subplot_dim=[1,10]):
    fig, axes = plt.subplots(subplot_dim[0], subplot_dim[1], figsize=(20,20))
    axes = axes.flatten()  # flaten converts an array to a 1D vector
    for img, ax in zip(images_arr,axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
def crop_image(image, crop_tuple):
    """crop_tuple consist of 4 params (y_start, y_end, x_start, x_end)"""
    assert len(image.shape) == 3 # confirms the image shape to be of length 3 (image_h, image_w, no_of_channels)
    assert image.shape[2] == 3 # confirms image data format is channels_last  (image_h, image_w, no_of_channels) and RGB channel
    
    img_height, img_width = image.shape[0], image.shape[1]
    y_start, y_end, x_start, x_end = crop_tuple
    
    if not x_start:
        x_start = 0
    if not y_start:
        y_start = 0
    if not x_end:
        x_end = img_width-1
    if not y_end:
        y_end = img_height-1
    
    cropped_img = image[y_start:y_end, x_start:x_end]  # cropping image according to given coordinates
    cropped_img = cv2.resize(cropped_img, (img_height, img_width)) # resizing image to its original size
    
    return cropped_img

In [None]:
def crop_generator(batches, crop_tuples_arr):
    """Take as input a Keras ImageGen (Iterator) and generate random
    crops from the image batches generated by the original iterator.
    """
    while True:
        batch_x, batch_y = next(batches)
        
        assert batch_x.shape[3] == 3 # confirms batch data format is channels_last  (batch_size,image_h, image_w, no_of_channels) and RGB channel
        
        # plot the batch images after cropped
        plotImages(batch_x, batch_x.shape[0], [4,8])
        print(batch_y)
        print('separator\n')
        
        batch_crops = np.zeros(batch_x.shape) # tensor of 0s with shape (batch_size,image_h, image_w, no_of_channels)
        for i in range(batch_x.shape[0]):
            batch_crops[i] = crop_image(batch_x[i], crop_tuple)
        
        # plot the batch images after cropped
        plotImages(batch_crops, batch_x.shape[0], [4,8])
        print(batch_y)
        
        yield (batch_crops, batch_y) # yield --> instead of return, denotes the function is a generator, keeps local states

In [None]:
train_path = "MIT-BIH_DATABASE/train"
test_path = "MIT-BIH_DATABASE/test"

In [None]:
# dataset_size*epoch = number_of_iteration*batch_size

batchSize = 4 #32

train_gen = ImageDataGenerator(rescale=1./255, validation_split=0.1)
test_gen = ImageDataGenerator(rescale=1./255)

In [None]:
train_batches_uncropped = train_gen.flow_from_directory(directory=train_path, target_size=(128,128), classes=['APC', 'LBBB', 'NOR', 'PAB', 'PVC', 'RBBB', 'VEB', 'VFE'], batch_size=batchSize, seed=32)
test_batches = test_gen.flow_from_directory(directory=test_path, target_size=(128,128), classes=['APC', 'LBBB', 'NOR', 'PAB', 'PVC', 'RBBB', 'VEB', 'VFE'], batch_size=batchSize, seed=32, shuffle=False)

In [None]:
# try to set the crop_touple in a way in which,  x_end-x_start = y_end-y_start  to avoid distortion in output image
#crop_coordinates_arr = [actual image, Left Top Crop, Center Top Crop, Right Top Crop, Left Center Crop, Center Center Crop, Right Center Crop, Left Bottom Crop, Center Bottom Crop, Right Bottom Crop]

crop_coordinates_arr = [(0,128,0,128), (0,96,0,96), (0,96,16,112), (0,96,32,128), (16,112,0,96), (16,112,16,112), (16,112,32,128), (32,128,0,96), (32,128,16,112), (32,128,32,128)]

# train_batches = crop_generator(train_batches_uncropped, crop_coordinates_arr)

train_batches = crop_generator(train_batches_uncropped, (0,96,0,96))

In [None]:
# to print the generator "crop_generator" output
next(train_batches)

In [None]:
def crop_image(image, crop_tuple):
    """crop_tuple consist of 4 params (x_start, y_start, x_end, y_end)"""
    assert len(image.shape) == 3 # confirms the image shape to be of length 3 (image_h, image_w, no_of_channels)
    assert image.shape[2] == 3 # confirms image data format is channels_last  (image_h, image_w, no_of_channels) and RGB channel
    
    img_height, img_width = image.shape[0], image.shape[1]
    x_start, y_start, x_end, y_end = crop_tuple
    
    if not x_start:
        x_start = 0
    if not y_start:
        y_start = 0
    if not x_end:
        x_end = img_width-1
    if not y_end:
        y_end = img_height-1
    
    cropped_img = image[y_start:y_end, x_start:x_end]  # cropping image according to given coordinates
    cropped_img = cv2.resize(cropped_img, (img_height, img_width)) # resizing image to its original size
#     cv2.imshow('',cropped_img)
#     cv2.waitKey(0)
#     cv2.destroyAllWindows()
    
    return cropped_img

In [None]:
def crop_generator(batches, crop_tuple):
    """Take as input a Keras ImageGen (Iterator) and generate random
    crops from the image batches generated by the original iterator.
    """
    while True:
        batch_x, batch_y = next(batches)
        
        assert batch_x.shape[3] == 3 # confirms batch data format is channels_last  (batch_size,image_h, image_w, no_of_channels) and RGB channel
        
        batch_crops = np.zeros(batch_x.shape) # tensor of 0s with shape (batch_size,image_h, image_w, no_of_channels)
        for i in range(batch_x.shape[0]):
            batch_crops[i] = crop_image(batch_x[i], crop_tuple)
        yield (batch_crops, batch_y) # yield --> instead of return, denotes the function is a generator, keeps local states 

In [None]:
train_path = 'Augmentation_test_folder/train'

In [None]:
batchSize = 1
train_gen = ImageDataGenerator()

In [None]:
train_batches = train_gen.flow_from_directory(directory=train_path, target_size=(128,128), classes=['c1', 'c2'], batch_size=batchSize)

In [None]:
train_batches_crops = crop_generator(train_batches, (0,0,96,96))

In [None]:
batch_crops, batch_y = next(train_batches)
print(batch_y)
plotImages(batch_crops, batchSize, [2,1])

In [None]:
img = load_img(train_path+'/fig_100_18.png')
x = img_to_array(img) # numpy array with shape (128,128,3)
print(x.shape)
x = x.reshape((1,)+x.shape) # numpy array with shape (1,128,128,3)
print(x.shape)

In [None]:
labels = [1,0,0,0,0,0,0,0]
y = np.array(labels)
print(y.shape)
y = y.reshape((1,)+y.shape) # numpy array with shape (1,8)
print(y.shape)

In [None]:
train_batches = train_gen.flow(x,
                              y,
                              batch_size=batchSize,
                              seed=5) 
#                               save_to_dir=train_path,
#                               save_prefix='aug',
#                               save_format='png')

train_batches_crops = crop_generator(train_batches, (0,0,128,100))

i = 0
for batch in train_batches_crops:
    i += 1
    if i >= 2:
        break

In [None]:
len(train_batches)

In [None]:
train_path = "MIT-BIH_AD/train"
valid_path = 'MIT-BIH_AD/valid'
test_path = 'MIT-BIH_AD/test'

In [None]:
# dataset_size*epoch = number_of_iteration*batch_size

batchSize = 32

train_gen = ImageDataGenerator(rescale=1./255)
valid_gen = ImageDataGenerator(rescale=1./255)
test_gen = ImageDataGenerator(rescale=1./255)

In [None]:
train_batches = train_gen.flow_from_directory(directory=train_path, target_size=(128,128), classes=['APC', 'LBBB', 'NOR', 'PAB', 'PVC', 'RBBB', 'VEB', 'VFE'], batch_size=batchSize, seed=32)
valid_batches = valid_gen.flow_from_directory(directory=valid_path, target_size=(128,128), classes=['APC', 'LBBB', 'NOR', 'PAB', 'PVC', 'RBBB', 'VEB', 'VFE'], batch_size=batchSize, seed=32)
test_batches = test_gen.flow_from_directory(directory=test_path, target_size=(128,128), classes=['APC', 'LBBB', 'NOR', 'PAB', 'PVC', 'RBBB', 'VEB', 'VFE'], batch_size=batchSize, seed=32) #shuffle=False)

In [None]:
# taking (train) images and labels of only one batch (10 images,10 labels) and plot them

imgs, labels = next(train_batches)
plotImages(imgs, batchSize, [4,8])
print(labels)

In [None]:
# model

def proposed_model(input_h, input_w, nb_classes):
    InputShape = (input_h, input_w, 3)
    
    model = Sequential([
        Conv2D(filters=64, kernel_size=(3,3), activation='elu', padding='same', input_shape=InputShape, kernel_initializer='glorot_uniform'),
        BatchNormalization(),
        Conv2D(filters=64, kernel_size=(3,3), activation='elu', padding='same'),
        BatchNormalization(),
        MaxPool2D(pool_size=(2, 2), strides= 2),
        Conv2D(filters=128, kernel_size=(3,3), activation='elu', padding='same'),
        BatchNormalization(),
        Conv2D(filters=128, kernel_size=(3,3), activation='elu', padding='same'),
        BatchNormalization(),
        MaxPool2D(pool_size=(2, 2), strides= 2),
        Conv2D(filters=256, kernel_size=(3,3), activation='elu', padding='same'),
        BatchNormalization(),
        Conv2D(filters=256, kernel_size=(3,3), activation='elu', padding='same'),
        BatchNormalization(),
        MaxPool2D(pool_size=(2, 2), strides= 2),
        Flatten(),
        Dense(units=2048, activation='elu'),
        BatchNormalization(),
        Dropout(rate=0.5),
        Dense(units=nb_classes, activation='softmax'),
    ])
    
    return model

In [None]:
model = proposed_model(128, 128, 8)
print(model.summary())
# dot_img_file = '/tmp/model_1.png'
# tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

In [None]:
lr = 0.001
model.compile(optimizer=Adam(learning_rate= lr), loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
# train model
Epoch = 1
Verbose = 1

model.fit(x=train_batches,validation_data=valid_batches,epochs=Epoch,verbose=Verbose)

# save model (architecture, optimizer, weights, ...all)
if os.path.isdir('models') is False:
    os.makedirs('models')
if os.path.isfile('models/ecg_arrgythmia_detection_model.h5') is False:
    model.save('models/ecg_arrgythmia_detection_model.h5')
    print('model saved successfully.')
    
# 100/3510 [.............................] ETA: 50:00 - loss: 1.18 - accuracy: 0.68
# running bacth no / total number of training batches [.....progressbar(shown when verbose=1)......] ETA(estimated time of (result) arrival)

In [None]:
# load model
from tensorflow.keras.models import load_model

prev_saved_model = load_model('models/ecg_arrgythmia_detection_model.h5')
#prev_saved_model = load_model('models/cnn.h5')

print(prev_saved_model.summary())
# print(prev_saved_model.get_weights())
# print(prev_saved_model.optimizer)

In [None]:
# train model
Epoch = 10
Verbose = 1

prev_saved_model.fit(x=train_batches,validation_data=valid_batches,epochs=Epoch,verbose=Verbose)
prev_saved_model.save('models/ecg_arrgythmia_detection_model.h5')
print('model saved successfully.')

In [None]:
# test model

predictions = prev_saved_model.predict(x = test_batches, verbose=1)

In [None]:
test_labels = test_batches.classes
print(test_labels)

In [None]:
print(np.round(predictions))

In [None]:
rounded_predictions = np.argmax(predictions, axis=-1)

In [None]:
# confusion matrix plot function
from sklearn.metrics import confusion_matrix
import itertools

def plot_confusion_matrix_custom(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """ prints and plots confusion matrix. 
        normalization can be applied by setting `normalize=True` """
    plt.imshow(cm, interpolation = 'nearest', cmap = cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation = 45)
    plt.yticks(tick_marks, classes)
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print('Normalized confusion matrix')
    else:
        print('Confusion matrix, without normalization ')
    
    print(cm)
    
    thresh = cm.max() / 2
    for i,j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, "{:0.2f}".format(cm[i, j]), horizontalalignment="center", color="white" if cm[i,j] > thresh else "black")
        
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig('confusion.jpg', dpi=400)

In [None]:
# confusion matrix output
cm = confusion_matrix(y_true = test_labels, y_pred = rounded_predictions)
cm_plot_lables = ['NOR', 'LBBB', 'RBBB', 'APC', 'PVC', 'VEB','PAB', 'VFE']
# non normalized confusion matrix
#plot_confusion_matrix_custom(cm = cm, classes = cm_plot_lables)

# normalized confusion matrix
plot_confusion_matrix_custom(cm = cm, classes = cm_plot_lables, normalize = True)

In [None]:
# classification report
from sklearn.metrics import classification_report

classification_report_result = classification_report(test_labels, rounded_predictions, target_names=cm_plot_lables)
print(classification_report_result)