# ITLPPD Experiments

### Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from sklearn.decomposition import PCA 
from itertools import product

## Models

Below are the architectures.

The Autoencoder encodes the buffered data windows.

The Information Theoretic Learning Point Process Detector is an architectural extension of the Autoencoder, that uses the expected value of the latent space to correct the learning process. Interpreting the latent space can additionally be used to generate a point process - when the batch is unexpected, it can be flagged (and generate a corresponding point process). 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA

# Autoencoder Model
class Autoencoder(nn.Module):
    def __init__(self, D, C, M_C):
        '''
        Hyperparameters:
        ---------------
        D: window_size
        C: compression factor, 2**C
        M_C: minimum compression

        Attributes:
        -----------
        encoder: Encoder part of the autoencoder
        decoder: Decoder part of the autoencoder
        LD: Latent dimension
        num_compressions: Number of compression layers
        compressions: List of compression sizes
        latent_code: Latent code of the input
        -----------
        '''
        super(Autoencoder, self).__init__()

        if D is None:
            raise ValueError("window_size must be provided")
        self.D = D
        if C is None:
            raise ValueError("compression factor must be provided")
        self.C = C
        if M_C is None:
            raise ValueError("minimum compression must be provided")
        self.M_C = M_C

        self.LD = None

        # Calculate compression sizes
        self.compressions = [D // (2**C)]
        while self.compressions[-1] >= M_C:
            s_c = self.compressions[-1] // (2**C)
            if s_c < M_C:
                break
            else:
                self.compressions.append(s_c)
                self.LD = s_c

        self.num_compressions = len(self.compressions)
        self.latent_code = None

        # Encoder
        self.encoder = nn.Sequential()
        self.encoder.add_module("encoder_0", nn.Linear(D, self.compressions[0]))
        for i in range(self.num_compressions - 1):
            self.encoder.add_module(f"encoder_{i+1}", nn.Linear(self.compressions[i], self.compressions[i + 1]))

        # Decoder
        self.decoder = nn.Sequential()
        for i in range(self.num_compressions - 1, 0, -1):
            self.decoder.add_module(f"decoder_{i}", nn.Linear(self.compressions[i], self.compressions[i - 1]))
        self.decoder.add_module("decoder_0", nn.Linear(self.compressions[0], D))

    def forward(self, x):
        x = self.encoder(x)
        self.latent_code = x.detach()  # Store latent code
        x = self.decoder(x)
        return x

# ITLPPD Model
class ITLPPD:
    def __init__(self, D=None, C=None, M_C=None, L=None, lr=None, sigma=None, lambda_reg=None, delta=None, device=None):
        '''
        Hyperparameters:
        ---------------
        D: window_size
        C: compression factor, 2**C
        M_C: minimum compression
        L: number of windows/lags
        lr: Learning rate for the optimizer
        sigma: Standard deviation for the Gaussian function
        lambda_reg: Weight for the regularization term
        delta: Threshold value for the regularizer to trigger point process events
        device: Device to run the model on (e.g., "cuda" or "cpu")

        Attributes:
        -----------
        model: Instance of the Autoencoder class
        Zs: Buffer of the previous L latent codes (as a tensor)
        latent_history: List that stores latent codes at each training step
        spike_train_history: List that stores spike (point process) events per sample
        criterion: Loss function (MSE)
        optimizer: Optimizer (Adam)
        -----------
        '''
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        
        if D is None:
            raise ValueError("window_size must be provided")
        if C is None:
            raise ValueError("compression factor must be provided")
        if M_C is None:
            raise ValueError("minimum compression must be provided")
        if L is None:
            raise ValueError("L must be provided")
        if sigma is None:
            raise ValueError("Sigma must be provided")
        if lr is None:
            raise ValueError("Learning rate must be provided")
        if delta is None:
            raise ValueError("Delta threshold must be provided")
            
        self.model = Autoencoder(D, C, M_C).to(self.device)
        self.L = L
        self.sigma = sigma
        self.lambda_reg = lambda_reg
        self.delta = delta  # Threshold for triggering a point process event

        # Initialize Zs as a tensor of zeros on the device (shape: L x latent dimension)
        if self.model.LD is None:
            raise ValueError("Latent dimension is not defined in the model")
        self.Zs = torch.zeros((L, self.model.LD), device=self.device)
        
        self.criterion = nn.MSELoss() 
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        # To store latent codes from each training iteration (each element is a tensor of shape [batch_size, latent_dim])
        self.latent_history = []
        # To store spike (point process) events per sample
        self.spike_train_history = []

    def train(self, train_loader, num_epochs=10):
        """
        Train the autoencoder with an additional latent-code regularization term.
        For each sample in a batch, compute the regularization loss and generate a spike event
        (1 if the loss exceeds delta, else 0). The spike events are generated per new sample.
        """
        self.model.train()
        sigma = self.sigma  # Standard deviation for the Gaussian function
        
        for epoch in range(num_epochs):
            epoch_loss = 0.0
            for data in train_loader:
                data = data.to(self.device)
                
                # Forward pass: autoencoder returns reconstructed data and stores the latent code.
                outputs = self.model(data)
                # Log latent codes from this batch (each element is of shape: [batch_size, latent_dim])
                self.latent_history.append(self.model.latent_code.detach().cpu())
                
                recon_loss = self.criterion(outputs, data)
                
                # Instead of averaging over the batch, compute per-sample regularization.
                latent_codes = self.model.latent_code  # shape: (batch_size, latent_dim)
                reg_losses = []
                
                # Process each sample in the batch individually:
                for i in range(latent_codes.size(0)):
                    current_z = latent_codes[i]  # shape: (latent_dim,)
                    diff = current_z.unsqueeze(0) - self.Zs  # shape: (L, latent_dim)
                    squared_diff = diff.pow(2).sum(dim=1)      # shape: (L,)
                    similarity = torch.exp(-squared_diff / (2 * sigma**2))
                    avg_similarity = similarity.mean()
                    reg_loss_i = 1 - avg_similarity  # Regularization loss for this sample
                    reg_losses.append(reg_loss_i)
                    
                    # Generate a spike event per sample: compare reg_loss_i to threshold delta.
                    spike_event = 1 if reg_loss_i.item() > self.delta else 0
                    self.spike_train_history.append(spike_event)
                    
                    # Update the Zs buffer: remove the oldest latent code and append the current one.
                    new_z = current_z.detach()
                    self.Zs = torch.cat((self.Zs[1:], new_z.unsqueeze(0)), dim=0)
                
                # Average the regularization loss over all samples in the batch.
                reg_loss = torch.stack(reg_losses).mean()
                total_loss = recon_loss + self.lambda_reg * reg_loss
                
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()
                
                epoch_loss += total_loss.item()
            
            if (epoch + 1) % 1024 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.8f}")
        
        print("Training Complete!")
    
    def encode(self, data):
        """Return the latent code for the given input data."""
        self.model.eval()
        with torch.no_grad():
            data = data.to(self.device)
            latent_code = self.model.encoder(data)
        return latent_code.cpu()

    def get_latent_codes(self, data_loader):
        """Iterate over the data loader and return latent codes for all samples."""
        self.model.eval()
        latent_codes_list = []
        with torch.no_grad():
            for data in data_loader:
                data = data.to(self.device)
                latent_codes = self.model.encoder(data)
                latent_codes_list.append(latent_codes.cpu())
        return torch.cat(latent_codes_list, dim=0)
    
    def print_latent_history(self):
        """Print the stored latent history in a nicely formatted manner."""
        print("=== Latent History ===")
        for iter_idx, batch_latents in enumerate(self.latent_history):
            print(f"Iteration {iter_idx+1}:")
            for sample_idx in range(batch_latents.shape[0]):
                code = batch_latents[sample_idx].tolist()
                formatted_code = ", ".join([f"{v:.4f}" for v in code])
                print(f"  Sample {sample_idx+1}: [{formatted_code}]")
            print("-" * 50)
        print("=== End of Latent History ===")
    
    def print_spike_train_history(self):
        """Print the recorded spike events corresponding to each sample."""
        print("=== Spike Train History ===")
        for idx, spike in enumerate(self.spike_train_history):
            print(f"Sample {idx+1}: Spike event = {spike}")
        print("=== End of Spike Train History ===")
    
    def plot_spike_train_history(self, save_path=None):
        """Plot the spike train events over time using a scatter plot with different colors for 0s and 1s."""
        spike_indices = [i for i, val in enumerate(self.spike_train_history) if val == 1]
        non_spike_indices = [i for i, val in enumerate(self.spike_train_history) if val == 0]

        plt.figure(figsize=(10, 4))
        plt.scatter(non_spike_indices, [0] * len(non_spike_indices), color='r', marker='o', label="No Spike (0)")
        plt.scatter(spike_indices, [1] * len(spike_indices), color='b', marker='o', label="Spike (1)")

        plt.xlabel("Sample Index")
        plt.ylabel("Spike Event (1=Spike, 0=No Spike)")
        plt.title("Spike Train History Over Samples")
        plt.ylim(-0.1, 1.1)
        plt.grid(True)
        plt.legend()
        plt.show()
        
        if save_path is not None:
            plt.savefig(save_path)
        plt.close()

    def plot_latent_space_history(self, save_path=None, title=None):
        """
        Plot the evolution of the latent space over training iterations.
        
        This function extracts the latent code of the first sample from each batch (stored in latent_history)
        and plots their evolution. If the latent code dimension is 2, a direct scatter plot is generated;
        otherwise, PCA is used to project the latent codes into 2D.
        
        Parameters:
        -----------
        save_path : str, optional
            If provided, the figure is saved to this path.
        title : str, optional
            Custom title for the plot.
        """
        # Extract the first sample's latent code from each batch in the latent history.
        first_sample_latents = []
        for batch_latents in self.latent_history:
            # Each batch_latents is assumed to be a tensor of shape [batch_size, latent_dim]
            first_sample_latents.append(batch_latents[0])
        first_sample_latents = torch.stack(first_sample_latents)  # Shape: [num_batches, latent_dim]
        first_sample_latents = first_sample_latents.detach().cpu().numpy()
        
        iterations = np.arange(first_sample_latents.shape[0])
        plt.figure(figsize=(8, 6))
        
        if first_sample_latents.shape[1] == 2:
            # If the latent dimension is 2, plot directly.
            plt.scatter(first_sample_latents[:, 0], first_sample_latents[:, 1],
                        c=iterations, cmap='viridis', s=50)
            plt.xlabel("Latent Dimension 1")
            plt.ylabel("Latent Dimension 2")
            if title is None:
                plt.title("Latent Codes Evolution (first sample in each batch)")
            else:
                plt.title(title)
        else:
            # Otherwise, project the latent codes to 2D using PCA.
            pca = PCA(n_components=2)
            latents_2d = pca.fit_transform(first_sample_latents)
            plt.scatter(latents_2d[:, 0], latents_2d[:, 1],
                        c=iterations, cmap='viridis', s=50)
            plt.xlabel("Principal Component 1")
            plt.ylabel("Principal Component 2")
            if title is None:
                plt.title("Latent Codes Evolution (first sample in each batch) - PCA Projection")
            else:
                plt.title(title)
        
        plt.colorbar(label="Training Iteration")
        
        if save_path is not None:
            plt.savefig(save_path)
        plt.show()
        plt.close()


