Load the data stored in the SpectrogramDataset folder

In [18]:
import os
import numpy as np
import json

# --- Define paths ---
base_dir = "C:\\Users\\jorge\\Desktop\\ECG_PROJECT\\OnGoingDataset\\SpectrogramDataset"
spectrogram_dir = os.path.join(base_dir, "spectrograms")
annotations_dir = os.path.join(base_dir, "annotations")

# --- Initialize storage ---
all_specs = []
all_labels = []
all_metadata = []
record_names = []

# --- Loop through spectrogram files ---
for file in os.listdir(spectrogram_dir):
    if not file.endswith("_spec.npy"):
        continue
    
    record_name = file.replace("_spec.npy", "")
    spec_path = os.path.join(spectrogram_dir, f"{record_name}_spec.npy")
    label_path = os.path.join(annotations_dir, f"{record_name}_labels.npy")
    json_path = os.path.join(annotations_dir, f"{record_name}_annotations.json")

    # Load files
    try:
        spec = np.load(spec_path)
        labels = np.load(label_path)
        with open(json_path, "r") as f_json:
            metadata = json.load(f_json)

        # Store
        all_specs.append(spec)
        all_labels.append(labels)
        all_metadata.append(metadata)
        record_names.append(record_name)

        print(f"Loaded: {record_name} | spec: {spec.shape} | labels: {labels.shape} | waves: {list(metadata.keys())[1:]}")

    except Exception as e:
        print(f"Failed to load {record_name}: {e}")

# --- Summary ---
print(f"\nTotal loaded records: {len(all_specs)}")


