In [12]:
import numpy as np
import torch
import mne
import os, sys
sys.path.append(".")
sys.path.append("./code")



In [13]:
raw_data_folder = 'C:/Users/Micro/Desktop/MainResearchProjects/SNN-seizure-detection-data/ethz_ieeg/long-term/ID01/'

In [14]:
EEG_raw_files = os.listdir(raw_data_folder)

In [15]:
from eeg_recording import SingleSubjectRecording

In [16]:
import scipy
example_data = scipy.io.loadmat(os.path.join(raw_data_folder, EEG_raw_files[0]))['EEG']

In [21]:
mne_example_data = mne.io.RawArray(example_data, mne.create_info(ch_names=[f'eeg unknow_ch_{i}'for i in range(example_data.shape[0])], sfreq=512))

Creating RawArray with float64 data, n_channels=88, n_times=1843200
    Range : 0 ... 1843199 =      0.000 ...  3599.998 secs
Ready.


In [22]:
mne_example_data.set_channel_types(dict(zip(mne_example_data.ch_names, ['eeg'] * len(mne_example_data.ch_names))))

  mne_example_data.set_channel_types(dict(zip(mne_example_data.ch_names, ['eeg'] * len(mne_example_data.ch_names))))


0,1
Measurement date,Unknown
Experimenter,Unknown
Participant,Unknown

0,1
Digitized points,Not available
Good channels,88 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available

0,1
Sampling frequency,512.00 Hz
Highpass,0.00 Hz
Lowpass,256.00 Hz
Duration,00:59:60 (HH:MM:SS)


In [23]:
mne_example_data = mne_example_data.set_eeg_reference(ref_channels="average")

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.


In [19]:
data = mne.filter.filter_data(mne_example_data.get_data(), sfreq = 512, l_freq = 0.5, h_freq=70)

