In [6]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset
from torchvision import transforms
import pytorch_lightning as pl
import torch.nn.functional as F
import torchaudio.transforms as T

In [3]:
data_pairs_path = Path('../data/processed/meta/data_pairs.csv')
# data_pairs_df = pd.read_csv(data_pairs_path)

In [None]:
class CNNModel(pl.LightningModule):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv1d(8, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 2500, 128)
        self.fc2 = nn.Linear(128, 3)  # Output has 3 classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        pred = torch.argmax(y_hat, dim=1)
        acc = torch.sum(pred == y).item() / len(y)
        return {"val_loss": loss, "val_acc": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_acc"] for x in outputs]).mean()
        self.log("val_loss", avg_loss)
        self.log("val_acc", avg_acc)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

In [None]:
# Split the dataset into train and validation sets
dataset = ECGDataset(ecg_data, ecg_labels, transform=transforms.ToTensor())
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Instantiate the Lightning model
model = CNNModel()

# Train the model using PyTorch Lightning Trainer
trainer = pl.Trainer(max_epochs=10, gpus=1)  # Set max_epochs as needed
trainer.fit(model, train_loader, val_loader)