# Load Tensors

In [1]:
import torch

# List of tensor paths
tensor_paths = [
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/arnold_tongues_rotation_numbers_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/dspm_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/higuchi_fractal_dimensions_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/Hurst_tensor.pth",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/mfdfa_concatd_tensor.pth",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/mfdfa_tensor.pth",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/short_time_fourier_transform_tensor.pth",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/transfer_entropy_granular_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/transfer_entropy_hemispheric_avg_input_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/transfer_entropy_regional_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/spectral_entropy_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/spectral_centroids_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/freq_max_power_tensor.pt",
    "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/spectral_edge_freqs_tensor.pt",
]

# Initialize an empty dictionary to store the tensors and another for their shapes
tensors = {}
tensor_shapes = {}

# Load the tensors into a dictionary and collect their shapes
for path in tensor_paths:
    tensor_name = path.split('/')[-1].replace('.pt', '').replace('.pth', '')

    # Remove the 'h' from the end, if it exists
    if tensor_name.endswith("h"):
        tensor_name = tensor_name[:-1]

    # Load the tensor
    data = torch.load(path)
    tensors[tensor_name] = data

    # Check the type of the loaded data
    if isinstance(data, torch.Tensor):
        tensor_shapes[tensor_name] = data.shape
    elif isinstance(data, dict):  # Likely a state_dict
        tensor_shapes[tensor_name] = "state_dict (model parameters)"
    else:
        tensor_shapes[tensor_name] = "unknown type"

# Print the shapes of all loaded tensors
for name, shape in tensor_shapes.items():
    print(f"{name}: {shape}")


arnold_tongues_rotation_numbers_tensor: torch.Size([32, 300, 300])
dspm_tensor: torch.Size([19, 18840, 10])
higuchi_fractal_dimensions_tensor: torch.Size([1, 1, 4, 8])
Hurst_tensor: torch.Size([1, 1, 32, 1])
mfdfa_concatd_tensor: torch.Size([32, 1, 30, 2])
mfdfa_tensor: torch.Size([9, 32, 10, 1])
short_time_fourier_transform_tensor: torch.Size([32, 1001, 4229])
transfer_entropy_granular_tensor: torch.Size([4, 4])
transfer_entropy_hemispheric_avg_input_tensor: torch.Size([92, 92])
transfer_entropy_regional_tensor: torch.Size([4, 4])
spectral_entropy_tensor: torch.Size([1, 1, 32, 1])
spectral_centroids_tensor: torch.Size([1, 1, 32, 1])
freq_max_power_tensor: torch.Size([1, 1, 32, 1])
spectral_edge_freqs_tensor: torch.Size([1, 1, 32, 1])


# Match dimensions, reshape, and normalize

In [3]:
import torch
import torch.nn.functional as F

def preprocess_and_resize_tensor(tensor, target_shape):
    # Add missing batch and channel dimensions
    while len(tensor.shape) < 4:
        tensor = tensor.unsqueeze(0)

    # Reduce the channel dimension to 1 by taking the mean along that axis
    tensor = torch.mean(tensor, dim=1, keepdim=True)

    # Normalize
    mean = tensor.mean()
    std = tensor.std()
    if std != 0:
        tensor = (tensor - mean) / std

    # Reshape/resize to target_shape
    tensor = F.interpolate(tensor, size=target_shape[2:], mode='bilinear', align_corners=True)
    
    return tensor


# Your tensor shapes here are just for illustration
# Replace these with your actual loaded tensors
arnold_tongues_rotation_numbers_tensor = torch.rand([32, 300, 300])
dspm_tensor = torch.rand([19, 18840, 10])
higuchi_fractal_dimensions_tensor = torch.rand([1, 1, 4, 8])
Hurst_tensor = torch.rand([1, 1, 32, 1])
mfdfa_concatd_tensor = torch.rand([32, 1, 30, 2])
mfdfa_tensor = torch.rand([9, 32, 10, 1])
short_time_fourier_transform_tensor = torch.rand([32, 1001, 4229])
transfer_entropy_granular_tensor = torch.rand([4, 4])
transfer_entropy_hemispheric_avg_input_tensor = torch.rand([92, 92])
transfer_entropy_regional_tensor = torch.rand([4, 4])
spectral_entropy_tensor = torch.rand([1, 1, 32, 1])
spectral_centroids_tensor = torch.rand([1, 1, 32, 1])
freq_max_power_tensor = torch.rand([1, 1, 32, 1])
spectral_edge_freqs_tensor = torch.rand([1, 1, 32, 1])

