In [6]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from model import NICUBradycardiaModel
from datasets import load_from_disk
from tqdm import tqdm

In [7]:
# Load the dataset
dataset = load_from_disk("bradycardia-balanced")

In [8]:
# Define a collate function
def collate_fn(batch):
    inputs = torch.tensor([item['input'] for item in batch], dtype=torch.float32)  # Shape: (batch, 2, 3750)
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
    return inputs, labels

# Create DataLoader
train_loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=collate_fn
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
model = NICUBradycardiaModel(
    in_channels=2,
    seq_length=3750,   # 15s at 250Hz
    hidden_size=1536,  # Large LSTM hidden
    lstm_layers=2, 
    out_channels=2
).to(device)

criterion = nn.CrossEntropyLoss()  # Binary classification (2 classes)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [11]:
num_epochs = 2

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    # Initialize the tqdm progress bar for the current epoch
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)
    
    for batch_idx, batch in enumerate(progress_bar):
        # batch is a dictionary or tuple of (inputs, labels)
        inputs, labels = batch
        inputs = inputs.to(device)   # (batch, 2, 3750)
        labels = labels.to(device)   # (batch,)
        
        optimizer.zero_grad()
        
        outputs, _ = model(inputs)  # outputs: (batch, 2)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * inputs.size(0)
        
        # Compute accuracy
        _, preds = torch.max(outputs, dim=1)
        correct = torch.sum(preds == labels).item()
        total_correct += correct
        total_samples += inputs.size(0)
        
        # Update the progress bar
        avg_loss = total_loss / total_samples
        accuracy = total_correct / total_samples * 100.0
        progress_bar.set_postfix({"Loss": f"{avg_loss:.4f}", "Accuracy": f"{accuracy:.2f}%"})
    
    # Print epoch-level results
    print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%")

Epoch 1/2: 100%|██████████| 3280/3280 [13:43<00:00,  3.98it/s, Loss=0.4387, Accuracy=78.08%]


Epoch [1/2] | Loss: 0.4387 | Accuracy: 78.08%


Epoch 2/2: 100%|██████████| 3280/3280 [13:41<00:00,  3.99it/s, Loss=0.3108, Accuracy=86.01%]

Epoch [2/2] | Loss: 0.3108 | Accuracy: 86.01%





In [13]:
torch.save(model.state_dict(), "nicu_bradycardia_model.pth")