In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import numpy as np

# Define a custom transformer model for regression using a simple transformer architecture
class EEGTransformerForRegression(nn.Module):
    def __init__(self, input_dim, model_dim=64, num_labels=1, num_heads=8, num_layers=4):
        super(EEGTransformerForRegression, self).__init__()

        # Transformer layers
        self.embedding = nn.Linear(input_dim, model_dim)
        self.positional_encoding = nn.Parameter(torch.randn(1, 512, model_dim))  # Max sequence length = 512
        self.transformer = nn.Transformer(d_model=model_dim, nhead=num_heads, num_encoder_layers=num_layers)

        # Regression head
        self.fc = nn.Linear(model_dim, num_labels)

    def forward(self, x):
        # x shape: (batch_size, seq_len, input_dim)
        seq_len = x.size(1)

        # Embed input into the transformer model's dimension
        x = self.embedding(x)

        # Add positional encoding (to give the transformer a sense of order)
        x = x + self.positional_encoding[:, :seq_len, :]

        # Permute the input to (seq_len, batch_size, model_dim)
        x = x.permute(1, 0, 2)

        # Pass through transformer
        x = self.transformer(x, x)

        # Use the output corresponding to the [CLS] token (first token)
        x = x[0, :, :]

        # Pass through the regression head
        x = self.fc(x)

        return x

# Function to preprocess EEG data
def preprocess_eeg_data(eeg_data, sequence_length=128):
    # Assume eeg_data is of shape (samples, channels, time_steps)
    # Reshape it into (samples, time_steps * channels) and normalize it
    num_samples, num_channels, num_time_points = eeg_data.shape
    processed_data = eeg_data.reshape(num_samples, num_channels, num_time_points).transpose(1, 2)

    # Optional: Normalize the data (for example, using min-max normalization)
    processed_data = (processed_data - processed_data.min()) / (processed_data.max() - processed_data.min())

    return processed_data

# Load EEG data and labels (use your actual EEG data here)
eeg_data = np.random.randn(1000, 32, 128)  # 1000 samples, 32 channels, 128 time points
labels = np.random.randn(1000, 1)  # 1000 continuous regression labels

# Preprocess the EEG data
processed_eeg_data = preprocess_eeg_data(eeg_data)

# Convert to PyTorch tensors
inputs = torch.tensor(processed_eeg_data, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.float32)

# Split into train and validation sets
train_inputs, val_inputs, train_labels, val_labels = train_test_split(inputs, labels, test_size=0.2)

# Create DataLoader for batching
train_dataset = TensorDataset(train_inputs, train_labels)
val_dataset = TensorDataset(val_inputs, val_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Initialize the model
input_dim = train_inputs.shape[1]  # Time steps
model = EEGTransformerForRegression(input_dim=input_dim)

# Set up optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

# Training loop
epochs = 5
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        optimizer.zero_grad()

        inputs, labels = batch

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader):.4f}")

# Evaluation
model.eval()
with torch.no_grad():
    all_preds = []
    all_labels = []
    for batch in val_loader:
        inputs, labels = batch
        preds = model(inputs)
        all_preds.append(preds)
        all_labels.append(labels)

    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # Calculate Mean Squared Error
    mse = mean_squared_error(all_labels.numpy(), all_preds.numpy())
    print(f"Validation MSE: {mse:.4f}")


ValueError: axes don't match array