Load the data stored in the SpectrogramDataset folder

In [7]:
import os
import numpy as np
import json

# --- Define paths ---
base_dir = "C:\\Users\\jorge\\Desktop\\ECG_PROJECT\\OnGoingDataset\\SpectrogramDataset\\"
spectrogram_dir = os.path.join(base_dir, "spectrograms")
annotations_dir = os.path.join(base_dir, "annotations")

# --- Initialize storage ---
all_specs = []
all_labels = []
all_metadata = []
record_names = []

# --- Loop through spectrogram files ---
for file in os.listdir(spectrogram_dir):
    if not file.endswith("_spec.npy"):
        continue
    
    record_name = file.replace("_spec.npy", "")
    spec_path = os.path.join(spectrogram_dir, f"{record_name}_spec.npy")
    label_path = os.path.join(annotations_dir, f"{record_name}_labels.npy")
    json_path = os.path.join(annotations_dir, f"{record_name}_annotations.json")

    # Load files
    try:
        spec = np.load(spec_path)
        labels = np.load(label_path)
        with open(json_path, "r") as f_json:
            metadata = json.load(f_json)

        # Store
        all_specs.append(spec)
        all_labels.append(labels)
        all_metadata.append(metadata)
        record_names.append(record_name)

        print(f"Loaded: {record_name} | spec: {spec.shape} | labels: {labels.shape} | waves: {list(metadata.keys())[1:]}")

    except Exception as e:
        print(f"Failed to load {record_name}: {e}")

# --- Summary ---
print(f"\nTotal loaded records: {len(all_specs)}")