target_shape = [1, 1, 32, 32]

# List of all your tensors (use your actual tensors here)
all_tensors = [
    arnold_tongues_rotation_numbers_tensor,
    dspm_tensor,
    higuchi_fractal_dimensions_tensor,
    Hurst_tensor,
    mfdfa_concatd_tensor,
    mfdfa_tensor,
    short_time_fourier_transform_tensor,
    transfer_entropy_granular_tensor,
    transfer_entropy_hemispheric_avg_input_tensor,
    transfer_entropy_regional_tensor,
    spectral_entropy_tensor,
    spectral_centroids_tensor,
    freq_max_power_tensor,
    spectral_edge_freqs_tensor,
]

# Preprocess all tensors
processed_tensors = [preprocess_and_resize_tensor(tensor, target_shape) for tensor in all_tensors]

# Print out the new shapes
for i, tensor in enumerate(processed_tensors):
    print(f"Processed tensor {i+1} shape: {tensor.shape}")

Processed tensor 1 shape: torch.Size([1, 1, 32, 32])
Processed tensor 2 shape: torch.Size([1, 1, 32, 32])
Processed tensor 3 shape: torch.Size([1, 1, 32, 32])
Processed tensor 4 shape: torch.Size([1, 1, 32, 32])
Processed tensor 5 shape: torch.Size([32, 1, 32, 32])
Processed tensor 6 shape: torch.Size([9, 1, 32, 32])
Processed tensor 7 shape: torch.Size([1, 1, 32, 32])
Processed tensor 8 shape: torch.Size([1, 1, 32, 32])
Processed tensor 9 shape: torch.Size([1, 1, 32, 32])
Processed tensor 10 shape: torch.Size([1, 1, 32, 32])
Processed tensor 11 shape: torch.Size([1, 1, 32, 32])
Processed tensor 12 shape: torch.Size([1, 1, 32, 32])
Processed tensor 13 shape: torch.Size([1, 1, 32, 32])
Processed tensor 14 shape: torch.Size([1, 1, 32, 32])


# CNN

In [8]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

tensor_names = [
    'arnold_tongues_rotation_numbers_tensor',
    'dspm_tensor',
    'higuchi_fractal_dimensions_tensor',
    'Hurst_tensor',
    'mfdfa_concatd_tensor',
    'mfdfa_tensor',
    'short_time_fourier_transform_tensor',
    'transfer_entropy_granular_tensor',
    'transfer_entropy_hemispheric_avg_input_tensor',
    'transfer_entropy_regional_tensor',
    'spectral_entropy_tensor',
    'spectral_centroids_tensor',
    'freq_max_power_tensor',
    'spectral_edge_freqs_tensor',
]

class BaseEmbeddingNet(nn.Module):
    def __init__(self, input_channels, conv_output_channels, reduce_to_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, conv_output_channels, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(conv_output_channels)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_reduce = nn.Linear(conv_output_channels, reduce_to_dim)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.pool1(x)
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc_reduce(x)
        return x

processed_tensors_dict = {name: tensor for name, tensor in zip(tensor_names, processed_tensors)}

net_params = {name: {'input_channels': 1, 'conv_output_channels': 16, 'reduce_to_dim': 8} 
              for name in processed_tensors_dict.keys()}

# Create BaseEmbeddingNets for each tensor
embedding_nets = {name: BaseEmbeddingNet(**params) for name, params in net_params.items()}

# Move networks to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for net in embedding_nets.values():
    net.to(device)

# Create custom dataset and dataloader
class CustomDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __len__(self):
        first_tensor = next(iter(self.tensors.values()))
        return first_tensor.size(0)

    def __getitem__(self, idx):
        result = {}
        for key, val in self.tensors.items():
            if val.shape[0] > idx:
                result[key] = val[idx]
        return result

def custom_collate(batch):
    collated_batch = {}
    all_keys = set([key for item in batch for key in item.keys()])
    
    for key in all_keys:
        collated_batch[key] = torch.stack([item[key] for item in batch if key in item.keys()], dim=0)
    
    return collated_batch

# Use processed_tensors for your CustomDataset
dataset = CustomDataset(processed_tensors_dict)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=custom_collate)

