# Load Tensors

In [1]:
import torch

# List of tensor paths
tensor_paths = [
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/arnold_tongues_rotation_numbers_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/dspm_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/higuchi_fractal_dimensions_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/Hurst_tensor.pth",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/mfdfa_concatd_tensor.pth",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/mfdfa_tensor.pth",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/short_time_fourier_transform_tensor.pth",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/transfer_entropy_granular_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/transfer_entropy_hemispheric_avg_input_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/transfer_entropy_regional_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/spectral_entropy_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/spectral_centroids_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/freq_max_power_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/spectral_edge_freqs_tensor.pt",
]

# Initialize an empty dictionary to store the tensors and another for their shapes
tensors = {}
tensor_shapes = {}

# Load the tensors into a dictionary and collect their shapes
for path in tensor_paths:
    tensor_name = path.split('/')[-1].replace('.pt', '').replace('.pth', '')

    # Remove the 'h' from the end, if it exists
    if tensor_name.endswith("h"):
        tensor_name = tensor_name[:-1]

    # Load the tensor
    data = torch.load(path)
    tensors[tensor_name] = data

    # Check the type of the loaded data
    if isinstance(data, torch.Tensor):
        tensor_shapes[tensor_name] = data.shape
    elif isinstance(data, dict):  # Likely a state_dict
        tensor_shapes[tensor_name] = "state_dict (model parameters)"
    else:
        tensor_shapes[tensor_name] = "unknown type"

# Print the shapes of all loaded tensors
for name, shape in tensor_shapes.items():
    print(f"{name}: {shape}")


arnold_tongues_rotation_numbers_tensor: torch.Size([32, 300, 300])
dspm_tensor: torch.Size([19, 18840, 10])
higuchi_fractal_dimensions_tensor: torch.Size([1, 1, 4, 8])
Hurst_tensor: torch.Size([1, 1, 32, 1])
mfdfa_concatd_tensor: torch.Size([32, 1, 30, 2])
mfdfa_tensor: torch.Size([9, 32, 10, 1])
short_time_fourier_transform_tensor: torch.Size([32, 1001, 4229])
transfer_entropy_granular_tensor: torch.Size([4, 4])
transfer_entropy_hemispheric_avg_input_tensor: torch.Size([92, 92])
transfer_entropy_regional_tensor: torch.Size([4, 4])
spectral_entropy_tensor: torch.Size([1, 1, 32, 1])
spectral_centroids_tensor: torch.Size([1, 1, 32, 1])
freq_max_power_tensor: torch.Size([1, 1, 32, 1])
spectral_edge_freqs_tensor: torch.Size([1, 1, 32, 1])


# Match dimensions, reshape, and normalize

In [3]:
import torch
import torch.nn.functional as F

def preprocess_and_resize_tensor(tensor, target_shape):
    # Add missing batch and channel dimensions
    while len(tensor.shape) < 4:
        tensor = tensor.unsqueeze(0)

    # Reduce the channel dimension to 1 by taking the mean along that axis
    tensor = torch.mean(tensor, dim=1, keepdim=True)

    # Normalize
    mean = tensor.mean()
    std = tensor.std()
    if std != 0:
        tensor = (tensor - mean) / std

    # Reshape/resize to target_shape
    tensor = F.interpolate(tensor, size=target_shape[2:], mode='bilinear', align_corners=True)
    
    return tensor

arnold_tongues_rotation_numbers_tensor = torch.rand([32, 300, 300])
dspm_tensor = torch.rand([19, 18840, 10])
higuchi_fractal_dimensions_tensor = torch.rand([1, 1, 4, 8])
Hurst_tensor = torch.rand([1, 1, 32, 1])
mfdfa_concatd_tensor = torch.rand([32, 1, 30, 2])
mfdfa_tensor = torch.rand([9, 32, 10, 1])
short_time_fourier_transform_tensor = torch.rand([32, 1001, 4229])
transfer_entropy_granular_tensor = torch.rand([4, 4])
transfer_entropy_hemispheric_avg_input_tensor = torch.rand([92, 92])
transfer_entropy_regional_tensor = torch.rand([4, 4])
spectral_entropy_tensor = torch.rand([1, 1, 32, 1])
spectral_centroids_tensor = torch.rand([1, 1, 32, 1])
freq_max_power_tensor = torch.rand([1, 1, 32, 1])
spectral_edge_freqs_tensor = torch.rand([1, 1, 32, 1])