## Data

In [None]:
class EEGDataset(Dataset):
    def __init__(self, signal, window_size, step=None):
        """
        signal: 1D numpy array (e.g., one EEG channel)
        window_size: number of samples per window (D)
        step: step size between windows (default: non-overlapping windows)
        """
        self.signal = signal
        self.window_size = window_size
        self.step = step if step is not None else window_size
        self.windows = []
        # Create sliding windows
        for start in range(0, len(signal) - window_size + 1, self.step):
            self.windows.append(signal[start:start+window_size])
        self.windows = np.array(self.windows)
        
    def __len__(self):
        return len(self.windows)
    
    def __getitem__(self, idx):
        sample = self.windows[idx]
        # Convert to tensor (shape: [window_size])
        sample = torch.tensor(sample, dtype=torch.float32)
        return sample

## Main

In [None]:
fs = 250                      # Sampling frequency (Hz)
batch_size = 32
num_epochs = 4096

D_candidates = [32, 64, 128, 256]    # Window size (e.g., 1 sec of data)
C_candidates = [1, 2, 3]             # Compression factor
M_C_candidates = [2, 4, 8]           # Minimum compression size
L_candidates = [16, 32, 64]          # Number of latent codes to store
lr_candidates = [2**-i for i in [6, 10, 13]]
lambda_reg_candidates = [0, 1/4, 1, 128]
sigma_candidates = [2**-i for i in [0, 3, 6]]
delta_candidates = [0.25, 0.5, 0.75]