Loaded: 100 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 101 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 102 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 103 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 104 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 105 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 106 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 107 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 108 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 109 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 10 | spec: (20, 80) | labels: (80,) | waves: ['p_wav

Loaded: 187 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 188 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 189 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 18 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 190 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 191 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 192 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 193 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 194 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 195 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 196 | spec: (20, 80) | labels: (80,) | waves: ['p_wav

Loaded: 92 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 93 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 94 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 95 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 96 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 97 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 98 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 99 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']
Loaded: 9 | spec: (20, 80) | labels: (80,) | waves: ['p_waves', 'qrs_complexes', 't_waves']

Total loaded records: 200


In [19]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
import matplotlib.pyplot as plt
import seaborn as sns

def plot_confusion_matrix(cm, class_names, epoch, save_path=None):
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=class_names,
                yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix - Epoch {epoch}")
    
    if save_path:
        plt.savefig(save_path)
        print(f"📊 Confusion matrix saved: {save_path}")
    else:
        plt.show()
    plt.close()

def plot_training_curves(train_losses, val_accuracies):
    epochs = range(1, len(train_losses) + 1)

    fig, ax1 = plt.subplots(figsize=(8, 5))

    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss', color='tab:red')
    ax1.plot(epochs, train_losses, color='tab:red', label="Loss")
    ax1.tick_params(axis='y', labelcolor='tab:red')

    ax2 = ax1.twinx()
    ax2.set_ylabel('Validation Accuracy (%)', color='tab:blue')
    ax2.plot(epochs, val_accuracies, color='tab:blue', label="Val Accuracy")
    ax2.tick_params(axis='y', labelcolor='tab:blue')

    fig.tight_layout()
    plt.title("Training Loss & Validation Accuracy")
    plt.savefig("training_curves.png")
    print("📈 Training curves saved: training_curves.png")
    plt.close()

# === Dataset ===
class SpectrogramDataset(Dataset):
    def __init__(self, spec_list, label_list):
        self.specs = spec_list
        self.labels = label_list

    def __len__(self):
        return len(self.specs)

    def __getitem__(self, idx):
        spec = torch.tensor(self.specs[idx].T, dtype=torch.float32)  # (80, 20)
        label = torch.tensor(self.labels[idx], dtype=torch.long)     # (80,)
        return spec, label





# === Load Your Data ===
def load_dataset(base_dir):
    spectrogram_dir = os.path.join(base_dir, "spectrograms")
    annotations_dir = os.path.join(base_dir, "annotations")

    all_specs, all_labels = [], []

    for file in os.listdir(spectrogram_dir):
        if not file.endswith("_spec.npy"):
            continue

        record_name = file.replace("_spec.npy", "")
        spec_path = os.path.join(spectrogram_dir, f"{record_name}_spec.npy")
        label_path = os.path.join(annotations_dir, f"{record_name}_labels.npy")

        try:
            spec = np.load(spec_path)  # shape (20, 80)
            labels = np.load(label_path)  # shape (80,)
            all_specs.append(spec)
            all_labels.append(labels)
        except Exception as e:
            print(f"Error loading {record_name}: {e}")

    return all_specs, all_labels

from collections import defaultdict
from sklearn.metrics import confusion_matrix
import numpy as np





In [38]:

#Input (B, T, 20) → Linear → (B, T, 64)
        #       → Transformer → (B, T, 64)
          #     → Linear Classifier → (B, T, 4)

# === Transformer Model ===
class LatentClassifier(nn.Module):
    def  __init__(self, input_dim=20, model_dim=64, num_classes=4, nhead=, num_l10ayers=2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, model_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=nhead,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classifier = nn.Linear(model_dim, num_classes)

    def forward(self, x):  # x: (batch, seq_len, input_dim)
        x = self.input_proj(x)         # -> (batch, seq_len, model_dim)
        x = self.encoder(x)            # -> (batch, seq_len, model_dim)
        out = self.classifier(x)       # -> (batch, seq_len, num_classes)
        return out
    

import torch
import torch.nn as nn
# One alternative
class LatentClassifier_Pos(nn.Module):
    def __init__(self, input_dim=20, model_dim=64, num_classes=4, nhead=10, num_layers=2, max_len=512):
        super().__init__()

        self.input_proj = nn.Linear(input_dim, model_dim)

        # === Learnable Positional Encoding ===
        self.positional_encoding = nn.Parameter(torch.randn(1, max_len, model_dim))

        # === Transformer Encoder ===
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=nhead,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classifier = nn.Linear(model_dim, num_classes)

    def forward(self, x):  # x: (batch, seq_len, input_dim)
        B, T, _ = x.size()
        x = self.input_proj(x)                    # → (B, T, model_dim)

        # Add positional encoding (truncate or pad as needed)
        x = x + self.positional_encoding[:, :T, :]  # → (B, T, model_dim)

        x = self.encoder(x)                       # → (B, T, model_dim)
        out = self.classifier(x)                  # → (B, T, num_classes)
        return out
    
    
    
def train_one_epoch(model, dataloader, optimizer, device, num_classes=4):
    model.train()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0

    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, num_classes), targets.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"✅ Training Loss: {total_loss / len(dataloader):.4f}")
    return total_loss / len(dataloader);
    
    


SyntaxError: invalid syntax (<ipython-input-38-0ff2c9e7a886>, line 7)

In [39]:
from collections import defaultdict
from sklearn.metrics import confusion_matrix

def evaluate_model(model, dataloader, device, num_classes=4,  epoch=0):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=-1)

            all_preds.append(preds.cpu().view(-1))
            all_targets.append(targets.cpu().view(-1))

    preds_flat = torch.cat(all_preds).numpy()
    targets_flat = torch.cat(all_targets).numpy()

    # Accuracy
    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    for c in range(num_classes):
        mask = targets_flat == c
        class_total[c] += mask.sum()
        class_correct[c] += (preds_flat[mask] == c).sum()

    total_correct = (preds_flat == targets_flat).sum()
    total_count = len(targets_flat)

    print(f"\n📊 Validation Accuracy Report")
    print(f"  Overall Accuracy: {100 * total_correct / total_count:.2f}%")

    for c in range(num_classes):
        if class_total[c] > 0:
            acc = 100 * class_correct[c] / class_total[c]
            label_name = ["background", "p_wave", "qrs", "t_wave"][c]
            print(f"  Class {c} ({label_name}): {acc:.2f}%")
        else:
            print(f"  Class {c}: No samples")

    # Confusion matrix
    cm = confusion_matrix(targets_flat, preds_flat, labels=list(range(num_classes)))
    print("\n  Confusion Matrix (rows = true, cols = predicted):")
    print(cm)
        # Plot heatmap
    class_names = ["background", "p_wave", "qrs", "t_wave"]
    #plot_confusion_matrix(cm, class_names, epoch, save_path=f"confusion_matrix_epoch_{epoch}.png")

    return 100 * total_correct / total_count  # <== return this value


In [40]:
def save_model(model, path="best_model.pt"):
    torch.save(model.state_dict(), path)
    print(f"📦 Model saved to: {path}")

In [41]:
if __name__ == "__main__":
    base_dir = r"C:\Users\jorge\Desktop\ECG_PROJECT\OnGoingDataset\SpectrogramDataset"
    specs, labels = load_dataset(base_dir)

    # Split into train/val
    from sklearn.model_selection import train_test_split
    train_specs, val_specs, train_labels, val_labels = train_test_split(
        specs, labels, test_size=0.2, random_state=42, shuffle=True
    )

    train_dataset = SpectrogramDataset(train_specs, train_labels)
    val_dataset = SpectrogramDataset(val_specs, val_labels)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = LatentClassifier().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    best_accuracy = 0.0
    epochs_without_improvement = 0
    patience = 10  # Stop if no improvement after 3 epochs
    train_losses = []
    val_accuracies = []
    
    for epoch in range(500):
        print(f"\n🔁 Epoch {epoch + 1}")
        loss = train_one_epoch(model, train_loader, optimizer, device)
        val_accuracy = evaluate_model(model, val_loader, device, epoch=epoch + 1)

        train_losses.append(loss)
        val_accuracies.append(val_accuracy)
        # Early stopping check
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            epochs_without_improvement = 0
            save_model(model, "best_model.pt")
        else:
            epochs_without_improvement += 1
            print(f"⏳ No improvement for {epochs_without_improvement} epoch(s).")

        if epochs_without_improvement >= patience:
            print(f"\n🛑 Early stopping: no improvement after {patience} epochs.")
            break