target_shape = [1, 1, 32, 32]

# List of all your tensors 
all_tensors = [
    arnold_tongues_rotation_numbers_tensor,
    dspm_tensor,
    higuchi_fractal_dimensions_tensor,
    Hurst_tensor,
    mfdfa_concatd_tensor,
    mfdfa_tensor,
    short_time_fourier_transform_tensor,
    transfer_entropy_granular_tensor,
    transfer_entropy_hemispheric_avg_input_tensor,
    transfer_entropy_regional_tensor,
    spectral_entropy_tensor,
    spectral_centroids_tensor,
    freq_max_power_tensor,
    spectral_edge_freqs_tensor,
]

# Preprocess all tensors
processed_tensors = [preprocess_and_resize_tensor(tensor, target_shape) for tensor in all_tensors]

# Print out the new shapes
for i, tensor in enumerate(processed_tensors):
    print(f"Processed tensor {i+1} shape: {tensor.shape}")

Processed tensor 1 shape: torch.Size([1, 1, 32, 32])
Processed tensor 2 shape: torch.Size([1, 1, 32, 32])
Processed tensor 3 shape: torch.Size([1, 1, 32, 32])
Processed tensor 4 shape: torch.Size([1, 1, 32, 32])
Processed tensor 5 shape: torch.Size([32, 1, 32, 32])
Processed tensor 6 shape: torch.Size([9, 1, 32, 32])
Processed tensor 7 shape: torch.Size([1, 1, 32, 32])
Processed tensor 8 shape: torch.Size([1, 1, 32, 32])
Processed tensor 9 shape: torch.Size([1, 1, 32, 32])
Processed tensor 10 shape: torch.Size([1, 1, 32, 32])
Processed tensor 11 shape: torch.Size([1, 1, 32, 32])
Processed tensor 12 shape: torch.Size([1, 1, 32, 32])
Processed tensor 13 shape: torch.Size([1, 1, 32, 32])
Processed tensor 14 shape: torch.Size([1, 1, 32, 32])


# CNN

In [8]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

tensor_names = [
    'arnold_tongues_rotation_numbers_tensor',
    'dspm_tensor',
    'higuchi_fractal_dimensions_tensor',
    'Hurst_tensor',
    'mfdfa_concatd_tensor',
    'mfdfa_tensor',
    'short_time_fourier_transform_tensor',
    'transfer_entropy_granular_tensor',
    'transfer_entropy_hemispheric_avg_input_tensor',
    'transfer_entropy_regional_tensor',
    'spectral_entropy_tensor',
    'spectral_centroids_tensor',
    'freq_max_power_tensor',
    'spectral_edge_freqs_tensor',
]

class BaseEmbeddingNet(nn.Module):
    def __init__(self, input_channels, conv_output_channels, reduce_to_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, conv_output_channels, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(conv_output_channels)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_reduce = nn.Linear(conv_output_channels, reduce_to_dim)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.pool1(x)
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc_reduce(x)
        return x

processed_tensors_dict = {name: tensor for name, tensor in zip(tensor_names, processed_tensors)}

net_params = {name: {'input_channels': 1, 'conv_output_channels': 16, 'reduce_to_dim': 8} 
              for name in processed_tensors_dict.keys()}

# Create BaseEmbeddingNets for each tensor
embedding_nets = {name: BaseEmbeddingNet(**params) for name, params in net_params.items()}

# Move networks to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for net in embedding_nets.values():
    net.to(device)

# Create custom dataset and dataloader
class CustomDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __len__(self):
        first_tensor = next(iter(self.tensors.values()))
        return first_tensor.size(0)

    def __getitem__(self, idx):
        result = {}
        for key, val in self.tensors.items():
            if val.shape[0] > idx:
                result[key] = val[idx]
        return result

def custom_collate(batch):
    collated_batch = {}
    all_keys = set([key for item in batch for key in item.keys()])
    
    for key in all_keys:
        collated_batch[key] = torch.stack([item[key] for item in batch if key in item.keys()], dim=0)
    
    return collated_batch

# Use processed_tensors for your CustomDataset
dataset = CustomDataset(processed_tensors_dict)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=custom_collate)

