In [1]:
import serial
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from matplotlib import rc
from sklearn.model_selection import train_test_split
from torch import nn, optim
import torch.nn.functional as F
import dash
import dash_bootstrap_components as dbc
from dash import html, dcc
import base64
import scipy
from collections import deque
import plotly.graph_objects as go
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
import os
%matplotlib inline

class LSTMModel(nn.Module):
    def __init__(self, seq_len, n_features, embedding_dim=64):
        super(LSTMModel, self).__init__()

        self.seq_len = seq_len
        self.n_features = n_features
        self.embedding_dim = embedding_dim
        self.hidden_dim = 2 * embedding_dim

        self.rnn1 = nn.LSTM(
            input_size=n_features,
            hidden_size=self.hidden_dim,
            num_layers=1,
            batch_first=True
        )

        self.rnn2 = nn.LSTM(
            input_size=self.hidden_dim,
            hidden_size=self.hidden_dim,
            num_layers=1,
            batch_first=True
        )

        self.linear = nn.Linear(self.hidden_dim, embedding_dim)
        self.linear2 = nn.Linear(embedding_dim, 1)
    
    def forward(self, x):
        x = x.view((-1, self.seq_len, self.n_features))
        x, (_, _) = self.rnn1(x)
        x, (hidden_n, _) = self.rnn2(x)
        hidden_n = hidden_n[-1]
        x = self.linear(hidden_n)
        x = self.linear2(x)
        return x

class Encoder(nn.Module):
    def __init__(self, seq_len, n_features, embedding_dim=64):
        super(Encoder, self).__init__()

        self.seq_len, self.n_features = seq_len, n_features
        self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dim

        self.rnn1 = nn.LSTM(
            input_size=n_features,
            hidden_size=self.hidden_dim,
            num_layers=1,
            batch_first=True
        )

        self.rnn2 = nn.LSTM(
            input_size=self.hidden_dim,
            hidden_size=embedding_dim,
            num_layers=1,
            batch_first=True
        )

    def forward(self, x):
        x = x.view((-1, self.seq_len, self.n_features))

        x, (_, _) = self.rnn1(x)
        x, (hidden_n, _) = self.rnn2(x)

        return hidden_n.view((-1, self.n_features, self.embedding_dim))

class Decoder(nn.Module):
    def __init__(self, seq_len, input_dim=64, n_features=1):
        super(Decoder, self).__init__()

        self.seq_len, self.input_dim = seq_len, input_dim
        self.hidden_dim, self.n_features = 2 * input_dim, n_features

        self.rnn1 = nn.LSTM(
            input_size=input_dim,
            hidden_size=input_dim,
            num_layers=1,
            batch_first=True
        )

        self.rnn2 = nn.LSTM(
            input_size=input_dim,
            hidden_size=self.hidden_dim,
            num_layers=1,
            batch_first=True
        )

        self.output_layer = nn.Linear(self.hidden_dim, n_features)

    def forward(self, x):
        x = x.repeat(1, self.seq_len, 1)
        x = x.view((-1, self.seq_len, self.input_dim))

        x, (hidden_n, cell_n) = self.rnn1(x)
        x, (hidden_n, cell_n) = self.rnn2(x)
        x = x.view((-1, self.seq_len, self.hidden_dim))

        return self.output_layer(x)

class RecurrentAutoencoder(nn.Module):
    def __init__(self, seq_len, n_features, embedding_dim=64):
        super(RecurrentAutoencoder, self).__init__()

        self.encoder = Encoder(seq_len, n_features, embedding_dim)
        self.decoder = Decoder(seq_len, embedding_dim, n_features)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)

        return x


RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
TIME_STEP = 116

current_directory = os.getcwd()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

file_path_model_pred = os.path.join(current_directory, 'Model/predict_ecg_100ep_x3_step.pth')
model_pred = torch.load(file_path_model_pred)
model_pred = model_pred.to(device)

file_path_model = os.path.join(current_directory, 'Model/ecg_abnormal_new2_20.pth')
model = torch.load(file_path_model)
model = model.to(device)

