In [None]:
import numpy as np
import matplotlib.pyplot as plt
#!pip install infomeasure
import infomeasure as im
from sklearn.preprocessing import KBinsDiscretizer

Function to calculate the MI using infomeasure

In [None]:
def compute_MI(data, n_bins=5):
  '''
  Compute the MI matrices for all the pairs of channel in data using infomeasure
  package and the Ordinal Estimator, _bins = number of boxes we use to discretize
  the data,
  '''
  data = np.asarray(data) #convert to np array if not already
  n_samples, n_channels = data.shape #assing depending on size
  # To discretize the data:
  discretizer = KBinsDiscretizer (n_bins=n_bins, encode='ordinal', strategy='uniform')
  discretized_data = discretizer.fit_transform(data)

  mi_matrix = np.zeros((n_channels, n_channels)) #initialized MI matrix

  for i in range(n_channels):
    for j in range(i+1, n_channels):
      x = discretized_data[:,i]
      y = discretized_data[:,j]

      mi_value = im.mutual_information(x, y, approach='ordinal', embedding_dim=4)

      mi_matrix[i,j] = mi_matrix[j,i] = mi_value
  return mi_matrix

Function to calculate the Multi-variable mutual information

In [None]:
def compute_multivariableMI(data, n_bins=8):
  '''
  Compute the multi-variable mutual information
  '''
  data = np.asarray(data) #convert to np array if not already
  n_samples, n_channels = data.shape #assing depending on size
  # To discretize the data:
  discretizer = KBinsDiscretizer (n_bins=n_bins, encode='ordinal', strategy='uniform')
  discretized_data = discretizer.fit_transform(data)

  # Calculate multivariable mutual information for all the channels
  multi_mi = im.mutual_information(*[discretized_data[:, i] for i in range(n_channels)], approach='ordinal', embedding_dim=4)

  return multi_mi

Load the data

In [None]:
data = np.load('A338_PS2_clean_epochs.npy')
print(f"Data load, shape: {data.shape} (epoch, channels, samples)")

Main loop to analyze the epochs

In [None]:
MI_matrices = []
Multi_MI = []

for i, epoch_data in enumerate(data):
  segment_for_analysis = epoch_data.T
  mi_mat = compute_MI(segment_for_analysis)
  mmi_val = compute_multivariableMI(segment_for_analysis)

  MI_matrices.append(mi_mat)
  Multi_MI.append(mmi_val)

print('Finish')


Visualization

MI

In [None]:
plt.figure(figsize=(8, 6))
plt.imshow(MI_matrices[0], origin='lower', cmap='viridis')
plt.colorbar(label="Mutual information (bits)")
plt.title(f"Mutual information matrix of first epoch")
plt.xlabel("Channel")
plt.ylabel("Channel")
plt.show()

Multi-variable MI

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(range(len(Multi_MI)), Multi_MI, marker='o', linestyle='-')
plt.title("A338 Multi-variable MI - Control condition")
plt.xlabel("Epoch")
plt.ylabel("Multi-variable MI (bits)")
plt.grid(True)
plt.show()