# Collect feature embeddings
all_features = []
for i, batch in enumerate(dataloader):
    features_list = [net(batch[key].to(device, dtype=torch.float32)) for key, net in embedding_nets.items()]
    concatenated_features = torch.cat(features_list, dim=1)
    all_features.append(concatenated_features.cpu().detach())

# Convert list to tensor
all_features = torch.cat(all_features, dim=0)

# Save the feature embeddings
save_path = '/home/vincent/AAA_projects/MVCS/Neuroscience/Models/Kuramoto'
torch.save(all_features, f'{save_path}/all_features.pt')

# Train Test Validation

In [1]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

# Load EEG data
EEG_data = np.load('/home/vincent/AAA_projects/MVCS/Neuroscience/eeg_data_with_channels.npy', allow_pickle=True)
EEG_tensor = torch.FloatTensor(EEG_data)  # Assumes EEG_data is a NumPy ndarray

window_size = 100  # You can adjust this
stride = 10  # This is also adjustable

# Create windows
def create_windows(data, window_size, stride):
    windows = []
    for i in range(0, len(data) - window_size, stride):
        windows.append(data[i:i+window_size])
    return torch.stack(windows)

class EEGDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        if self.transform:
            sample = self.transform(sample)
        return sample

# Create windowed data
train_data = create_windows(EEG_tensor[:int(0.7 * len(EEG_tensor))], window_size, stride)
val_data = create_windows(EEG_tensor[int(0.7 * len(EEG_tensor)):int(0.85 * len(EEG_tensor))], window_size, stride)
test_data = create_windows(EEG_tensor[int(0.85 * len(EEG_tensor)):], window_size, stride)

# Create Datasets and DataLoaders
train_dataset = EEGDataset(data=train_data)
val_dataset = EEGDataset(data=val_data)
test_dataset = EEGDataset(data=test_data)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False)  # Set shuffle=False for time series data
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Kuramoto layer

In [14]:
import torch
import torch.nn as nn
from torchdiffeq import odeint
from torch.utils.data import DataLoader
import numpy as np
from scipy.signal import hilbert

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load EEG data
EEG_data_path = "/home/vincent/AAA_projects/MVCS/Neuroscience/Models/CNN/EEG_tensor.pth"
EEG_data = torch.load(EEG_data_path).to(device)

# Hilbert Transform for analytical signal
analytic_signal = hilbert(EEG_data.cpu().numpy())
phases = torch.tensor(np.angle(analytic_signal), dtype=torch.float32).to(device)

# EEG channel names
eeg_channel_names = ['Fp1', 'Fpz', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC5', 'FC1', 'FC2', 'FC6',
                     'M1', 'T7', 'C3', 'Cz', 'C4', 'T8', 'M2', 'CP5', 'CP1', 'CP2', 'CP6',
                     'P7', 'P3', 'Pz', 'P4', 'P8', 'POz', 'O1', 'Oz', 'O2']

