# Main Pipeline File

the following will be the pipeline:
<ol>
   <li>Read in the files and deal with the missing data. </li>
   <li>Preprocess the signals</li>
   <li>Complete feature extraction</li>
   <li>Put data into a model.</li>
   <li>optimize, compare, iterate - try other models.</li>
</ol>

In [None]:
# Import libraries. 
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt 
import os
import wfdb
import pickle
import sys
import glob
from scipy.signal import butter, lfilter
import pprint


In [None]:
# Load helper files.
import dataloaders
import visualize
import preprocess
import segment 
# import cart_model # @Henry what is this?

In [None]:
dataloaders.get_device_info()

In [None]:
# 'C:/Users/henry/OneDrive/Desktop/ELEC 872 - AI and Interactive Systems/Project/mit-bih-arrhythmia-database-1.0.0/'
# 'G:/Datasets/mit-bih-arrhythmia-database-1.0.0/'
file_path = 'G:/Datasets/mit-bih-arrhythmia-database-1.0.0/'

In [None]:
# Load data. 
patient_data = dataloaders.load_all_records(file_path)

In [None]:
pp = pprint.PrettyPrinter(indent=2)

# Print the structure of the patient_data dictionary
pp.pprint(patient_data['103'][:3])

### Print out the rhythm count (as beats per rhythm type)

In [None]:

# for i in range (100,125):
#     visualize.summarize_rhythm_counts(patient_data, str(i))

# for i in range (200,225):
#     visualize.summarize_rhythm_counts(patient_data, str(i))

#visualize.summarize_rhythm_counts(patient_data, "203")


# Preprocessing Stage

Note. Prior to this, we downsample. typically we downsample AFTER preprocessing. I've elected to do it before because of the way the annotation object is structured.

0. Convert Dictionary to array value style for use with the filter functions from scipy.
1. High-Pass Filter to remove baseline wander
2. Notch Filter to remove powerline interference (if any?)
3. Low-Pass Filter to remove high-frequency noise (set for 40 hz for now?)
4. Moving Average Filter to smooth the remaining signal 
5. Normalization for 0-->1 because the leads all act differently

<p> may want to consider an FFT or a Wavelet Transform because it's time series data. can determine later on. </p>

In [None]:
# aggregate data into arrays from the dict first, note this is a progress test.
# aggregated_patient_data = preprocess.aggregate_signals(patient_data)
# pp.print(aggregated_data['103'])

In [None]:
processed_data = preprocess.preprocess_patient_data(patient_data)
pp.pprint(processed_data['103'])

In [None]:
#without labels. 
visualize.visualize_patient_data(processed_data,'100',10)

In [None]:
# with labels - note i just did very basic math and put it in the general area of the segment. might not lineup 1:1
visualize.visualize_patient_data_with_rhythm(processed_data, patient_id='103', display_seconds=10)

## Feature Extractions?

Now that the signals have been filtered and normalized - we have 2 leads. we can select which one is more important to use.
considerations include ICA, correlation, SNR etc.

for now, I've done SNR, std deviation, high frequency check, entropy.

In [None]:
# Create a dictionary to store the best lead's signal and labels for each patient
best_leads_data = {}

# Iterate through all patients and determine the better lead for each
for patient_id in processed_data.keys():
    # Determine the best lead
    better_lead = visualize.compute_noise_metrics_for_patient(processed_data, patient_id)
    print(f"Patient {patient_id}: Best Lead is {better_lead}")

    # Extract the corresponding lead's signal and labels , 
    if better_lead == 'Lead 1':
        best_signal = processed_data[patient_id]['signals_lead_1']
    elif better_lead == 'Lead 2':
        best_signal = processed_data[patient_id]['signals_lead_2']

    labels = processed_data[patient_id]['labels']

    # Save the best signal and labels into the dictionary
    best_leads_data[patient_id] = {
        'signal': best_signal,
        'labels': labels
    }

print("new dictionary with specifc labels + lead for given patient added.")

# Test simple RNN!

In [None]:
# Parameters for segmentation
segment_length_sec = 10  # Segment length in seconds
fs = 250  # Sampling frequency

# Prepare segmented data
rnn_signals = []
rnn_labels = []

for patient_id, data in best_leads_data.items():
    signal = data['signal']  # The best lead's signal
    labels = data['labels']  # Corresponding labels

    # Segment the data
    segments, segment_labels = segment.prepare_rnn_data(signal, labels, segment_length_sec, fs)
    rnn_signals.extend(segments)
    rnn_labels.extend(segment_labels)

# Convert to numpy arrays
rnn_signals = np.array(rnn_signals)
rnn_labels = np.array(rnn_labels)

print(f"RNN Signals Shape: {rnn_signals.shape}")
print(f"RNN Labels Shape: {rnn_labels.shape}")


In [None]:
# Split data into train, validation, and test sets
data_splits = segment.split_data(rnn_signals, rnn_labels)

# Create DataLoaders for PyTorch
train_loader, val_loader, test_loader = dataloaders.create_dataloaders(data_splits, batch_size=32)

print(f"Train Loader Size: {len(train_loader)} batches")
print(f"Validation Loader Size: {len(val_loader)} batches")
print(f"Test Loader Size: {len(test_loader)} batches")


In [None]:
import torch.nn as nn

class ECG_RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(ECG_RNN, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = out[:, -1, :]  # Take the last hidden state
        out = self.fc(out)
        return out

def train_model(model, train_loader, val_loader, num_epochs, criterion, optimizer, device):
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, labels).item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {running_loss/len(train_loader):.4f}, "
              f"Val Loss: {val_loss/len(val_loader):.4f}, "
              f"Val Accuracy: {100 * correct / total:.2f}%")

    return model

def evaluate_model(model, test_loader, device):
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total:.2f}%")



In [None]:
# Hyperparameters
input_size = 1  # ECG data is 1D
hidden_size = 64
num_layers = 2
num_classes = len(set(rnn_labels))
num_epochs = 20
batch_size = 32
learning_rate = 0.001

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create the model
model = ECG_RNN(input_size, hidden_size, num_layers, num_classes)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
train_loader, val_loader, test_loader = dataloaders.create_dataloaders(data_splits, batch_size)
model = train_model(model, train_loader, val_loader, num_epochs, criterion, optimizer, device)

# Evaluate the model
evaluate_model(model, test_loader, device)