def find_peak(signal):
    c = signal.to_numpy().reshape(1,len(signal))
    p,_ = scipy.signal.find_peaks(c[0])
    return p

def make_data(signal,lst_peak):
    y = signal[signal.index.isin(lst_peak)]
    y = y[y>15000]
    y = y.dropna(axis=0)
    return y

def denoise(signal):
    s = signal
    for i in range(1,len(s)-1):
        sig = s.loc[i].values
        pre = s.loc[i-1].values
        nex = s.loc[i+1].values
        if sig < 8000:
            s.loc[i] = (pre + nex)/2
    return s

def create_dataset(data, step):
    X = []
    flatData = data.flatten()
    flatData = flatData.reshape(len(flatData),1)
    for i in range(len(flatData)-3*step):
        feature = flatData[i:i+3*step]
        X.append(feature)
    return torch.tensor(X).float()

def split_data(data,step):
    seqs = []
    for i in range(len(data)-step):
        seq = data[i:i+step]
        seqs.append(seq)
    seqs = np.array(seqs)
    return seqs

def preprocess(data,step):
    seq = split_data(data,step)
    dataset = [torch.tensor(s).unsqueeze(1).float() for s in seq]
    n_seq, seq_len, n_features = torch.stack(dataset).shape
    return dataset, seq_len, n_features

def make_data_predicted(signal):
    p,_ = scipy.signal.find_peaks(signal)
    peaks = [[i,signal[i]] for i in p]
    filted_idx = [i[0] for i in peaks if i[1]>0.5]
    return filted_idx

def predict(model, dataset):
    predictions, losses = [], []
    criterion = nn.L1Loss(reduction='sum').to(device)
    with torch.no_grad():
        model = model.eval()
        for seq_true in dataset:
            seq_true = seq_true.to(device)
            seq_pred = model(seq_true)

            loss = criterion(seq_pred, seq_true)

            predictions.append(seq_pred.cpu().numpy().flatten())
            losses.append(loss.item())
    return predictions, losses

def make_data_predicted(signal):
    p,_ = scipy.signal.find_peaks(signal)
    peaks = [[i,signal[i]] for i in p]
    filted_idx = [i[0] for i in peaks if i[1]>15000]
    try:
        b = filted_idx[-1] - filted_idx[-2]
        if b < 100:
            return 2*b
        else:
            return int(b/2)
    except:
        return 65
    
def make_data_predicted_AI(signal):
    p,_ = scipy.signal.find_peaks(signal)
    peaks = [[i,signal[i]] for i in p]
    filted_idx = [i[0] for i in peaks if 0.6 < i[1] <=1]
    try:
        b = filted_idx[-1] - filted_idx[-2]
        if b < 100:
            return 2*b
        else:
            return int(b/2)
    except:
        return 65

def process_data(data_df, model_pred, model):
    THRESHOLD_HIGH = 4.8
    status = ''
    scaler = MinMaxScaler()
    scaler = scaler.fit(data_df)

    train_data = scaler.transform(data_df)
    X = create_dataset(train_data,TIME_STEP)
    with torch.no_grad():
        y_pred = model_pred(X.to(device))
        y_pred = y_pred.to("cpu")
    signal_predicted = np.array(y_pred).flatten()[-348:]
    signal_formated = list(signal_predicted)
    x_formated = list(np.arange(len(signal_formated)))

    try:
        p = make_data_predicted_AI(signal_formated)
        signal_predicted_filted = np.array(signal_formated[p[0]:p[-1]+1])
        train, seq_len, n_features = preprocess(signal_predicted_filted,TIME_STEP)
        predictions, pred_losses = predict(model, train)
        for i in pred_losses:
            if i > THRESHOLD_HIGH:
                status = "Có bất thường"
                break
            else:
                status = "Bình thường"
    except:
        status = "Bình thường"
    
    return status, x_formated, signal_formated

