# File to experiment with different machine learning models.

In [1]:
from lib.DataObject import DataObject
import lib.DataObjectUtils as util
import torch
import pickle
import torch.nn as nn
from lib.DataHandler import DataAcquisitionHandler
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import WeightedRandomSampler
import matplotlib.pyplot as plt

pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
# load data from data/* using pickle
filename = 'data/handler_box_data_full_Oct_30_2023.pkl'
with open(filename, 'rb') as f:
    handler = pickle.load(f)

In [3]:
# Data filter visitor test

data = DataObject(handler.get_data())

data.accept(util.BandpassFilterVisitor(low=0.1, high=10))
data.accept(util.BandstopFilterVisitor(low=49, high=51))

key_data, box_data = data.get_data(decorator=util.MakeTensorWindowsDataDecorator())

sample = box_data[0]

print("List of samples:" ,len(box_data))
print("Sample - (channels, label):", len(box_data[0]))
print("Channels:", len(box_data[0][0]))
max_channel_len = 0
min_channel_len = 10000000
for sample in box_data:
    for channel in sample[0]:
        if len(channel) > max_channel_len:
            max_channel_len = len(channel)
        if len(channel) < min_channel_len:
            min_channel_len = len(channel)
print("Max Channel Len - [reading_1, ...]:", max_channel_len)
print("Min Channel Len - [reading_1, ...]:", min_channel_len)
print("Sample example:", sample)

List of samples: 83
Sample - (channels, label): 2
Channels: 24
Max Channel Len - [reading_1, ...]: 250
Min Channel Len - [reading_1, ...]: 250
Sample example: (tensor([[ 2.4500e+02,  2.4600e+02,  2.4700e+02,  ...,  2.3600e+02,
          2.3700e+02,  2.3800e+02],
        [ 8.3703e+03,  8.4933e+03,  8.5348e+03,  ...,  8.4436e+03,
          8.5442e+03,  8.3925e+03],
        [-6.4812e+00, -6.1533e+00, -5.2448e+00,  ..., -3.8378e+00,
         -3.7535e+00, -3.5997e+00],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.6987e+09,  1.6987e+09,  1.6987e+09,  ...,  1.6987e+09,
          1.6987e+09,  1.6987e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]]), tensor(1))


In [4]:
# Dataset

class EEGDataset(torch.utils.data.Dataset):

    def __init__(self, data):
        self.data, self.labels = self.parse_data(data)
        self.window_size = self.data.shape[2]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]
    
    def parse_data(self, data):
        """
        Data comes in the form of a list of tuples (data, label)
        data is a 2D tensor (channels, readings)
        label is a 1D tensor (1)

        return:
        data: 3D tensor (samples, channels, readings)
        label: 1D tensor (samples)
        """

        data_list = []
        label_list = []

        channels_idx = (1, 9)

        for sample in data:
            data_list.append(sample[0][channels_idx[0]:channels_idx[1]])
            label_list.append(sample[1])

        return torch.stack(data_list), torch.stack(label_list)
    
    def downsample(self):
        # Count the number of occurrences of each class
        label_counts = torch.bincount(self.labels)
        # Find the least represented class and its count
        min_count = torch.min(label_counts).item()

        downsampled_data = []
        downsampled_labels = []
        # Keep track of how many samples per class are added to the downsampled set
        samples_per_class = dict()

        # Iterate over data and add samples to the new downsampled dataset
        for i in range(len(self.data)):
            label = self.labels[i].item()
            # If the class is not in the dictionary, or the count for this class is less than the minimum count
            if samples_per_class.get(label, 0) < min_count:
                # Add the sample to the downsampled set
                downsampled_data.append(self.data[i])
                downsampled_labels.append(self.labels[i])
                # Increment the count for this class in the dictionary
                samples_per_class[label] = samples_per_class.get(label, 0) + 1

        # Replace the dataset with the downsampled set, stacked into tensors
        self.data = torch.stack(downsampled_data)
        self.labels = torch.stack(downsampled_labels)

    # Function to create a balanced sampler
    def make_balanced_sampler(labels):
        class_counts = torch.bincount(labels)
        class_weights = 1. / class_counts
        weights = class_weights[labels]
        sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
        return sampler

    # Function to print class distribution per batch
    def print_class_distribution_per_batch(dataloader):
        for i, (_, labels) in enumerate(dataloader):
            class_counts = torch.bincount(labels)
            class_distribution = {f"class_{class_idx}": count.item() for class_idx, count in enumerate(class_counts)}
            print(f"Batch {i}: class distribution: {class_distribution}")

In [14]:
# Data loader

# Set seed
torch.manual_seed(0)

dataset = EEGDataset(box_data)
dataset.downsample()
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Check if data is loaded correctly
print("Train dataset size:", len(train_dataset))
print("Test dataset size:", len(test_dataset))
print("Train loader size:", len(train_loader))
print("Test loader size:", len(test_loader))
print("Sample data shape:", len(dataset))