# Collect feature embeddings
all_features = []
for i, batch in enumerate(dataloader):
    features_list = [net(batch[key].to(device, dtype=torch.float32)) for key, net in embedding_nets.items()]
    concatenated_features = torch.cat(features_list, dim=1)
    all_features.append(concatenated_features.cpu().detach())

# Convert list to tensor
all_features = torch.cat(all_features, dim=0)

# Save the feature embeddings
save_path = '/home/vincent/AAA_projects/MVCS/Neuroscience/Models/Kuramoto'
torch.save(all_features, f'{save_path}/all_features.pt')

# Kuramoto 

In [4]:
torch.cuda.empty_cache()

In [1]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader



In [2]:
class EEGDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        if self.transform:
            sample = self.transform(sample)
        return sample

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchdiffeq import odeint
from torch.cuda.amp import autocast, GradScaler  # Importing the AMP utilities
import numpy as np
from scipy.signal import hilbert
from torch.utils.checkpoint import checkpoint

EEG_data = np.load('/home/vincent/AAA_projects/MVCS/Neuroscience/eeg_data_with_channels.npy', allow_pickle=True)
EEG_tensor = torch.FloatTensor(EEG_data)  # Assumes EEG_data is a NumPy ndarray

# Function to create windows for time-series data
def create_windows(data, window_size, stride):
    windows = []
    for i in range(0, len(data) - window_size, stride):
        windows.append(data[i:i+window_size])
    return torch.stack(windows)

In [4]:
# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
window_size = 50
stride = 10

# Add necessary transformations here to EEG_tensor if required
EEG_tensor = EEG_tensor.clone().detach().to(device)


In [5]:
def apply_hilbert_in_batches(data, batch_size):
    n_batches = int(np.ceil(data.shape[0] / batch_size))
    analytic_signal = np.zeros_like(data, dtype=np.complex64)  # change dtype as needed

    for i in range(n_batches):
        start_idx = i * batch_size
        end_idx = (i + 1) * batch_size
        analytic_signal[start_idx:end_idx, :] = hilbert(data[start_idx:end_idx, :])
        
    return analytic_signal

In [6]:
batch_size = 100  # Set as appropriate

# Apply Hilbert transform in batches
EEG_numpy = EEG_tensor.cpu().numpy()
analytic_signal_batches = apply_hilbert_in_batches(EEG_numpy, batch_size)

# Convert the angle to phases and move to GPU
phases = torch.tensor(np.angle(analytic_signal_batches), dtype=torch.float16).to(device)

# Load PLV matrix
plv_matrix_path = "/home/vincent/AAA_projects/MVCS/Neuroscience/Analysis/Phase Syncronization/plv_matrix.npy"
plv_matrix = torch.tensor(np.load(plv_matrix_path), dtype=torch.float16).to(device)

In [7]:
# Compute_phase_diff_matrix function
def compute_phase_diff_matrix(phases):
    time, channels = phases.shape[:2]
    phase_diff_matrix = torch.zeros(channels, channels, device=phases.device)
    for i in range(channels):
        for j in range(channels):
            phase_diff_matrix[i, j] = torch.mean(phases[:, i] - phases[:, j])
    return phase_diff_matrix

In [8]:
phase_diff_matrix = compute_phase_diff_matrix(phases).to(device)

# EEG channel names
eeg_channel_names = ['Fp1', 'Fpz', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC5', 'FC1', 'FC2', 'FC6',
                     'M1', 'T7', 'C3', 'Cz', 'C4', 'T8', 'M2', 'CP5', 'CP1', 'CP2', 'CP6',
                     'P7', 'P3', 'Pz', 'P4', 'P8', 'POz', 'O1', 'Oz', 'O2']