# Broad regions and corresponding channels
regions = {
    "frontal": ['Fp1', 'Fpz', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8'],
    "temporal": ['T7', 'T8'],
    "parietal": ['CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8'],
    "occipital": ['O1', 'Oz', 'O2']
}

# Load PLV matrix
plv_matrix = torch.tensor(np.load("/home/vincent/AAA_projects/MVCS/Neuroscience/Analysis/Phase Syncronization/plv_matrix.npy"), dtype=torch.float32).to(device)

# Compute_phase_diff_matrix function
def compute_phase_diff_matrix(phases, num_channels):
    # Assuming phases is shape (Time, Channels), and you want (Channels, Channels)
    phase_diff_matrix = torch.mean(phases.unsqueeze(1) - phases.unsqueeze(0), dim=2)
    return phase_diff_matrix


# Precompute omega and phase_diff_matrix
N = len(eeg_channel_names)
omega = torch.mean(plv_matrix, dim=1).to(device)
phase_diff_matrix = compute_phase_diff_matrix(phases, N).to(device)  # Added N as the second argument

# Modify the Kuramoto function to use PyTorch functions instead of NumPy
def kuramoto_weighted_bias(t, y, omega, K):
    weighted_sin = plv_matrix * torch.sin(y - y[:, None] - phase_diff_matrix)
    dydt = omega + K / N * torch.sum(weighted_sin, axis=1)
    return dydt

class KuramotoODEFunc(nn.Module):
    def __init__(self, omega, K, plv_matrix, phase_diff_matrix):
        super(KuramotoODEFunc, self).__init__()
        self.omega = omega
        self.K = K
        self.plv_matrix = plv_matrix
        self.phase_diff_matrix = phase_diff_matrix

    def forward(self, t, theta):
        # Reshape to accommodate the additional time dimension.
        theta = theta.view(-1, theta.shape[-1])
        N = theta.shape[1]
        print("Theta shape: ", theta.shape)
        print("PLV Matrix shape: ", self.plv_matrix.shape)
        print("Phase Diff Matrix shape: ", self.phase_diff_matrix.shape)
        weighted_sin = self.plv_matrix * torch.sin(theta.unsqueeze(1) - theta.unsqueeze(2) - self.phase_diff_matrix)
        dtheta = self.omega + (self.K / N) * torch.sum(weighted_sin, dim=1)
        return dtheta.view(theta.shape)

class KuramotoLayer(nn.Module):
    def __init__(self, oscillator_count, time_steps, dt=0.01, plv_matrix=None, phase_diff_matrix=None):
        super(KuramotoLayer, self).__init__()
        self.oscillator_count = oscillator_count
        self.time_steps = time_steps
        self.dt = dt
        self.plv_matrix = plv_matrix
        self.phase_diff_matrix = phase_diff_matrix
        
        # Initialize omega based on PLV if available
        if plv_matrix is not None:
            omega_init = torch.mean(plv_matrix, dim=1)  
            self.omega = nn.Parameter(omega_init)

        else:
            self.omega = nn.Parameter(torch.randn(oscillator_count))
        
        self.K = nn.Parameter(torch.tensor(1.0))  # global coupling strength

    def forward(self, x):
        x = x.to(device) 
        ode_func = KuramotoODEFunc(self.omega, self.K, self.plv_matrix, self.phase_diff_matrix)
        time_points = torch.arange(0, self.time_steps * self.dt, self.dt)
        theta = odeint(ode_func, x, time_points)
        mean_coherence = self.calculate_mean_coherence(theta)
        return theta, mean_coherence

    @staticmethod
    def calculate_mean_coherence(theta):
        # Calculate phase coherence across oscillators
        N, _, _ = theta.shape
        mean_coherence = torch.mean(torch.cos(theta[:, -1, :] - theta[:, -1, :].mean(dim=1).unsqueeze(1)))
        return mean_coherence

# Number of channels
N = len(eeg_channel_names)

# Make sure all tensors are on the correct device
phases = phases.to(device)

# Compute natural frequencies and phase differences just once
omega = torch.mean(plv_matrix, dim=1).to(device)
num_channels = len(eeg_channel_names)  # Get the number of channels
phase_diff_matrix = compute_phase_diff_matrix(phases, num_channels).to(device)

# Initialize model and move to device
kuramoto_model = KuramotoLayer(N, 100, plv_matrix=plv_matrix, phase_diff_matrix=phase_diff_matrix).to(device)

# Data Parallelism for multiple GPUs
if torch.cuda.device_count() > 1:
    kuramoto_model = nn.DataParallel(kuramoto_model)

# Placeholder for training data
train_data = create_windows(EEG_tensor[:int(0.7 * len(EEG_tensor))], window_size, stride)
train_dataset = EEGDataset(data=train_data)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False)

# Extract Kuramoto features
kuramoto_features_list = []
for i, batch in enumerate(train_loader):
    theta, mean_coherence = kuramoto_model(batch.to(device))
    kuramoto_features_list.append(mean_coherence)

# Save the combined features
kuramoto_features_tensor = torch.stack(kuramoto_features_list)
all_features = torch.load('/home/vincent/AAA_projects/MVCS/Neuroscience/Models/Kuramoto/all_features.pt')
combined_features = torch.cat([all_features, kuramoto_features_tensor.unsqueeze(1)], dim=1)
torch.save(combined_features, '/home/vincent/AAA_projects/MVCS/Neuroscience/Models/Transformer/combined_features.pt')

Theta shape:  torch.Size([6400, 32])
PLV Matrix shape:  torch.Size([32, 32])
Phase Diff Matrix shape:  torch.Size([1, 1, 1, 4227788])
Theta shape:  torch.Size([6400, 32])
PLV Matrix shape:  torch.Size([32, 32])
Phase Diff Matrix shape:  torch.Size([1, 1, 1, 4227788])




RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/vincent/miniconda3/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/vincent/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_59630/2060336086.py", line 93, in forward
    theta = odeint(ode_func, x, time_points)
  File "/home/vincent/miniconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint
    solution = solver.integrate(t)
  File "/home/vincent/miniconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 28, in integrate
    self._before_integrate(t)
  File "/home/vincent/miniconda3/lib/python3.10/site-packages/torchdiffeq/_impl/rk_common.py", line 161, in _before_integrate
    f0 = self.func(t[0], self.y0)
  File "/home/vincent/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/vincent/miniconda3/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 189, in forward
    return self.base_func(t, y)
  File "/home/vincent/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_59630/2060336086.py", line 66, in forward
    weighted_sin = self.plv_matrix * torch.sin(theta.unsqueeze(1) - theta.unsqueeze(2) - self.phase_diff_matrix)
RuntimeError: The size of tensor a (32) must match the size of tensor b (4227788) at non-singleton dimension 3


# Transformer block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward):
        super(TransformerBlock, self).__init__()
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward)
        self.transformer = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        self.pos_encoder = nn.Embedding(4227788, d_model)  # Adjust this based on your data
        self.position = torch.arange(0, 4227788, dtype=torch.long).unsqueeze(1)
        
    def forward(self, x):
        pos_encoding = self.pos_encoder(self.position[:x.size(0), :])
        x = x + pos_encoding
        return self.transformer(x)

# RNN block

In [None]:
class RNNBlock(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout):
        super(RNNBlock, self).__init__()
        self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout)
        self.fc_out = nn.Linear(hidden_size, 32)  # 32 channels of EEG data
        
    def forward(self, x):
        rnn_out, _ = self.rnn(x)
        return self.fc_out(rnn_out)

# Complete EEG Predictor

In [None]:
class CompleteEEGPredictor(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, dim_feedforward):
        super(CompleteEEGPredictor, self).__init__()
        
        # Initialize all feature embedding networks
        self.eeg_embedding = EEGEmbeddingNet()
        self.rotation_embedding = RotationEmbeddingNet()
        self.band_power_embedding = BandPowerEmbeddingNet()
        self.dspm_embedding = DSPMEmbeddingNet()
        self.fast_fourier_embedding = FastFourierEmbeddingNet()
        self.higuchi_fractal_embedding = HiguchiFractalEmbeddingNet()
        self.hurst_embedding = HurstEmbeddingNet()
        self.mfdfa_concatd_embedding = MFDFAConcatdEmbeddingNet()
        self.mfdfa_embedding = MFDFAEmbeddingNet()
        self.short_time_fourier_embedding = ShortTimeFourierEmbeddingNet()
        self.spectral_entropy_embedding = SpectralEntropyEmbeddingNet()
        self.spectral_centroids_embedding = SpectralCentroidsEmbeddingNet()
        self.freq_max_power_embedding = FreqMaxPowerEmbeddingNet()
        self.spectral_edge_freqs_embedding = SpectralEdgeFreqsEmbeddingNet()
        self.pairwise_measure_net = PairwiseMeasureNet(input_channels=32, output_channels=64)
        
        # Initialize Kuramoto layer
        self.kuramoto = KuramotoLayer(oscillator_count=32, time_steps=100)
        
        # Initialize Transformer block
        self.transformer_block = TransformerBlock(d_model, nhead, num_encoder_layers, dim_feedforward)
        
        # Initialize RNN block
        self.rnn_block = RNNBlock(input_size=d_model, hidden_size=256, num_layers=2, dropout=0.5)
        
    def forward(self, src):
        # Feature extraction using all the embedding networks
        embeddings = [
            self.eeg_embedding(src),
            self.rotation_embedding(src),
            self.band_power_embedding(src),
            self.dspm_embedding(src),
            self.fast_fourier_embedding(src),
            self.higuchi_fractal_embedding(src),
            self.hurst_embedding(src),
            self.mfdfa_concatd_embedding(src),
            self.mfdfa_embedding(src),
            self.short_time_fourier_embedding(src),
            self.spectral_entropy_embedding(src),
            self.spectral_centroids_embedding(src),
            self.freq_max_power_embedding(src),
            self.spectral_edge_freqs_embedding(src),
            self.pairwise_measure_net(src)
        ]
        
        # Concatenating the embeddings
        combined_features = torch.cat(embeddings, dim=-1)
        
        # Kuramoto layer
        src_kuramoto = self.kuramoto(combined_features)
        
        # Transformer block
        src_transformed = self.transformer_block(src_kuramoto)
        
        # RNN block
        output = self.rnn_block(src_transformed)
        
        return output