def print_class_distribution_per_batch(dataloader):
    for i, (_, labels) in enumerate(dataloader):
        # Counting occurrences of each class in the batch
        class_counts = torch.bincount(labels)
        class_distribution = {f"class_{class_idx}": count.item() for class_idx, count in enumerate(class_counts)}
        print(f"Batch {i}: class distribution: {class_distribution}")

# Call the function for your train_loader
print_class_distribution_per_batch(train_loader)

Train dataset size: 62
Test dataset size: 16
Train loader size: 2
Test loader size: 1
Sample data shape: 78
Batch 0: class distribution: {'class_0': 16, 'class_1': 16}
Batch 1: class distribution: {'class_0': 15, 'class_1': 15}


In [15]:
# create p300Model
class EEG_Net_CNN(torch.nn.Module):
    """
    Pytorch implementation of EEGNet

    Expecting input of shape (batch_size, channels, readings)
    input = [32, 8, 250] = [batch_size, channels, readings]
    batch_size: number of samples in a batch
    channels: number of channels in a sample (8)
    readings: number of readings in a channel (len())
    """
    
    def __init__(self, num_channels=8, num_classes=2, input_length=250):
        super(EEG_Net_CNN, self).__init__()

        self.block1 = torch.nn.Sequential(
            # Conv1D
            nn.Conv1d(num_channels, 32, kernel_size=50, stride=1, padding=0, bias=False),
            
            # Batch norm
            nn.BatchNorm1d(32),

            # DepthwiseConv1D
            nn.Conv1d(32, 32, kernel_size=1, groups=32, bias=False),

            # Batch norm
            nn.BatchNorm1d(32),

            # ELU Activation
            nn.ELU(alpha=1.0),

            # Avg Pooling 1D
            nn.AvgPool1d(kernel_size=4, stride=4, padding=0),

            # Dropout
            nn.Dropout(p=0.15)
        )

        self.block2 = torch.nn.Sequential( 
            # Separable Conv1D
            nn.Conv1d(32, 32, kernel_size=15, stride=1, padding=0, bias=False),

            # Batch norm
            nn.BatchNorm1d(32),

            # ELU Activation
            nn.ELU(alpha=1.0),

            # Avg Pooling 1D
            nn.AvgPool1d(kernel_size=8, stride=8, padding=0),

            # Dropout
            nn.Dropout(p=0.15)
        )

        # Calculating the length of the signal after convolutions and pooling
        def conv_output_length(input_length, kernel_size, stride=1, padding=0):
            return (input_length - kernel_size + 2*padding) // stride + 1
        
        conv1_out_length = conv_output_length(input_length, 50, 1, 0)
        pool1_out_length = conv_output_length(conv1_out_length, 4, 4, 0)
        conv2_out_length = conv_output_length(pool1_out_length, 15, 1, 0)
        pool2_out_length = conv_output_length(conv2_out_length, 8, 8, 0)
        
        linear_input_features = pool2_out_length * 32  # 32 is the number of output channels after block1

        # Fully Connected Layer
        self.fc = nn.Linear(in_features=linear_input_features, out_features=num_classes, bias=True)

    def forward(self, x):
        # block 1
        x = self.block1(x)

        # block 2
        x = self.block2(x)

        # flatten
        x = x.view(x.size(0), -1)

        # fc
        x = self.fc(x)

        return x

In [16]:
# initialize model
model = EEG_Net_CNN()

# loss function
loss_fn = torch.nn.CrossEntropyLoss()

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Testing function
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f'Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n')

In [None]:
def train(train_dataloader, val_dataloader, model, loss_fn, optimizer, num_epochs, print_every=100):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for batch, (X, y) in enumerate(train_dataloader):
            optimizer.zero_grad()
            pred = model(X)
            loss = loss_fn(pred, y)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            total += y.size(0)

            if batch % print_every == 0 and print_every > 0:
                print(f'Epoch {epoch+1} - Batch {batch+1}/{len(train_dataloader)} - Loss: {loss.item():.4f}')

        avg_train_loss = running_loss / len(train_dataloader)
        avg_train_acc = correct / total
        
        train_losses.append(avg_train_loss)
        train_accuracies.append(avg_train_acc)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for X, y in val_dataloader:
                pred = model(X)
                loss = loss_fn(pred, y)
                
                val_loss += loss.item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
                total += y.size(0)
        
        avg_val_loss = val_loss / len(val_dataloader)
        avg_val_acc = correct / total
        
        val_losses.append(avg_val_loss)
        val_accuracies.append(avg_val_acc)
        
        print(f'Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f} - Train Accuracy: {avg_train_acc:.4f} - Val Loss: {avg_val_loss:.4f} - Val Accuracy: {avg_val_acc:.4f}')

    # Plotting
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.title("Loss")
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.title("Accuracy")
    plt.plot(train_accuracies, label='Train')
    plt.plot(val_accuracies, label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()

    return train_losses, val_losses, train_accuracies, val_accuracies

In [None]:
train_losses, val_losses, train_accuracies, val_accuracies = train(
    train_loader, test_loader, model, loss_fn, optimizer, 10, print_every=100)