Loaded: 100 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 101 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 102 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 103 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 104 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 105 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 106 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 107 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 108 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 109 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 10 | spec: (

Loaded: 26 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 27 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 28 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 29 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 2 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 30 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 31 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 32 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 33 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 34 | spec: (40, 2500) | labels: (2500,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 35 | spec: (40, 2500) |

In [8]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
import matplotlib.pyplot as plt
import seaborn as sns

def plot_confusion_matrix(cm, class_names, epoch, save_path=None):
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=class_names,
                yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix - Epoch {epoch}")
    
    if save_path:
        plt.savefig(save_path)
        print(f"📊 Confusion matrix saved: {save_path}")
    else:
        plt.show()
    plt.close()

def plot_training_curves(train_losses, val_accuracies):
    epochs = range(1, len(train_losses) + 1)

    fig, ax1 = plt.subplots(figsize=(8, 5))

    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss', color='tab:red')
    ax1.plot(epochs, train_losses, color='tab:red', label="Loss")
    ax1.tick_params(axis='y', labelcolor='tab:red')

    ax2 = ax1.twinx()
    ax2.set_ylabel('Validation Accuracy (%)', color='tab:blue')
    ax2.plot(epochs, val_accuracies, color='tab:blue', label="Val Accuracy")
    ax2.tick_params(axis='y', labelcolor='tab:blue')

    fig.tight_layout()
    plt.title("Training Loss & Validation Accuracy")
    plt.savefig("training_curves.png")
    print("📈 Training curves saved: training_curves.png")
    plt.close()

# === Dataset ===
class SpectrogramDataset(Dataset):
    def __init__(self, spec_list, label_list):
        self.specs = spec_list
        self.labels = label_list

    def __len__(self):
        return len(self.specs)

    def __getitem__(self, idx):
        spec = torch.tensor(self.specs[idx].T, dtype=torch.float32)  # (80, 20)
        label = torch.tensor(self.labels[idx], dtype=torch.long)     # (80,)
        return spec, label





# === Load Your Data ===
def load_dataset(base_dir):
    spectrogram_dir = os.path.join(base_dir, "spectrograms")
    annotations_dir = os.path.join(base_dir, "annotations")

    all_specs, all_labels = [], []

    for file in os.listdir(spectrogram_dir):
        if not file.endswith("_spec.npy"):
            continue

        record_name = file.replace("_spec.npy", "")
        spec_path = os.path.join(spectrogram_dir, f"{record_name}_spec.npy")
        label_path = os.path.join(annotations_dir, f"{record_name}_labels.npy")

        try:
            spec = np.load(spec_path)  # shape (20, 80)
            labels = np.load(label_path)  # shape (80,)
            all_specs.append(spec)
            all_labels.append(labels)
        except Exception as e:
            print(f"Error loading {record_name}: {e}")

    return all_specs, all_labels

from collections import defaultdict
from sklearn.metrics import confusion_matrix
import numpy as np





In [9]:

#Input (B, T, 20) → Linear → (B, T, 40)
        #       → Transformer → (B, T, 40)
          #     → Linear Classifier → (B, T, 4)

# === Transformer Model ===
class LatentClassifier(nn.Module):
    def  __init__(self, input_dim=40, model_dim=20, num_classes=4, nhead=2, num_layers=2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, model_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=nhead,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classifier = nn.Linear(model_dim, num_classes)

    def forward(self, x):  # x: (batch, seq_len, input_dim)
        x = self.input_proj(x)         # -> (batch, seq_len, model_dim)
        x = self.encoder(x)            # -> (batch, seq_len, model_dim)
        out = self.classifier(x)       # -> (batch, seq_len, num_classes)
        return out
    

import torch
import torch.nn as nn
# One alternative
class LatentClassifier_Pos(nn.Module):
    def __init__(self, input_dim=40, model_dim=20, num_classes=4, nhead=10, num_layers=2, max_len=512):
        super().__init__()

        self.input_proj = nn.Linear(input_dim, model_dim)

        # === Learnable Positional Encoding ===
        self.positional_encoding = nn.Parameter(torch.randn(1, max_len, model_dim))

        # === Transformer Encoder ===
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=nhead,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classifier = nn.Linear(model_dim, num_classes)

    def forward(self, x):  # x: (batch, seq_len, input_dim)
        B, T, _ = x.size()
        x = self.input_proj(x)                    # → (B, T, model_dim)

        # Add positional encoding (truncate or pad as needed)
        x = x + self.positional_encoding[:, :T, :]  # → (B, T, model_dim)

        x = self.encoder(x)                       # → (B, T, model_dim)
        out = self.classifier(x)                  # → (B, T, num_classes)
        return out
    
    
    
def train_one_epoch(model, dataloader, optimizer, device, num_classes=4):
    model.train()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0

    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, num_classes), targets.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"✅ Training Loss: {total_loss / len(dataloader):.4f}")
    return total_loss / len(dataloader);
    
    


In [13]:
from collections import defaultdict
from sklearn.metrics import confusion_matrix

def evaluate_model(model, dataloader, device, num_classes=4,  epoch=0):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            preds = outputs.argmax(dim=-1)

            all_preds.append(preds.cpu().view(-1))
            all_targets.append(targets.cpu().view(-1))

    preds_flat = torch.cat(all_preds).numpy()
    targets_flat = torch.cat(all_targets).numpy()

    # Accuracy
    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    for c in range(num_classes):
        mask = targets_flat == c
        class_total[c] += mask.sum()
        class_correct[c] += (preds_flat[mask] == c).sum()

    total_correct = (preds_flat == targets_flat).sum()
    total_count = len(targets_flat)

    print(f"\n📊 Validation Accuracy Report")
    print(f"  Overall Accuracy: {100 * total_correct / total_count:.2f}%")

    for c in range(num_classes):
        if class_total[c] > 0:
            acc = 100 * class_correct[c] / class_total[c]
            label_name = ["background", "p_wave", "qrs", "t_wave"][c]
            print(f"  Class {c} ({label_name}): {acc:.2f}%")
        else:
            print(f"  Class {c}: No samples")

    # Confusion matrix
    cm = confusion_matrix(targets_flat, preds_flat, labels=list(range(num_classes)))
    print("\n  Confusion Matrix (rows = true, cols = predicted):")
    print(cm)
        # Plot heatmap
    class_names = ["background", "p_wave", "qrs", "t_wave"]
    #plot_confusion_matrix(cm, class_names, epoch, save_path=f"confusion_matrix_epoch_{epoch}.png")

    return 100 * total_correct / total_count  # <== return this value


In [14]:
def save_model(model, path="best_model.pt"):
    torch.save(model.state_dict(), path)
    print(f"📦 Model saved to: {path}")

In [15]:
if __name__ == "__main__":
    base_dir = "C:\\Users\\jorge\\Desktop\\ECG_PROJECT\\OnGoingDataset\\SpectrogramDataset"
    specs, labels = load_dataset(base_dir)

    # Split into train/val
    from sklearn.model_selection import train_test_split
    train_specs, val_specs, train_labels, val_labels = train_test_split(
        specs, labels, test_size=0.2, random_state=42, shuffle=True
    )

    train_dataset = SpectrogramDataset(train_specs, train_labels)
    val_dataset = SpectrogramDataset(val_specs, val_labels)

    
    print(train_specs[0].shape)
    
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = LatentClassifier().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    best_accuracy = 0.0
    epochs_without_improvement = 0
    patience = 10  # Stop if no improvement after 3 epochs
    train_losses = []
    val_accuracies = []
    
    for epoch in range(500):
        print(f"\n🔁 Epoch {epoch + 1}")
        loss = train_one_epoch(model, train_loader, optimizer, device)
        val_accuracy = evaluate_model(model, val_loader, device, epoch=epoch + 1)

        train_losses.append(loss)
        val_accuracies.append(val_accuracy)
        # Early stopping check
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            epochs_without_improvement = 0
            save_model(model, "best_model.pt")
        else:
            epochs_without_improvement += 1
            print(f"⏳ No improvement for {epochs_without_improvement} epoch(s).")

        if epochs_without_improvement >= patience:
            print(f"\n🛑 Early stopping: no improvement after {patience} epochs.")
            break

plot_training_curves(train_losses, val_accuracies)


(40, 2500)

🔁 Epoch 1
✅ Training Loss: 0.7655

📊 Validation Accuracy Report
  Overall Accuracy: 75.44%
  Class 0 (background): 95.63%
  Class 1 (p_wave): 0.00%
  Class 2 (qrs): 80.78%
  Class 3 (t_wave): 6.35%

  Confusion Matrix (rows = true, cols = predicted):
[[66776     0  2734   321]
 [ 6225     0    71    32]
 [ 1846     0  7761     0]
 [12822     0   508   904]]
📦 Model saved to: best_model.pt

🔁 Epoch 2
✅ Training Loss: 0.6179

📊 Validation Accuracy Report
  Overall Accuracy: 78.91%
  Class 0 (background): 90.15%
  Class 1 (p_wave): 10.40%
  Class 2 (qrs): 90.54%
  Class 3 (t_wave): 46.40%

  Confusion Matrix (rows = true, cols = predicted):
[[62950   461  2914  3506]
 [ 5110   658    91   469]
 [  904     0  8698     5]
 [ 7520    13    96  6605]]
📦 Model saved to: best_model.pt

🔁 Epoch 3
✅ Training Loss: 0.5609

📊 Validation Accuracy Report
  Overall Accuracy: 79.42%
  Class 0 (background): 90.85%
  Class 1 (p_wave): 17.53%
  Class 2 (qrs): 91.20%
  Class 3 (t_wave): 42.95%


✅ Training Loss: 0.3955

📊 Validation Accuracy Report
  Overall Accuracy: 81.77%
  Class 0 (background): 85.28%
  Class 1 (p_wave): 62.18%
  Class 2 (qrs): 89.45%
  Class 3 (t_wave): 68.05%

  Confusion Matrix (rows = true, cols = predicted):
[[59553  2734  2383  5161]
 [ 2207  3935     3   183]
 [ 1014     0  8593     0]
 [ 4489     7    52  9686]]
⏳ No improvement for 3 epoch(s).

🔁 Epoch 23
✅ Training Loss: 0.3949

📊 Validation Accuracy Report
  Overall Accuracy: 82.09%
  Class 0 (background): 87.04%
  Class 1 (p_wave): 59.77%
  Class 2 (qrs): 88.00%
  Class 3 (t_wave): 63.73%

  Confusion Matrix (rows = true, cols = predicted):
[[60782  2523  2235  4291]
 [ 2440  3782     0   106]
 [ 1153     0  8454     0]
 [ 5135     8    20  9071]]
📦 Model saved to: best_model.pt

🔁 Epoch 24
✅ Training Loss: 0.3936

📊 Validation Accuracy Report
  Overall Accuracy: 82.12%
  Class 0 (background): 87.38%
  Class 1 (p_wave): 53.62%
  Class 2 (qrs): 91.01%
  Class 3 (t_wave): 63.03%

  Confusion Matr

⏳ No improvement for 6 epoch(s).

🔁 Epoch 43
✅ Training Loss: 0.3673

📊 Validation Accuracy Report
  Overall Accuracy: 82.53%
  Class 0 (background): 88.61%
  Class 1 (p_wave): 57.70%
  Class 2 (qrs): 87.49%
  Class 3 (t_wave): 60.43%

  Confusion Matrix (rows = true, cols = predicted):
[[61875  2055  2161  3740]
 [ 2612  3651     0    65]
 [ 1202     0  8405     0]
 [ 5629     4     0  8601]]
⏳ No improvement for 7 epoch(s).

🔁 Epoch 44
✅ Training Loss: 0.3662

📊 Validation Accuracy Report
  Overall Accuracy: 83.06%
  Class 0 (background): 88.30%
  Class 1 (p_wave): 55.66%
  Class 2 (qrs): 88.70%
  Class 3 (t_wave): 65.77%

  Confusion Matrix (rows = true, cols = predicted):
[[61658  1816  2220  4137]
 [ 2752  3522     0    54]
 [ 1086     0  8521     0]
 [ 4866     3     3  9362]]
📦 Model saved to: best_model.pt

🔁 Epoch 45
✅ Training Loss: 0.3649

📊 Validation Accuracy Report
  Overall Accuracy: 82.57%
  Class 0 (background): 89.48%
  Class 1 (p_wave): 44.36%
  Class 2 (qrs): 89.14%

  Confusion Matrix (rows = true, cols = predicted):
[[61273  1952  2400  4206]
 [ 2463  3822     0    43]
 [  854     0  8753     0]
 [ 5016     4     3  9211]]
⏳ No improvement for 7 epoch(s).

🔁 Epoch 64
✅ Training Loss: 0.3490

📊 Validation Accuracy Report
  Overall Accuracy: 83.17%
  Class 0 (background): 87.23%
  Class 1 (p_wave): 62.20%
  Class 2 (qrs): 90.00%
  Class 3 (t_wave): 67.94%

  Confusion Matrix (rows = true, cols = predicted):
[[60914  2152  2336  4429]
 [ 2347  3936     1    44]
 [  961     0  8646     0]
 [ 4546    11     6  9671]]
⏳ No improvement for 8 epoch(s).

🔁 Epoch 65
✅ Training Loss: 0.3480

📊 Validation Accuracy Report
  Overall Accuracy: 83.56%
  Class 0 (background): 88.34%
  Class 1 (p_wave): 58.96%
  Class 2 (qrs): 89.20%
  Class 3 (t_wave): 67.28%

  Confusion Matrix (rows = true, cols = predicted):
[[61687  1873  2265  4006]
 [ 2570  3731     0    27]
 [ 1038     0  8569     0]
 [ 4647     5     6  9576]]
📦 Model saved to: best_model.pt

🔁 Epoch 66
✅