# Broad regions and corresponding channels
regions = {
    "frontal": ['Fp1', 'Fpz', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8'],
    "temporal": ['T7', 'T8'],
    "parietal": ['CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8'],
    "occipital": ['O1', 'Oz', 'O2']
}

# Precompute omega and phase_diff_matrix
N = len(eeg_channel_names)
omega = torch.mean(plv_matrix, dim=1).to(device)
phase_diff_matrix = compute_phase_diff_matrix(phases).to(device)

In [9]:
# Modify the Kuramoto function to use PyTorch functions instead of NumPy
def kuramoto_weighted_bias(t, y, omega, K):
    weighted_sin = plv_matrix * torch.sin(y - y[:, None] - phase_diff_matrix)
    dydt = omega + K / N * torch.sum(weighted_sin, axis=1)
    return dydt

In [10]:
class KuramotoODEFunc(nn.Module):
    def __init__(self, omega, K, plv_matrix, phase_diff_matrix):
        super(KuramotoODEFunc, self).__init__()
        self.omega = omega
        self.K = K
        self.plv_matrix = plv_matrix
        self.phase_diff_matrix = phase_diff_matrix

    def forward(self, t, theta):
        # Reshape to accommodate the additional time dimension.
        theta = theta.view(-1, theta.shape[-1])
        N = theta.shape[1]
    
        # Compute the phase differences without unsqueezing
        theta_diff = theta[:, :, None] - theta[:, None, :]
        phase_diff_with_matrix = theta_diff - self.phase_diff_matrix
    
        # Compute the weighted sine values
        weighted_sin = self.plv_matrix * torch.sin(phase_diff_with_matrix)
    
        dtheta = self.omega + (self.K / N) * torch.sum(weighted_sin, dim=1)
    
        return dtheta.view(theta.shape)


In [11]:
class KuramotoLayer(nn.Module):
    def __init__(self, oscillator_count, time_steps, dt=0.01, plv_matrix=None, phase_diff_matrix=None):
        super(KuramotoLayer, self).__init__()
        self.oscillator_count = oscillator_count
        self.time_steps = time_steps
        self.dt = dt
        self.plv_matrix = plv_matrix
        self.phase_diff_matrix = phase_diff_matrix

        if plv_matrix is not None:
            omega_init = torch.mean(plv_matrix, dim=1)
            self.omega = nn.Parameter(omega_init, requires_grad=True)
        else:
            self.omega = nn.Parameter(torch.randn(oscillator_count), requires_grad=True)

        self.K = nn.Parameter(torch.tensor(1.0), requires_grad=True)

    def custom_forward(self, *inputs):
        initial_shape = inputs[0].shape  # Store the initial shape
    
        # Flatten the batch and time dimensions
        inputs_flattened = inputs[0].reshape(-1, initial_shape[-1])
        
        ode_func = KuramotoODEFunc(self.omega, self.K, self.plv_matrix, self.phase_diff_matrix)
        time_points = torch.arange(0, 10000 * self.dt, self.dt).to(device)  # Assume device is defined elsewhere
        theta_flattened = odeint(ode_func, inputs_flattened, time_points, method='bosh3', rtol=1e-6, atol=1e-8)

        # Reshape theta to its original shape
        theta = theta_flattened.reshape(*initial_shape, -1)  # -1 will automatically compute the required size
        return theta
        
    def forward(self, theta):
        device = theta.device
        self.plv_matrix = self.plv_matrix.to(device)
        self.phase_diff_matrix = self.phase_diff_matrix.to(device)
        theta = checkpoint(self.custom_forward, theta, self.omega, self.K, self.plv_matrix, self.phase_diff_matrix)
        theta = theta.to(torch.float16)
        mean_coherence = self.calculate_mean_coherence(theta)
        return theta, mean_coherence

    def forward_with_checkpoint(self, x):
        x = x.to(device)
        theta = checkpoint(self.custom_forward, x)
        mean_coherence = self.calculate_mean_coherence(theta)
        return theta, mean_coherence

    @staticmethod
    def calculate_mean_coherence(theta):
        N, _, _, _ = theta.shape
        mean_coherence = torch.mean(torch.cos(theta[:, -1, :] - theta[:, -1, :].mean(dim=1).unsqueeze(1)))
        return mean_coherence

In [12]:
# Make sure all tensors are on the correct device
phases = phases.to(device)

# Compute natural frequencies and phase differences just once
num_channels = len(eeg_channel_names)  # Get the number of channels
#print("Theta shape: ", theta.shape)
#print("Theta Unsqueeze(1) shape: ", theta.unsqueeze(1).shape)
#print("Theta Unsqueeze(2) shape: ", theta.unsqueeze(2).shape)

# Number of channels
N = len(eeg_channel_names)

# Initialize model and move to device
kuramoto_model = KuramotoLayer(N, 12800, plv_matrix=plv_matrix, phase_diff_matrix=phase_diff_matrix).to(dtype=torch.float16).to(device)

# Data Parallelism for multiple GPUs
if torch.cuda.device_count() > 1:
    kuramoto_model = nn.DataParallel(kuramoto_model)

scaler = GradScaler()
train_data = create_windows(EEG_tensor[:int(0.7 * len(EEG_tensor))], window_size, stride).detach().requires_grad_(True)
train_dataset = EEGDataset(data=train_data)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=False)