c:\Users\PC\AppData\Local\Programs\Python\Python39\lib\site-packages\numpy\.libs\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll
c:\Users\PC\AppData\Local\Programs\Python\Python39\lib\site-packages\numpy\.libs\libopenblas64__v0.3.21-gcc_10_3_0.dll


In [2]:
status = "Bình thường"
status_pred = ""
beat = 60
beat_pred = 0
app = dash.Dash(external_stylesheets=[dbc.themes.BOOTSTRAP])
sliced = 0
slicedd = 0
color = {
    "text": "#b9cfbb",
    "plot_color": "#83adb5",
    "paper_color": "#b9def3",
    "title_color": "#00ced1",
    "box_color": "#85b72b",
    "box_text_color": "#fbe102"
}
color_box = ['#b8e6f2', ['#f03a26', '#f5695e']]
# Load and encode the image
image_filename = 'Pic/logo.png'  # Replace with your own image
encoded_image = base64.b64encode(open(image_filename, 'rb').read())
data = pd.read_csv("ecg_predict.csv")
# data = pd.read_csv("ecg_114_clr.csv")

max_points = 240

# Function to create graph1
def create_graph1():
    fig = go.Figure()

    fig.add_trace(go.Scatter(
        x=[],
        y=[],
        mode='lines',
        name='ECG Signal'
    ))

    fig.update_layout(
        title={
            'text': 'TÍN HIỆU ĐIỆN TIM (ECG)',
            'y': 0.98,
            'x': 0.5
        },
        xaxis={'showgrid': True, 'tickvals': []},
        yaxis={'title': 'Value', 'showgrid': True, 'range': [5000, 30000]},
        plot_bgcolor=color["plot_color"],
        paper_bgcolor=color["paper_color"],
        margin={'t': 40, 'b': 25, 'l': 25, 'r': 10},
        font={'color': color["title_color"]},
        autosize=False,
        width=870,  # Set the width of the graph
        height = 350
    )

    return dcc.Graph(
        id='graph1',
        figure=fig,
        style={'width': '100%',"height":"100%"}  # Set the width of the graph container
    )

# Function to create graph2
def create_graph2():
    fig = go.Figure()

    fig.add_trace(go.Scatter(
        x=[],
        y=[],
        mode='lines',
        name='ECG Signal'
    ))

    fig.update_layout(
        title={
            'text': 'TÍN HIỆU ĐIỆN TIM DỰ ĐOÁN (ECG)',
            'y': 0.98,
            'x': 0.5
        },
        xaxis={'showgrid': True, 'tickvals': []},
        yaxis={'title': 'Value', 'showgrid': True, 'range': [-0.1, 1.1]},
        plot_bgcolor=color["plot_color"],
        paper_bgcolor=color["paper_color"],
        margin={'t': 40, 'b': 25, 'l': 25, 'r': 10},
        font={'color': color["title_color"]},
        autosize=False,
        width=870,  # Set the width of the graph
        height = 350
    )

    return dcc.Graph(
        id='graph2',
        figure=fig,
        style={'width': '100%',"height":"100%"}  # Set the width of the graph container
    )

# update color box status
@app.callback(
    Output('box1', 'color'),
    Input('interval-component', 'n_intervals'),
    State('box1', 'color')
)
def update_box_color(n,current_color):
    if beat < 114:
        return color_box[0]
    else:
        return color_box[1][n % 2]

""" Update first row """
# Update graph1 with data from read_ecg
@app.callback(
    Output('graph1', 'figure'),
    Input('graph1', 'figure'),
    Input('interval-component', 'n_intervals')
)
def update_graph1(fig, n):
    global slicedd, data
    df = data.loc[sliced:sliced+348].to_numpy().flatten()
    last_100_y = list(df)
    last_100_x = list(np.arange(len(df)))
    fig['data'][0]['x'] = last_100_x
    fig['data'][0]['y'] = last_100_y

    return fig