# Approximately 4k combinations
combinations = list(product(D_candidates, C_candidates, M_C_candidates, L_candidates, lr_candidates, lambda_reg_candidates, sigma_candidates, delta_candidates))

# Channels specification
essential_channels = ['EXG Channel 6', 'EXG Channel 7']
keep_channels = [
    'EXG Channel 0', 'EXG Channel 1', 'EXG Channel 2', 'EXG Channel 3',
    'EXG Channel 4', 'EXG Channel 5', 'EXG Channel 6', 'EXG Channel 7'
]
drop_channels = [
    'Accel Channel 0', 'Accel Channel 1', 'Accel Channel 2', 'Not Used',
    'Digital Channel 0 (D11)', 'Digital Channel 1 (D12)', 'Digital Channel 2 (D13)',
    'Digital Channel 3 (D17)', 'Not Used', 'Digital Channel 4 (D18)', 
    'Analog Channel 0', 'Analog Channel 1', 'Analog Channel 2',
    'Timestamp', 'Marker Channel', 'Timestamp (Formatted)'
]

# Loop over experiments (for example, experiments 1, 2, and 3)
for i in [1, 2, 3]:
    # Loop over all hyperparameter combinations
    for hyperparameters in combinations:
        D, C, M_C, L, lr, lambda_reg, sigma, delta = hyperparameters
        print(f"\nProcessing experiment {i} ...")
        print(f"Hyperparameters: D={D}, C={C}, M_C={M_C}, L={L}, lr={lr}, lambda_reg={lambda_reg}, sigma={sigma}, delta={delta}")
        hyp_str = f"D={D}_C={C}_M_C={M_C}_L={L}_lr={lr}_lambda_reg={lambda_reg}_sigma={sigma}_delta={delta}"
        
        # Load data
        df = pd.read_csv(f"wink{i}_received_data.txt", skiprows=4, delimiter=', ', engine='python')
        
        # Drop unwanted channels and remove rows with NaNs
        df = df.drop(columns=drop_channels, errors='ignore').dropna().reset_index(drop=True)

        # Normalize data by magnitude for channels to keep
        for channel in keep_channels:
            df[channel] = df[channel] / np.linalg.norm(df[channel])
              
        # Select one channel for training (e.g., "EXG Channel 6")
        signal = df['EXG Channel 6'].values
        
        # Create dataset from sliding windows (non-overlapping windows)
        dataset = EEGDataset(signal, window_size=D, step=D)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        # Instantiate the model and train it
        try:
            model = ITLPPD(D=D, C=C, M_C=M_C, L=L, lr=lr, sigma=sigma, lambda_reg=lambda_reg, delta=delta, device="cuda")
            print(f"Training model on experiment {i} data...")
            model.train(dataloader, num_epochs=num_epochs)
            
            # Plot latent space evolution using the new method
            save_path = f"plots/latent_codes/experiment_{i}_{hyp_str}.png"
            model.plot_latent_space_history(save_path=save_path, title="Latent Codes Evolution (first sample in each batch)")
            
            # (Optional) Plot spike train history
            save_path = f"plots/spike_trains/experiment_{i}_{hyp_str}.png"
            model.plot_spike_train_history(save_path=save_path)
            
            # (Optional) Test prediction on a sample window and visualize reconstruction.
            sample_window = next(iter(dataloader))
            reconstructed = model.model(sample_window.to(model.device))
            plt.figure(figsize=(8, 4))
            plt.plot(sample_window[0].cpu().numpy(), label='Original')
            plt.plot(reconstructed[0].cpu().detach().numpy(), label='Reconstructed')
            plt.legend()
            plt.title(f'Reconstruction Example for Experiment {i}')
            plt.show()
            
            print("#####################################################")
        except ValueError as e:
            if "Latent dimension is not defined in the model" in str(e):
                print("Skipping due to undefined latent dimension.")
            else:
                raise e
        except Exception as e:
            print(f"An error occurred: {e}")
            continue
        finally:
            del model
            torch.cuda.empty_cache()