In [None]:
# Training loop and feature extraction
kuramoto_features_list = []
for i, batch in enumerate(train_loader):
    # Moves batch to device and changes dtype to float16
    batch = batch.to(device, dtype=torch.float16)

    # Using autocast for the forward pass
    with autocast():
        theta, mean_coherence = kuramoto_model(batch)

    kuramoto_features_list.append(mean_coherence)


In [None]:
# Save the combined features
kuramoto_features_tensor = torch.stack(kuramoto_features_list)
all_features_path = '/home/vincent/AAA_projects/MVCS/Neuroscience/Models/Kuramoto/all_features.pt'
all_features = torch.load(all_features_path)
combined_features = torch.cat([all_features, kuramoto_features_tensor.unsqueeze(1)], dim=1)
combined_features_path = '/home/vincent/AAA_projects/MVCS/Neuroscience/Models/Transformer/combined_features.pt'
torch.save(combined_features, combined_features_path)

In [16]:
import torch

# List of tensor paths
tensor_paths = [
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/Transformer/band_power_tensor.pth",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/Transformer/EEG_tensor.pth",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/Transformer/fast_fourier_transform_psd_tensor.pth",
]

# Load tensors and print their shapes
loaded_tensors = {}
for path in tensor_paths:
    tensor_name = path.split("/")[-1].replace(".pth", "")
    tensor = torch.load(path)
    print(f"Shape of {tensor_name}: {tensor.shape}")
    loaded_tensors[tensor_name] = tensor

Shape of band_power_tensor: torch.Size([4227788, 32, 5])
Shape of EEG_tensor: torch.Size([1, 32, 1, 4227788])
Shape of fast_fourier_transform_psd_tensor: torch.Size([32, 4227788])


# Transformer block

In [22]:
# Initialize the model
d_model = 128
nhead = 8
num_layers = 2
dim_feedforward = 512

batch_size = 64
seq_len = 500  # truncated sequence length
feature_dim = 32

class EEGPredictor(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward, sequence_length):
        super(EEGPredictor, self).__init__()
        self.feature_transform = nn.Linear(feature_dim, d_model)  # Transform from 32 to d_model
        self.transformer_block = TransformerBlock(d_model, nhead, num_layers, dim_feedforward, seq_len)
        self.prediction_head = nn.Linear(d_model, 1)
        
    def forward(self, x):
        x = self.feature_transform(x)  # Add this line
        x = self.transformer_block(x)
        x = self.prediction_head(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward, sequence_length):
        super(TransformerBlock, self).__init__()
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward)
        self.transformer = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        self.pos_encoder = nn.Embedding(seq_len, d_model)
        self.position = torch.arange(0, seq_len, dtype=torch.long).unsqueeze(1)
        
    def forward(self, x):
        pos_encoding = self.pos_encoder(self.position[:x.size(0), :])
        x = x + pos_encoding
        return self.transformer(x)
redictor(d_model, nhead, num_layers, dim_feedforward, seq_len)

# Prepare the EEG tensor, removing singleton dimensions
eeg_tensor = loaded_tensors["EEG_tensor"].squeeze()  # Likely becomes [32, 4227788]
print("Squeezed shape:", eeg_tensor.shape)

# Since the tensor is 2D, permute using only two dimensions
eeg_tensor = eeg_tensor.permute(1, 0)  # This would swap the dimensions making it [4227788, 32]
print("Permuted shape:", eeg_tensor.shape)

num_batches = eeg_tensor.shape[0] // (batch_size * seq_len)

# To store the transformer outputs
transformer_outputs = []

for i in range(num_batches):
    start_idx = i * batch_size * seq_len
    end_idx = start_idx + (batch_size * seq_len)
    
    # Extract batch and reshape to [seq_len, batch_size, feature_dim]
    batch = eeg_tensor[start_idx:end_idx, :].reshape(seq_len, batch_size, feature_dim)
    
    # Forward pass
    output = model(batch)
    print("Batch output:", output.shape)
    
    # Save the transformer output for use in RNN
    transformer_outputs.append(output.detach())

