In [None]:
from config import get_dl_config
import xarray as xr
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
torch.cuda.is_available()

In [None]:
conf = get_dl_config()

INPUT_FILE = conf['input_file']

### load data

In [None]:
data = xr.load_dataarray(INPUT_FILE)
data 

In [None]:
labels = data['label'].values
labels_map = {"bckg": 0, "seiz": 1}
labels_numeric = np.vectorize(labels_map.get)(labels)
labels_numeric

In [None]:
data_tensor = torch.tensor(data.values, dtype=torch.float32).unsqueeze(1)
labels_tensor = torch.tensor(labels_numeric, dtype=torch.float32)

In [None]:
dataset = TensorDataset(data_tensor, labels_tensor)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        # First Convolutional Layer
        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5), stride=1)  # Input channels = 1, Output channels = 6
        # Mean Pooling Layer
        self.pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)  # Mean pooling with kernel size 2x2
        
        # Calculate the flattened size for the fully connected layer
        # Input size: (batch_size, 1, 19, 5250)
        # After conv1: (batch_size, 6, (19-5+1)=15, (5250-5+1)=5246)
        # After pool: (batch_size, 6, 15//2=7, 5246//2=2623)
        self.fc_input_size = 6 * 7 * 2623
        
        # Fully Connected Layer
        self.fc1 = nn.Linear(self.fc_input_size, 1)  # Assuming binary classification: seizure vs. non-seizure

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)  # Flatten all dimensions except batch
        x = self.fc1(x)
        x = torch.sigmoid(x)  # Sigmoid activation for binary classification
        return x

net = EEGNet()

In [None]:
# Loss and optimizer
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss for binary classification
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(data_loader):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels.unsqueeze(1))  # Reshape labels to match output
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print(f'Epoch {epoch + 1}, Iteration {i + 1}, Loss: {loss.item()}'

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(data_loader)}')

print('Finished Training')