In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataloader import TUHZDataloaderBinary
import warnings

warnings.filterwarnings("ignore")


class TinyEEGCNN(nn.Module):
    def __init__(self, in_ch=41):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),  # -> [B, 128, 1]
        )
        self.fc = nn.Linear(128, 1)  # binary logit

    def forward(self, x):
        h = self.net(x).squeeze(-1)  # [B, 128]
        return self.fc(h).squeeze(-1)  # [B]


dataloader = TUHZDataloaderBinary("eeg_seizure_only.json")
loader = dataloader.return_loader()

device = "cuda" if torch.cuda.is_available() else "cpu"
model = TinyEEGCNN().to(device)

opt = torch.optim.Adam(model.parameters(), lr=1e-3)


# def eval_loader(loader):
#     model.eval()
#     total = correct = 0
#     tp = fp = fn = 0
#     loss_sum = 0.0

#     with torch.no_grad():
#         for batch in loader:
#             print("Started a batch")
#             x = batch["x"].to(device)
#             y = batch["y"].float().to(device)  # BCE expects float {0,1}

#             logits = model(x)
#             loss = F.binary_cross_entropy_with_logits(logits, y)
#             loss_sum += loss.item() * y.numel()

#             pred = (torch.sigmoid(logits) >= 0.5).long()
#             y_long = y.long()

#             total += y.numel()
#             correct += (pred == y_long).sum().item()

#             tp += ((pred == 1) & (y_long == 1)).sum().item()
#             fp += ((pred == 1) & (y_long == 0)).sum().item()
#             fn += ((pred == 0) & (y_long == 1)).sum().item()

#     acc = correct / max(total, 1)
#     prec = tp / max(tp + fp, 1)
#     rec = tp / max(tp + fn, 1)
#     f1 = 2 * prec * rec / max(prec + rec, 1e-12)
#     return loss_sum / max(total, 1), acc, f1


In [None]:
epochs = 1
for ep in range(1, epochs + 1):
    model.train()
    for idx, batch in enumerate(loader):
        print(f"Batch: {idx}")
        x = batch["x"].to(device)
        y = batch["y"].float().to(device)

        logits = model(x)
        loss = F.binary_cross_entropy_with_logits(logits, y)

        opt.zero_grad()
        loss.backward()
        opt.step()

    # tr_loss, tr_acc, tr_f1 = eval_loader(loader)
    # print(f"Epoch {ep} | train loss {tr_loss:.4f} acc {tr_acc:.3f} f1 {tr_f1:.3f}")

Batch: 0
Batch: 1
Batch: 2
Batch: 3
Batch: 4
Batch: 5
Batch: 6
Batch: 7
Batch: 8
Batch: 9
Batch: 10
Batch: 11
Batch: 12
Batch: 13
Batch: 14
Batch: 15
Batch: 16
Batch: 17
Batch: 18
Batch: 19
Batch: 20
Batch: 21
Batch: 22
Batch: 23
Batch: 24
Batch: 25
Batch: 26
Batch: 27
Batch: 28
Batch: 29
Batch: 30
Batch: 31
Batch: 32
Batch: 33
Batch: 34
Batch: 35
Batch: 36
Batch: 37
Batch: 38
Batch: 39
Batch: 40
Batch: 41
Batch: 42
Batch: 43
Batch: 44
Batch: 45
Batch: 46
Batch: 47
Batch: 48
Batch: 49
Batch: 50
Batch: 51
Batch: 52
Batch: 53
Batch: 54
Batch: 55
Batch: 56
Batch: 57
Batch: 58
Batch: 59
Batch: 60
Batch: 61
Batch: 62
Batch: 63
Batch: 64
Batch: 65
Batch: 66
Batch: 67
Batch: 68
Batch: 69
Batch: 70
Batch: 71
Batch: 72
Batch: 73
Batch: 74
Batch: 75
Batch: 76
Batch: 77
Batch: 78
Batch: 79
Batch: 80
Batch: 81
Batch: 82
Batch: 83
Batch: 84
Batch: 85
Batch: 86
Batch: 87
Batch: 88
Batch: 89
Batch: 90
Batch: 91
Batch: 92
Batch: 93
Batch: 94
Batch: 95
Batch: 96
Batch: 97
Batch: 98
Batch: 99
Batch: 100