Setting up band-pass filter from 0.5 - 70 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 70.00 Hz
- Upper transition bandwidth: 17.50 Hz (-6 dB cutoff frequency: 78.75 Hz)
- Filter length: 3381 samples (6.604 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


In [20]:
mne_example_data = mne.io.RawArray(data, mne_example_data.info)

Creating RawArray with float64 data, n_channels=47, n_times=248320
    Range : 0 ... 248319 =      0.000 ...   484.998 secs
Ready.


In [24]:
recording = SingleSubjectRecording(0 << 8 + 2, mne_example_data)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


In [25]:
recording.run_latent_kmeans(n_states=9, use_gfp=True, n_inits=100)

MemoryError: Unable to allocate 1.21 GiB for an array with shape (88, 1843200) and data type float64

In [23]:
microstates = recording.latent_maps
microstates.shape

(15, 47)

In [27]:
recording.gev_tot

0.44715825323618047

In [221]:
data = mne_example_data.get_data()

In [180]:
from data_utils import corr_vectors, get_gfp_peaks

activation = microstates.dot(data)
segmentation = np.argmax(np.abs(activation), axis=0)
print(activation)

gfp_curve = None
(gfp_peeks, gfp_curve) = get_gfp_peaks(
    data, min_peak_dist=2, smoothing=None, smoothing_window=100
)
gfp_sum_sq = np.sum(gfp_curve**2)
map_corr = corr_vectors(data, microstates[segmentation].T)

gev_tot = sum((gfp_curve * map_corr) ** 2) / gfp_sum_sq
gev_tot



[[ 30.78629073  32.38387228  32.53269755 ... 184.13942999 159.92214329
  137.17126715]
 [-32.93767442 -38.67300829 -47.76458465 ... -25.41480863 -34.60075786
  -42.55006125]
 [ -5.43664283   4.14588945  17.33972476 ... 362.70422868 329.65057379
  298.06995285]
 ...
 [-70.90081658 -73.39596684 -80.83949145 ... -99.93996995 -96.75787581
  -95.07881456]
 [-71.20564574 -69.23512586 -74.24628047 ... -64.02229791 -58.65491401
  -53.85131512]
 [-36.10473468 -39.34547843 -43.1381015  ...  16.56768132  25.29825924
   35.16601051]]


0.489340963862868

In [276]:
from data_utils import get_gfp_peaks
class MicrostateTrainingModel(torch.nn.Module):
    
    def __init__(self, n_microstate, n_channels, W = None):
        super(MicrostateTrainingModel, self).__init__()
        weight_shape = (n_microstate, n_channels)
        if W is None:
            W = np.random.random(weight_shape)
            W = torch.DoubleTensor(W)
            norm = W.pow(2).sum(dim=1).sqrt()
            W = W / norm.unsqueeze(-1)
            self.W = torch.nn.Linear(n_microstate, n_channels, bias  = False)
            self.W.weight = torch.nn.Parameter(torch.DoubleTensor(W))
        else:
            self.W = torch.nn.Linear(n_microstate, n_channels, bias  = False)
            self.W.weight = torch.nn.Parameter(torch.DoubleTensor(W))
            
        self.loss_function = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.999))

            
    
    def corr_vectors(self, A, B, axis=0):
        An = A - torch.mean(A, dim=axis, keepdims=True)
        Bn = B - torch.mean(B, dim=axis, keepdims=True)
        new_An = An / torch.linalg.norm(An, dim=axis, keepdims=True)
        new_Bn = Bn / torch.linalg.norm(Bn, dim=axis, keepdims=True)
        return torch.sum(new_An * new_Bn, dim=axis)
    
    def forward(self, eeg_data):
        
        microstates_maps = self.W
        eeg_data = torch.DoubleTensor(eeg_data)
        activation = torch.abs(microstates_maps(eeg_data.T)).T
        segmentation = torch.argmax(activation, axis = 0)
        return segmentation
        
    def loss(self, eeg_data, segmentation, **kwargs):
        # loss
        (peaks, gfp_curve) = get_gfp_peaks(
            eeg_data,
            min_peak_dist=kwargs.pop("min_peak_dist", 2),
            smoothing=kwargs.pop("smoothing", None),
            smoothing_window=kwargs.pop("smoothing_window", 100),
        )
        gfp_sum_sq = np.sum(gfp_curve ** 2)

        eeg_data = torch.DoubleTensor(eeg_data)
        map_corr = self.corr_vectors(eeg_data, self.W.weight[segmentation].T)
        
        
        gev = torch.sum((torch.DoubleTensor(gfp_curve) * map_corr) ** 2) / gfp_sum_sq
        
        loss = self. loss_function(gev, torch.DoubleTensor([1]))
        print(loss, gev)
        return loss
        

In [277]:
model = MicrostateTrainingModel(n_microstate=31, n_channels=47, W = microstates)

In [278]:
data = mne_example_data.get_data()

In [279]:
segmentation = model.forward(data)

In [280]:
running_loss = 0
for i  in range(10000):
        # Zero your gradients for every batch!
        model.optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model.forward(data)

        # Compute the loss and its gradients
        loss = model.loss(data, outputs)
        loss.backward()

        # Adjust learning weights
        model.optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        print(loss)
        
        with torch.no_grad():
                for param in model.parameters():
                        norm = param.pow(2).sum(dim=1).sqrt()
                        param = param / norm.unsqueeze(-1)

tensor(0.3497, dtype=torch.float64, grad_fn=<MseLossBackward0>) tensor(0.4086, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(0.3497, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.3497, dtype=torch.float64, grad_fn=<MseLossBackward0>) tensor(0.4087, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(0.3497, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.3496, dtype=torch.float64, grad_fn=<MseLossBackward0>) tensor(0.4087, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(0.3496, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.3496, dtype=torch.float64, grad_fn=<MseLossBackward0>) tensor(0.4087, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(0.3496, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.3497, dtype=torch.float64, grad_fn=<MseLossBackward0>) tensor(0.4087, dtype=torch.float64, grad_fn=<DivBackward0>)
tensor(0.3497, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.3497, dtype=torch.float64, grad_fn=<MseLossBackward

KeyboardInterrupt: 