In [None]:
import numpy as np
import pandas as pd
import os
import torch
from tsai.all import PatchTST
import warnings
from sklearn.model_selection import train_test_split


from utils import *

### Defining Patient List

In [None]:
file_directory_res = r"D:\Thesis Project\Github Upload\Project 1\Data\Non Clean Data\Residuals"
file_directory_art = r"D:\Thesis Project\Github Upload\Project 1\Data\Artifact Data\All Artifact"

files = os.listdir(file_directory_res)

patient_ids = []

for file in files:
    patient_ids.append(file[:7])

###  Collecting windows

In [None]:
warnings.filterwarnings('ignore')

X_all = []
y_all = []
patient_window_map = []  # keeps track of which window belongs to which patient

parameters = ["RAP"]
window_size, step_size, window_threshold = 60, 15, 0.5

for pid in patient_ids:
    result = get_features(patient_ids, file_directory_art, patient_ID, parameters,
                     window_size, step_size, window_threshold)

X = result["X"]          # [num_windows, channels, seq_len]
y = result["y"]          # [num_windows, 1]

X_all.append(X)
y_all.append(y)

patient_window_map.extend([pid] * X.shape[0])

# Stacking everything
X_all = torch.cat(X_all, dim=0)               # [TOTAL_WINDOWS, channels, seq_len]
y_all = torch.cat(y_all, dim=0)               # [TOTAL_WINDOWS, 1]
patient_window_map = np.array(patient_window_map)


unique_patients = np.unique(patient_window_map)
# Patient-wise train/validation spliting
train_patients, valid_patients = train_test_split(
    unique_patients,
    test_size=0.2,
    random_state=42
)

# Converting to window indices
train_idx = np.where(np.isin(patient_window_map, train_patients))[0]
valid_idx = np.where(np.isin(patient_window_map, valid_patients))[0]

### Training and Getting Predictions

In [None]:
seq_len = X.shape[2]
model = PatchTST(
        c_in=1,       
        c_out=1,       # single logit for binary classification
        seq_len=seq_len,
        pred_dim=1,    # one prediction per window
        n_layers=2,
        n_heads=4,
        d_model=64,
        patch_len=16,
        stride=8
    )

epoch, batch_size, learning_rate = 50, 16, 1e-3
training_result = training_model(X, y, train_idx, valid_idx, 
                                 model, epoch, batch_size, learning_rate)

# Prediction and evaluation
pred_threshold = 0.5
y_pred = get_prediction_labels(training_result['Model'], training_result['DLS'], pred_threshold)

window_indices = result['Window_indices']
valid_window_indices = np.array(window_indices)[valid_idx]

### Exporting the Training Log

In [None]:
# Extracting recorder values
learn = training_result["Learner"]

training_log_df = pd.DataFrame(
    learn.recorder.values,
    columns=learn.recorder.metric_names
)

training_log_df.to_csv("training_log.csv", index=False)

### Visualizing the Training Log

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))

# Plot losses (left y-axis)
plt.plot(
    training_log_df["epoch"],
    training_log_df["train_loss"],
    label="train loss"
)
plt.plot(
    training_log_df["epoch"],
    training_log_df["valid_loss"],
    label="valid loss"
)

ax = plt.gca()
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")

# Plot recall (right y-axis)
ax2 = ax.twinx()
ax2.plot(
    training_log_df["epoch"],
    training_log_df["recall_label1"],
    linestyle="--",
    label="recall"
)
ax2.set_ylabel("Recall")

# Combine legends
lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(
    lines1 + lines2,
    labels1 + labels2,
    loc="center right"
)

plt.tight_layout()
plt.show()


### Visualizing Confusion Matrix

In [None]:
y_true = y_all[valid_idx].cpu().numpy().astype(int).ravel()

TN = np.sum((y_true == 0) & (y_pred == 0))
FP = np.sum((y_true == 0) & (y_pred == 1))
FN = np.sum((y_true == 1) & (y_pred == 0))
TP = np.sum((y_true == 1) & (y_pred == 1))

cm = np.array([
    [TN, FP],   # True Clean
    [FN, TP]    # True Artifact
])

In [None]:
# Ploting confusion matrix
labels = ["Clean", "Artifact"]

plt.figure(figsize=(5.5, 4.5))
plt.imshow(cm)
plt.colorbar()

plt.xticks(range(2), labels)
plt.yticks(range(2), labels)

for i in range(2):
    for j in range(2):
        plt.text(j, i, f"{cm[i, j]}", ha="center", va="center", fontsize=11)

plt.xlabel("Predicted label")
plt.ylabel("True label")
plt.title(f"Confusion Matrix (Total samples = {cm.sum():,})")

plt.tight_layout()
plt.show()

In [None]:
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1 = 2 * precision * recall / (precision + recall)

precision, recall, f1