# update heart beat real
@app.callback(
    Output('box1-text', 'children'),
    Input('graph1', 'figure')
)
def update_box1_text(fig):
    global beat

    if not fig or not fig['data']:
        raise PreventUpdate

    y_data = fig['data'][0]['y']
    beat = make_data_predicted(y_data)

    return 'Nhịp tim: {0} bpm'.format(beat)

# update heart status real
@app.callback(
    Output('box1-text_state', 'children'),
    Input('graph1', 'figure')
)
def update_box1_status(fig):
    global status
    global beat
    if beat > 114:
        status = "Tim đập nhanh"
    elif beat <50:
        status = "Tim đập chậm"
    else:
        status = "Bình thường"

    return 'Trạng thái: {0}'.format(status)

""" Update second row"""
# Update graph1 with data from read_ecg
@app.callback(
    Output('graph2', 'figure'),
    Input('graph2', 'figure'),
    Input('interval-component-2', 'n_intervals')
)
def update_graph2(fig, n):
    global sliced, data, status_pred
    last_100_y = data.loc[sliced:1000+sliced]
    status_pred, x_predict, predict_signal = process_data(last_100_y, model_pred, model)
    fig['data'][0]['x'] = x_predict
    fig['data'][0]['y'] = predict_signal
    sliced += 50

    return fig
    

# update heart beat predict
@app.callback(
    Output('box2-text', 'children'),
    Input('graph2', 'figure')
)
def update_box2_text(fig):
    global beat_pred

    if not fig or not fig['data']:
        raise PreventUpdate

    y_data = fig['data'][0]['y']
    beat_pred = make_data_predicted_AI(y_data)

    return 'Nhịp tim: {0} bpm'.format(beat_pred)

# update heart status predict
@app.callback(
    Output('box2-text_state', 'children'),
    Input('graph2', 'figure')
)
def update_box2_status(fig):
    global status_pred, beat_pred
    if beat_pred > 114 or beat_pred < 50:
        status_pred = "Có bất thường"
    else:
        status_pred = "Bình thường"
    return 'Trạng thái: {0}'.format(status_pred)