# Convert list of outputs to a tensor (or keep as a list if that's easier for your application)
transformer_outputs = torch.stack(transformer_outputs)


Squeezed shape: torch.Size([32, 4227788])
Permuted shape: torch.Size([4227788, 32])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch.Size([500, 64, 1])
Batch output: torch

# RNN block

In [24]:
# Define an RNN model
class RNNModel(nn.Module):
    def __init__(self, rnn_input_size, rnn_hidden_size, rnn_num_layers):
        super(RNNModel, self).__init__()
        self.rnn = nn.LSTM(input_size=rnn_input_size, hidden_size=rnn_hidden_size, num_layers=rnn_num_layers)
        self.prediction_head = nn.Linear(rnn_hidden_size, 1)

    def forward(self, x):
        x, _ = self.rnn(x)
        x = self.prediction_head(x[-1])
        return x

# Initialize the RNN model
rnn_model = RNNModel(d_model, rnn_hidden_size=64, rnn_num_layers=1)

# Use the transformer_outputs as input for the RNN model
# You might need to reshape transformer_outputs to fit the expected input shape of the RNN
rnn_output = rnn_model(transformer_outputs)

NameError: name 'TransformerModel' is not defined

# Complete EEG Predictor

In [None]:
class CompleteEEGPredictor(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, dim_feedforward):
        super(CompleteEEGPredictor, self).__init__()
        
        # Initialize all feature embedding networks
        self.eeg_embedding = EEGEmbeddingNet()
        self.rotation_embedding = RotationEmbeddingNet()
        self.band_power_embedding = BandPowerEmbeddingNet()
        self.dspm_embedding = DSPMEmbeddingNet()
        self.fast_fourier_embedding = FastFourierEmbeddingNet()
        self.higuchi_fractal_embedding = HiguchiFractalEmbeddingNet()
        self.hurst_embedding = HurstEmbeddingNet()
        self.mfdfa_concatd_embedding = MFDFAConcatdEmbeddingNet()
        self.mfdfa_embedding = MFDFAEmbeddingNet()
        self.short_time_fourier_embedding = ShortTimeFourierEmbeddingNet()
        self.spectral_entropy_embedding = SpectralEntropyEmbeddingNet()
        self.spectral_centroids_embedding = SpectralCentroidsEmbeddingNet()
        self.freq_max_power_embedding = FreqMaxPowerEmbeddingNet()
        self.spectral_edge_freqs_embedding = SpectralEdgeFreqsEmbeddingNet()
        self.pairwise_measure_net = PairwiseMeasureNet(input_channels=32, output_channels=64)
        
        # Initialize Kuramoto layer
        self.kuramoto = KuramotoLayer(oscillator_count=32, time_steps=100)
        
        # Initialize Transformer block
        self.transformer_block = TransformerBlock(d_model, nhead, num_encoder_layers, dim_feedforward)
        
        # Initialize RNN block
        self.rnn_block = RNNBlock(input_size=d_model, hidden_size=256, num_layers=2, dropout=0.5)
        
    def forward(self, src):
        # Feature extraction using all the embedding networks
        embeddings = [
            self.eeg_embedding(src),
            self.rotation_embedding(src),
            self.band_power_embedding(src),
            self.dspm_embedding(src),
            self.fast_fourier_embedding(src),
            self.higuchi_fractal_embedding(src),
            self.hurst_embedding(src),
            self.mfdfa_concatd_embedding(src),
            self.mfdfa_embedding(src),
            self.short_time_fourier_embedding(src),
            self.spectral_entropy_embedding(src),
            self.spectral_centroids_embedding(src),
            self.freq_max_power_embedding(src),
            self.spectral_edge_freqs_embedding(src),
            self.pairwise_measure_net(src)
        ]
        
        # Concatenating the embeddings
        combined_features = torch.cat(embeddings, dim=-1)
        
        # Kuramoto layer
        src_kuramoto = self.kuramoto(combined_features)
        
        # Transformer block
        src_transformed = self.transformer_block(src_kuramoto)
        
        # RNN block
        output = self.rnn_block(src_transformed)
        
        return output