plot_training_curves(train_losses, val_accuracies)



🔁 Epoch 1
✅ Training Loss: 1.2232

📊 Validation Accuracy Report
  Overall Accuracy: 55.09%
  Class 0 (background): 80.39%
  Class 1 (p_wave): 0.00%
  Class 2 (qrs): 65.51%
  Class 3 (t_wave): 29.33%

  Confusion Matrix (rows = true, cols = predicted):
[[1148    0  105  175]
 [ 327    0    0   56]
 [ 133    1  376   64]
 [ 555    0   21  239]]
📦 Model saved to: best_model.pt

🔁 Epoch 2
✅ Training Loss: 1.0413

📊 Validation Accuracy Report
  Overall Accuracy: 58.12%
  Class 0 (background): 80.60%
  Class 1 (p_wave): 0.00%
  Class 2 (qrs): 79.44%
  Class 3 (t_wave): 31.04%

  Confusion Matrix (rows = true, cols = predicted):
[[1151    0  130  147]
 [ 341    0    3   39]
 [ 110    1  456    7]
 [ 531    0   31  253]]
📦 Model saved to: best_model.pt

🔁 Epoch 3
✅ Training Loss: 0.9881

📊 Validation Accuracy Report
  Overall Accuracy: 59.69%
  Class 0 (background): 81.65%
  Class 1 (p_wave): 0.00%
  Class 2 (qrs): 81.88%
  Class 3 (t_wave): 33.62%

  Confusion Matrix (rows = true, cols = pre

✅ Training Loss: 0.7394

📊 Validation Accuracy Report
  Overall Accuracy: 69.38%
  Class 0 (background): 73.81%
  Class 1 (p_wave): 44.91%
  Class 2 (qrs): 91.11%
  Class 3 (t_wave): 57.79%

  Confusion Matrix (rows = true, cols = predicted):
[[1054   86  115  173]
 [ 186  172    1   24]
 [  20   12  523   19]
 [ 242   54   48  471]]
⏳ No improvement for 1 epoch(s).

🔁 Epoch 24
✅ Training Loss: 0.7355

📊 Validation Accuracy Report
  Overall Accuracy: 69.81%
  Class 0 (background): 70.31%
  Class 1 (p_wave): 43.86%
  Class 2 (qrs): 91.64%
  Class 3 (t_wave): 65.77%

  Confusion Matrix (rows = true, cols = predicted):
[[1004   72  117  235]
 [ 165  168    1   49]
 [  12   10  526   26]
 [ 189   41   49  536]]
📦 Model saved to: best_model.pt

🔁 Epoch 25
✅ Training Loss: 0.7278

📊 Validation Accuracy Report
  Overall Accuracy: 69.16%
  Class 0 (background): 73.60%
  Class 1 (p_wave): 46.21%
  Class 2 (qrs): 90.59%
  Class 3 (t_wave): 57.06%

  Confusion Matrix (rows = true, cols = predicte

✅ Training Loss: 0.6498

📊 Validation Accuracy Report
  Overall Accuracy: 71.44%
  Class 0 (background): 71.29%
  Class 1 (p_wave): 51.44%
  Class 2 (qrs): 92.33%
  Class 3 (t_wave): 66.38%

  Confusion Matrix (rows = true, cols = predicted):
[[1018   84  120  206]
 [ 159  197    2   25]
 [  14   10  530   20]
 [ 191   34   49  541]]
⏳ No improvement for 1 epoch(s).

🔁 Epoch 46
✅ Training Loss: 0.6465

📊 Validation Accuracy Report
  Overall Accuracy: 71.16%
  Class 0 (background): 70.24%
  Class 1 (p_wave): 51.17%
  Class 2 (qrs): 92.51%
  Class 3 (t_wave): 67.12%

  Confusion Matrix (rows = true, cols = predicted):
[[1003   81  123  221]
 [ 152  196    2   33]
 [  13    7  531   23]
 [ 183   35   50  547]]
⏳ No improvement for 2 epoch(s).

🔁 Epoch 47
✅ Training Loss: 0.6459

📊 Validation Accuracy Report
  Overall Accuracy: 71.25%
  Class 0 (background): 70.17%
  Class 1 (p_wave): 51.44%
  Class 2 (qrs): 93.03%
  Class 3 (t_wave): 67.12%

  Confusion Matrix (rows = true, cols = predict