""" Define the layout of the dashboard """
app.layout = html.Div([
    # Image on the top left
    html.Img(src='data:image/png;base64,{}'.format(encoded_image.decode()),
             style={'width': '290px', 'height': '60px'}),

    html.Br(),
    html.Br(),

    # Container for the graphs and boxes using grid system
    dbc.Container([
        dbc.Row([
            # Graph 1
            dbc.Col(
                create_graph1(),
                width=8,  # Set the width of the column
                className="mb-4"  # Add margin bottom
            ),
            dbc.Col(
                dbc.Card(
                    html.Div([
                        html.H2('THÔNG SỐ',
                            style={
                                'color': color["box_text_color"],
                                'text-align': 'center',
                                # 'font-family': 'Times New Roman',
                                'font-weight': 'bold',
                                '-webkit-text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'padding': '10px'}),
                        html.P('Nhịp tim: {0} bpm'.format(beat),
                            id='box1-text',
                            className='card-text',
                            style={
                                'color': color["box_text_color"],
                                'font-size': '25px',
                                # 'font-family': 'Times New Roman',
                                'font-weight': 'bold',
                                'margin-left': '10px',
                                '-webkit-text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'padding': '10px'  # add some padding to the text to make it more readable
                            }),
                        html.P('Trạng thái: {0}'.format(status),
                            id='box1-text_state',
                            className='card-text',
                            style={
                                'color': color["box_text_color"],
                                'font-size': '25px',
                                # 'font-family': 'Times New Roman',
                                'font-weight': 'bold',
                                'margin-left': '10px',
                                '-webkit-text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'padding': '10px'  # add some padding to the text to make it more readable
                            }),
                        html.P('Tần số lấy mẫu: 116 Hz',
                            className='card-text',
                            style={
                                'color': color["box_text_color"],
                                'font-size': '25px',
                                # 'font-family': 'Times New Roman',
                                'font-weight': 'bold',
                                'margin-left': '10px',
                                '-webkit-text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'padding': '10px'  # add some padding to the text to make it more readable
                            })
                    ]),
                    id='box1',
                    color = color_box[0],
                    inverse=True,
                    className="mb-4",
                    style={'border': '8px solid #006991'}
                ),
                width=4
            ),
        ]),

        dbc.Row([
            # Graph 2
            dbc.Col(
                create_graph2(),
                width=8,  
                className="mb-4"
            ),
            dbc.Col(
                dbc.Card(
                    html.Div([
                        html.H2('THÔNG SỐ DỰ ĐOÁN',
                            style={
                                'color': color["box_text_color"],
                                'text-align': 'center',
                                # 'font-family': 'Times New Roman',
                                'font-weight': 'bold',
                                '-webkit-text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'padding': '10px'}),
                        html.P('Nhịp tim: {0} bpm'.format(beat_pred),
                            id='box2-text',
                            className='card-text',
                            style={
                                'color': color["box_text_color"],
                                'font-size': '25px',
                                # 'font-family': 'Times New Roman',
                                'font-weight': 'bold',
                                'margin-left': '10px',
                                '-webkit-text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'padding': '10px'  # add some padding to the text to make it more readable
                            }),
                        html.P('Trạng thái: {0}'.format(status_pred),
                            id='box2-text_state',
                            className='card-text',
                            style={
                                'color': color["box_text_color"],
                                'font-size': '25px',
                                # 'font-family': 'Times New Roman',
                                'font-weight': 'bold',
                                'margin-left': '10px',
                                '-webkit-text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'padding': '10px'  # add some padding to the text to make it more readable
                            }),
                        html.P('Thời gian dự đoán: 5s',
                            className='card-text-2',
                            style={
                                'color': color["box_text_color"],
                                'font-size': '25px',
                                # 'font-family': 'Times New Roman',
                                'font-weight': 'bold',
                                'margin-left': '10px',
                                '-webkit-text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'text-stroke': '1px black',  # add a 1-pixel wide black outline around the text
                                'padding': '10px'  # add some padding to the text to make it more readable
                            })
                    ]),
                    id='box-2',
                    color = color_box[0],
                    inverse=True,
                    className="mb-4",
                    style={'border': '8px solid #006991'}
                ),
                width = 4
            ),
        ]),
    ]),
    dcc.Interval(
        id='interval-component',
        interval=3000,
        n_intervals=0),
    dcc.Interval(
        id='interval-component-2',
        interval=5000,
        n_intervals=0)
])

In [3]:
if __name__ == '__main__':
    app.run_server(port = 9900)

Dash is running on http://127.0.0.1:9900/

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:9900
Press CTRL+C to quit
127.0.0.1 - - [08/Dec/2023 09:02:15] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [08/Dec/2023 09:02:15] "GET /_dash-component-suites/dash/deps/react-dom@16.v2_6_1m1688700425.14.0.min.js HTTP/1.1" 200 -
127.0.0.1 - - [08/Dec/2023 09:02:15] "GET /_dash-component-suites/dash/deps/prop-types@15.v2_6_1m1688700425.8.1.min.js HTTP/1.1" 200 -
127.0.0.1 - - [08/Dec/2023 09:02:15] "GET /_dash-component-suites/dash/deps/react@16.v2_6_1m1688700425.14.0.min.js HTTP/1.1" 200 -
127.0.0.1 - - [08/Dec/2023 09:02:15] "GET /_dash-component-suites/dash/deps/polyfill@7.v2_6_1m1688700425.12.1.min.js HTTP/1.1" 200 -
127.0.0.1 - - [08/Dec/2023 09:02:15] "GET /_dash-component-suites/dash/dcc/dash_core_components-shared.v2_6_1m1688700425.js HTTP/1.1" 200 -
127.0.0.1 - - [08/Dec/2023 09:02:15] "GET /_dash-component-suites/dash_bootstrap_components/_components/dash_bootstrap_components.v1_4_1m1688742597.min.js HTTP/1.1" 200 -
127.0.0.1 - - [08/Dec/2023 09:02:15] "