# Device Separation & Information Capacity Monte Carlo

This notebook estimates the achievable Information Capacity (IC) for pairs of neural recording devices (e.g., DISC, SEEG) as a function of spatial separation or relative rotation.

Objectives
- Quantify montage selection impact per‑channel IC.
- Compare monopolar vs. intra‑device vs. inter‑device differential (montage) strategies.
- Explore scaling behavior with source count and size (dipole density experiments).

Key Concepts
- Lead Field (LF): Gain matrix mapping dipole moment vectors to electrode voltages.
- Dipole Sampling: Randomized dipole orientations and positions (non-overlapping in time) within a region of interest (ROI).
- Shannon–Hartley Capacity: \(C = B \log_2(1 + \mathrm{SNR})\) summed across dipoles/electrodes.
- Montage Optimization: Selecting electrode pairings that maximize instantaneous power for each dipole.

Experiments Implemented
1. Gap Sweep (Small & Large ROI): Translate or rotate two devices; compute IC curves.
2. Source Scaling (S vs. C): Vary number of dipoles; evaluate IC trends for single vs. paired devices.

Usage Notes
- Ensure lead field archives exist in `leadfields/` with consistent spatial dimensions.
- Parameter `transform` controls experiment mode: `'translate'` or `'rotate'`.
- Monte Carlo iteration counts (`dual_iterations`, `single_iterations`) govern statistical stability.
- Noise values are squared RMS microvolt noise (already converted to power domain).

Outputs
- Confidence intervals (mean, lower, upper, std) aggregated for each configuration.
- Plots saved under `outputs/` capturing IC vs. gap and IC vs. source count.

Extend
- Add 3D visualization of electrode overlap regions.
- Introduce adaptive stopping when CI width converges below threshold.
- Incorporate biologically informed dipole orientation distributions.


# Maximum Information Capacity
Monte Carlo simulation estimating per‑channel information capacity under varying device separation.


### Procedure
- For N electrodes, generate N/M sources at distinct voxels.
- Assign each source a random orientation and fixed magnitude (scaled when sweeping count).
- Sources are temporally disjoint (no overlap) ⇒ independent SNR contributions.
- For each dipole, only the best (non‑redundant) electrode or montage pairing is counted.
- Capacity normalized per channel for cross‑device comparability.


## Settings
Core experimental configuration — adjust before running cells below.


In [None]:
# Imports
import numpy as np
import sys
from os import path
import os
import time
import matplotlib.pyplot as plt
import scipy.stats  # Added: required for confidence interval calculations

# Path setup to allow relative imports when running notebook directly
sys.path.insert(0, path.join('../..'))
from modules.leadfield_importer import FieldImporter

# ------------------------------------------------------------------
# CONFIGURATION BLOCK
# ------------------------------------------------------------------
folder = r"...\SEPIO_dataset"  # Root dataset directory (replace placeholder)
device = 'DISC'                # Primary device label for plots
fields_file = path.join(folder, 'leadfields', 'DISC_500um_grid.npz')
seeg_file   = path.join(folder, 'leadfields', 'SEEG_500um_grid.npz')
output      = path.join(folder, 'outputs')
os.makedirs(output, exist_ok=True)

# Source & signal parameters
magnitude = 4e-9      # A*m total dipole budget (scaled in source sweep)
B = 50               # Bandwidth (Hz) per channel for Shannon-Hartley
M = 1                # N/M sources logic; currently 1 ⇒ N sources
single_iterations = int(1e3)  # MC iterations for single-device (if used)
dual_iterations   = int(1e3)  # MC iterations for dual-device conditions
noise = np.square(4.1)        # DISC RMS noise (µV) squared → power domain
seeg_noise = np.square(2.7)   # SEEG RMS noise squared

# Transformation mode: spatial sweep type
transform = 'translate'       # {'translate', 'rotate'}
xforms_translate = range(0, 10)   # Gap values (mm) small ROI
xforms_rotate    = range(0, 44)   # Rotation angles (deg)

sources_list = np.array([2, 4, 8, 16, 32, 64, 128, 256])  # Source scaling sweep
xforms = xforms_translate if transform == 'translate' else xforms_rotate

# ------------------------------------------------------------------
# VALIDATION (non-fatal warnings)
# ------------------------------------------------------------------
for lf in [fields_file, seeg_file]:
    if not path.isfile(lf):
        print(f"[WARN] Leadfield missing: {lf}")
if transform not in {'translate','rotate'}:
    raise ValueError("transform must be 'translate' or 'rotate'")
print(f"Experiment: {device} vs SEEG | mode={transform} | iterations={dual_iterations}")


- 2 nAm/mm2 -> 2e-3 Am/m^2 -> * (1e-4m*1e-4m) = 2e-11

- 2 nAm/mm2 -> 2e-3 Am/m^2 -> * (5e-4m*5e-4m) = 50e-11 = 0.5 nAm

- 1 nA*um = 1 e-9A \* e-6m = 1e-15 A\*m -> * 1.2e5 = 1.2e-10 = 0.1 nAm

- 2 nAm/mm2 -> 2e-3 Am/m^2 -> * (200mm^2=2e-4) = 4e-7 = 400 nAm

## Helper Methods
Functions below provide montage power selection and confidence interval calculation.
- All montage selectors operate on voltage arrays shaped (dipoles, electrodes).
- Voltages are expected in microvolts before power conversion (square).
- Minimum voltage threshold suppresses sub-noise contributions.


In [None]:
### Defining core functions

# Adapted from: https://stackoverflow.com/questions/15033511/compute-a-confidence-interval-from-sample-data
def mean_confidence_interval(data, confidence=0.95):
    """Compute mean, lower CI bound, upper CI bound, and sample std.

    Parameters
    ----------
    data : array-like
        Sequence of numeric samples.
    confidence : float, default 0.95
        Confidence level for interval using Student's t distribution.

    Returns
    -------
    tuple (mean, lower, upper, std)
        Mean and symmetric confidence bounds plus standard deviation.
    """
    a = np.array(data, dtype=float)
    n = len(a)
    m = np.mean(a)
    se = scipy.stats.sem(a) if n > 1 else 0.0
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1) if n > 1 else 0.0
    return m, m - h, m + h, np.std(a)

# Montage / Power Selection Utilities

def get_best_channel_each_dipole(v, min_voltage=20):
    """Return max power per dipole using best single electrode.

    Parameters
    ----------
    v : ndarray (D, E)
        Voltages (µV) per dipole (rows) per electrode (columns).
    min_voltage : float
        Threshold below which power is zeroed.
    """
    pwr = np.square(v)
    pwr[pwr < min_voltage**2] = 0
    return np.nanmax(pwr, axis=1)


def get_best_montage_each_dipole(v1, v2, min_voltage=20):
    """Best differential montage power combining two devices.

    Takes max minus min across concatenated channel set; if improvement absent,
    monopolar effectively retained via thresholding.
    """
    all_channels = np.concatenate((v1, v2), axis=1)
    all_channels = np.sort(all_channels, axis=1)
    best_channel_max = np.maximum(all_channels[:,-1], np.zeros((v1.shape[0])))
    best_channel_min = np.minimum(all_channels[:,0], np.zeros((v1.shape[0])))
    powers = np.square(best_channel_max - best_channel_min)
    powers[powers < min_voltage**2] = 0
    return powers


def get_best_inter_montage(v1, v2, min_voltage=20):
    """Optimal differential across opposite devices only.

    Considers (max(dev1) - min(dev2)) and (max(dev2) - min(dev1)), selecting higher.
    """
    sorted1 = np.sort(v1, axis=1)
    sorted2 = np.sort(v2, axis=1)
    max1 = np.maximum(sorted1[:, -1], 0)
    max2 = np.maximum(sorted2[:, -1], 0)
    min1 = np.minimum(sorted1[:, 0], 0)
    min2 = np.minimum(sorted2[:, 0], 0)
    p1 = np.square(max1 - min2)
    p2 = np.square(max2 - min1)
    best = np.maximum(p1, p2)
    best[best < min_voltage**2] = 0
    return best


def get_best_intra_montage(v, min_voltage=20):
    """Optimal differential within a single device.

    Uses channel max versus channel min for each dipole.
    """
    v_sorted = np.sort(v, axis=1)
    vmax = np.maximum(v_sorted[:, -1], 0)
    vmin = np.minimum(v_sorted[:, 0], 0)
    powers = np.square(vmax - vmin)
    powers[powers < min_voltage**2] = 0
    return powers


def inter_device_montage(v1, v2, allow_add=False):
    """Return all pairwise inter-device montages (differentials or additions).

    Parameters
    ----------
    v1, v2 : ndarray (D, E1/E2)
        Dipole voltages per electrode for device 1 and 2.
    allow_add : bool
        If True, include additive combinations; else only subtraction.

    Returns
    -------
    ndarray (D, E1*E2*2)
        All montage voltage results.
    """
    shape1 = v1.shape
    shape2 = v2.shape
    if shape1[0] != shape2[0]:
        raise "Both devices must contain the same number of samples."
        
    # Tile first device
    v1 = np.tile(np.repeat(v1, repeats=2, axis=1), (1, shape1[1]))

    # Repeat second device
    v2 = np.reshape(np.repeat(v2, 2*shape2[1]), (shape2[0], -1))

    # Alternate +/- every other column
    if allow_add:
        m = np.tile([-1, 1], (shape2[0], shape2[1]**2))  # Creates a matrix with alternating columns of -1, 1
    else:
        m = np.tile([-1, -1], (shape2[0], shape2[1]**2)) 
    v2 = v2 * m

    # Add together and return
    return v1 + v2

def intra_device_montage(v, allow_add=False):
    """All within-device montages (differentials; additions optional).

    Returns array of size (D, E^2) containing voltage differences.
    """
    D, E = v.shape
    montages = np.zeros((D, E**2))
    for e1 in range(E):
        for e2 in range(E):
            if e1 == e2:
                continue
            montages[:, e2 + e1 * E] = v[:, e1] - v[:, e2]
    return montages


def all_montages(v1, v2, allow_add=False):
    """Concatenate inter- and intra-device montage voltages.

    Useful for exhaustive search; high dimensional for large E.
    """
    inter = inter_device_montage(v1, v2, allow_add)
    intra1 = intra_device_montage(v1, allow_add)
    intra2 = intra_device_montage(v2, allow_add)
    return np.concatenate((inter, intra1, intra2), axis=1)


## 2 Device Gap Experiment: Small ROI
Translational or rotational sweep across limited spatial extent. Evaluates IC degradation or enhancement with separation.


In [None]:
### DiSc process
print("Starting DiSc")
conf_intervals = []
conf_intervals_inter = []
conf_intervals_intra = []
conf_intervals_monopolar = []
conf_intervals_single = []

for index, xform in enumerate(xforms):
    xform = int(np.round(xform))
    print(f'Starting xform {xform}.')
    start_time = time.time()
    
    field_importer = FieldImporter()
    field_importer.load(fields_file)
    num_electrodes = np.shape(field_importer.fields)[4]
    num_positions = round(num_electrodes*2 / M)  # *2 because we use 2 devices
    field_importer.duplicate_fields()  # 2 devices
    if transform == 'translate':
        field_importer.translate(x=0, y=-xform, electrodes=range(0, num_electrodes))  # move device 1 (electrodes 1-128)
        field_importer.translate(x=0, y=xform, electrodes=range(num_electrodes, 2*num_electrodes))  # move device 2 (electrodes 128-256)
    elif transform == 'rotate':
        trans = round(np.sin(np.radians(xform))*20/2)
        field_importer.rotate(psi=xform, electrodes=range(0, num_electrodes))
        field_importer.rotate(psi=-xform, electrodes=range(num_electrodes, 2*num_electrodes))
        field_importer.translate(x=-trans, y=-10, electrodes=range(0, num_electrodes))
        field_importer.translate(x=trans, y=10, electrodes=range(num_electrodes, 2*num_electrodes))
    fields = field_importer.fields

    # Initialize lists that will hold results for each iteration in MC
    C_list = []
    C_list_inter = []
    C_list_intra = []
    C_list_monopolar = []
    C_list_single = []
    for i in range(0, int(dual_iterations)):
        dipoles = np.zeros((num_positions, num_positions*3))  # First, create empty matrix to fill with dipoles
        for idx in range(num_positions):
            rand_dipole = np.random.random(size=3)
            rand_dipole = rand_dipole * np.linalg.norm(rand_dipole) * magnitude
            dipoles[idx, idx*3:idx*3+3] = rand_dipole

        # Only place where, when devices are furthest apart, there is overlap in the field solutions
        x_locs = np.int64(np.floor(np.random.random(size=num_positions)*21)+20)
        y_locs = np.int64(np.floor(np.random.random(size=num_positions)*21)+20)
        z_locs = np.int64(np.floor(np.random.random(size=num_positions)*15)+3)

        G1 = np.reshape(fields[x_locs,y_locs,z_locs,:,0:num_electrodes], (-1, num_electrodes))
        G1 = np.nan_to_num(G1)
        v1 = np.matmul(dipoles, G1)
        v1 = np.multiply(v1, 1e6)  # uV conversion

        G2 = np.reshape(fields[x_locs,y_locs,z_locs,:,num_electrodes:2*num_electrodes], (-1, num_electrodes))
        G2 = np.nan_to_num(G2)
        v2 = np.matmul(dipoles, G2)
        v2 = np.multiply(v2, 1e6)  # uV conversion

        # Calculate powers, then take maximum dipole only for C calculation
        pwr_montages = get_best_montage_each_dipole(v1, v2)
        pwr_inter = get_best_inter_montage(v1, v2)
        pwr_intra1 = get_best_intra_montage(v1)
        pwr_intra2 = get_best_intra_montage(v2)
        pwr_intra = np.maximum(pwr_intra1, pwr_intra2)
        pwr_monopolar = get_best_channel_each_dipole(np.concatenate((v1, v2), axis=1))
        pwr_single = get_best_channel_each_dipole(v1[:num_positions//2, :])

        snr_montages = pwr_montages / noise
        snr_inter = pwr_inter / noise
        snr_intra = pwr_intra / noise
        snr_monopolar = pwr_monopolar / noise
        snr_single = pwr_single / noise  # Only use half the sources so that N=M

        # Shannon-Hartley Eqn: C = B*log2(1+SNR)
        # Sum C from all devices (snr is an array) to get total C. Can take B out of the sum and multiply at end.
        C = B * np.nansum( np.log2(1 + snr_montages) )
        C = C / (num_electrodes*2)  # C per channel
        C_inter = B * np.nansum( np.log2(1 + snr_inter) )
        C_inter = C_inter / (num_electrodes*2)
        C_intra = B * np.nansum( np.log2(1 + snr_intra) )
        C_intra = C_intra / (num_electrodes*2)
        C_monopolar = B * np.nansum( np.log2(1 + snr_monopolar) )
        C_monopolar = C_monopolar / (num_electrodes*2)  # C per channel (2 devices)
        C_single = B * np.nansum( np.log2(1 + snr_single) )
        C_single = C_single / (num_electrodes)
        
        # Add answers to list for statistics later
        C_list = np.append(C_list, C)
        C_list_inter = np.append(C_list_inter, C_inter)
        C_list_intra = np.append(C_list_intra, C_intra)
        C_list_monopolar = np.append(C_list_monopolar, C_monopolar)
        C_list_single = np.append(C_list_single, C_single)
    print(f'Ended xform {xform}. Iteration took {time.time()-start_time} seconds.')
    
    conf_intervals = np.append(conf_intervals, mean_confidence_interval(C_list))
    conf_intervals_inter = np.append(conf_intervals_inter, mean_confidence_interval(C_list_inter))
    conf_intervals_intra = np.append(conf_intervals_intra, mean_confidence_interval(C_list_intra))
    conf_intervals_monopolar = np.append(conf_intervals_monopolar, mean_confidence_interval(C_list_monopolar))
    conf_intervals_single = np.append(conf_intervals_single, mean_confidence_interval(C_list_single))

In [None]:
### SEEG process
print("Starting SEEG")
s_conf_intervals = []
s_conf_intervals_inter = []
s_conf_intervals_intra = []
s_conf_intervals_monopolar = []
s_conf_intervals_single = []

for index, xform in enumerate(xforms):
    xform = int(np.round(xform))
    print(f'Starting xform {xform}.')
    start_time = time.time()
    
    field_importer = FieldImporter()
    field_importer.load(seeg_file)
    num_electrodes = np.shape(field_importer.fields)[4]
    num_positions = round(num_electrodes*2 / M)  # *2 because we use 2 devices
    field_importer.duplicate_fields()  # 2 devices
    if transform == 'translate':
        field_importer.translate(x=0, y=-xform, electrodes=range(0, num_electrodes))  # move device 1 (electrodes 1-128)
        field_importer.translate(x=0, y=xform, electrodes=range(num_electrodes, 2*num_electrodes))  # move device 2 (electrodes 128-256)
    elif transform == 'rotate':
        trans = round(np.sin(np.radians(xform))*20/2)
        field_importer.rotate(psi=xform, electrodes=range(0, num_electrodes))
        field_importer.rotate(psi=-xform, electrodes=range(num_electrodes, 2*num_electrodes))
        field_importer.translate(x=-trans, y=-10, electrodes=range(0, num_electrodes))
        field_importer.translate(x=trans, y=10, electrodes=range(num_electrodes, 2*num_electrodes))
    fields = field_importer.fields

    # Initialize lists that will hold results for each iteration in MC
    C_list = []
    C_list_inter = []
    C_list_intra = []
    C_list_monopolar = []
    C_list_single = []
    for i in range(0, int(dual_iterations)):
        dipoles = np.zeros((num_positions, num_positions*3))  # First, create empty matrix to fill with dipoles
        for idx in range(num_positions):
            rand_dipole = np.random.random(size=3)
            rand_dipole = rand_dipole * np.linalg.norm(rand_dipole) * magnitude
            dipoles[idx, idx*3:idx*3+3] = rand_dipole

        # Only place where, when devices are furthest apart, there is overlap in the field solutions
        x_locs = np.int64(np.floor(np.random.random(size=num_positions)*21)+20)
        y_locs = np.int64(np.floor(np.random.random(size=num_positions)*21)+20)
        z_locs = np.int64(np.floor(np.random.random(size=num_positions)*15)+3)

        G1 = np.reshape(fields[x_locs,y_locs,z_locs,:,0:num_electrodes], (-1, num_electrodes))
        G1 = np.nan_to_num(G1)
        v1 = np.matmul(dipoles, G1)
        v1 = np.multiply(v1, 1e6)  # uV conversion

        G2 = np.reshape(fields[x_locs,y_locs,z_locs,:,num_electrodes:2*num_electrodes], (-1, num_electrodes))
        G2 = np.nan_to_num(G2)
        v2 = np.matmul(dipoles, G2)
        v2 = np.multiply(v2, 1e6)  # uV conversion

        # Calculate powers, then take maximum dipole only for C calculation
        pwr_montages = get_best_montage_each_dipole(v1, v2)
        pwr_inter = get_best_inter_montage(v1, v2)
        pwr_intra1 = get_best_intra_montage(v1)
        pwr_intra2 = get_best_intra_montage(v2)
        pwr_intra = np.maximum(pwr_intra1, pwr_intra2)
        pwr_monopolar = get_best_channel_each_dipole(np.concatenate((v1, v2), axis=1))
        pwr_single = get_best_channel_each_dipole(v1[:num_positions//2, :])

        snr_montages = pwr_montages / noise
        snr_inter = pwr_inter / noise
        snr_intra = pwr_intra / noise
        snr_monopolar = pwr_monopolar / noise
        snr_single = pwr_single / noise  # Only use half the sources so that N=M

        # Shannon-Hartley Eqn: C = B*log2(1+SNR)
        # Sum C from all devices (snr is an array) to get total C. Can take B out of the sum and multiply at end.
        C = B * np.nansum( np.log2(1 + snr_montages) )
        C = C / (num_electrodes*2)  # C per channel
        C_inter = B * np.nansum( np.log2(1 + snr_inter) )
        C_inter = C_inter / (num_electrodes*2)
        C_intra = B * np.nansum( np.log2(1 + snr_intra) )
        C_intra = C_intra / (num_electrodes*2)
        C_monopolar = B * np.nansum( np.log2(1 + snr_monopolar) )
        C_monopolar = C_monopolar / (num_electrodes*2)  # C per channel (2 devices)
        C_single = B * np.nansum( np.log2(1 + snr_single) )
        C_single = C_single / (num_electrodes)
        
        # Add answers to list for statistics later
        C_list = np.append(C_list, C)
        C_list_inter = np.append(C_list_inter, C_inter)
        C_list_intra = np.append(C_list_intra, C_intra)
        C_list_monopolar = np.append(C_list_monopolar, C_monopolar)
        C_list_single = np.append(C_list_single, C_single)
    print(f'Ended xform {xform}. Iteration took {time.time()-start_time} seconds.')
    
    s_conf_intervals = np.append(s_conf_intervals, mean_confidence_interval(C_list))
    s_conf_intervals_inter = np.append(s_conf_intervals_inter, mean_confidence_interval(C_list_inter))
    s_conf_intervals_intra = np.append(s_conf_intervals_intra, mean_confidence_interval(C_list_intra))
    s_conf_intervals_monopolar = np.append(s_conf_intervals_monopolar, mean_confidence_interval(C_list_monopolar))
    s_conf_intervals_single = np.append(s_conf_intervals_single, mean_confidence_interval(C_list_single))

### Plot Results (Small ROI)
Confidence interval shading uses (upper - mean) as symmetric uncertainty. Curves compare montage strategies and device types.


In [None]:
x_axis = (np.array(xforms)) if transform == 'translate' else np.array(xforms)*2

# DISC colors: blue, cornflowerblue, deepskyblue, skyblue, lightblue
# SEEG colors: red, orange, darksalmon, tomato, salmon
error=conf_intervals[2::4]-conf_intervals[0::4]
plt.fill_between(x_axis,conf_intervals[0::4]+error, conf_intervals[0::4]-error, alpha=0.2, color='blue')
plt.plot(x_axis, conf_intervals[0::4],label=f'2 {device} montage', color='blue')

# error=conf_intervals_inter[2::4]-conf_intervals_inter[0::4]
# plt.fill_between(x_axis,conf_intervals_inter[0::4]+error, conf_intervals_inter[0::4]-error, alpha=0.2, color='cornflowerblue')
# plt.plot(x_axis, conf_intervals_inter[0::4],label=f'2 {device} inter-montage', color='cornflowerblue')

# error=conf_intervals_intra[2::4]-conf_intervals_intra[0::4]
# plt.fill_between(x_axis,conf_intervals_intra[0::4]+error, conf_intervals_intra[0::4]-error, alpha=0.2, color='deepskyblue')
# plt.plot(x_axis, conf_intervals_intra[0::4],label=f'2 {device} intra-montage', color='deepskyblue')

error=conf_intervals_monopolar[2::4]-conf_intervals_monopolar[0::4]
plt.fill_between(x_axis,conf_intervals_monopolar[0::4]+error, conf_intervals_monopolar[0::4]-error, alpha=0.2, color='skyblue')
plt.plot(x_axis, conf_intervals_monopolar[0::4],label=f'2 {device} monopolar', color='skyblue', linestyle='dashed')

# error=conf_intervals_single[2::4]-conf_intervals_single[0::4]
# plt.fill_between(x_axis,conf_intervals_single[0::4]+error, conf_intervals_single[0::4]-error, alpha=0.2, color='lightblue')
# plt.plot(x_axis, conf_intervals_single[0::4],label=f'1 {device}', color='lightblue', linestyle='dotted')

# SEEG colors: red, orange, darksalmon, tomato, salmon
error=s_conf_intervals[2::4]-s_conf_intervals[0::4]
plt.fill_between(x_axis,s_conf_intervals[0::4]+error, s_conf_intervals[0::4]-error, alpha=0.2, color='red')
plt.plot(x_axis, s_conf_intervals[0::4],label=f'2 SEEG montage', color='red')

error=conf_intervals_monopolar[2::4]-conf_intervals_monopolar[0::4]
plt.fill_between(x_axis,s_conf_intervals_monopolar[0::4]+error, s_conf_intervals_monopolar[0::4]-error, alpha=0.2, color='tomato')
plt.plot(x_axis, s_conf_intervals_monopolar[0::4],label=f'2 SEEG monopolar', color='tomato', linestyle='dashed')


plt.title(f'Gap sweep ({dual_iterations} iterations, {magnitude*10**9}nAm)')
plt.ylabel('Btis per channel')
if transform == 'translate':
    plt.xlabel('Center to Center gap (mm)')
else:
    plt.xlabel('Angle (degrees)')
plt.legend()
plt.savefig(path.join(output,f'3_separation_sweep_small.png'), transparent=True, format='png',dpi=300)

## 2 Device Gap Experiment: Large ROI
Expanded spatial sweep (larger translation range) to assess long-range interaction and diminishing returns of montage selection.


In [None]:
# Reset variables in case of changes
magnitude = 4e-9  # See below, magnitude of sources. Unit = A*m; Usually 0.5 - 10 nA*m
B = 50  # Bandwidth, dependent on recording backend
M = 1  # N/M number of sources
single_iterations = int(1e3)
dual_iterations = int(1e3)
noise = np.square(4.1)
seeg_noise = np.square(2.7)
transform = 'translate'  # Options: {'translate', 'rotate'}
xforms_translate = range(0, 20) # 0-10 for small ROI, 0-20 for large ROI; repeated for both
xforms_rotate = range(0,44)
sources_list = np.array([2, 4, 8, 16, 32, 64, 128, 256])
xforms = xforms_translate if transform == 'translate' else xforms_rotate

In [None]:
### DiSc process
print("Starting DiSc")
conf_intervals = []
conf_intervals_inter = []
conf_intervals_intra = []
conf_intervals_monopolar = []
conf_intervals_single = []

for index, xform in enumerate(xforms):
    xform = int(np.round(xform))
    print(f'Starting xform {xform}.')
    start_time = time.time()
    
    field_importer = FieldImporter()
    field_importer.load(fields_file)
    num_electrodes = np.shape(field_importer.fields)[4]
    num_positions = round(num_electrodes*2 / M)  # *2 because we use 2 devices
    field_importer.duplicate_fields()  # 2 devices
    if transform == 'translate':
        field_importer.translate(x=0, y=-xform, electrodes=range(0, num_electrodes))  # move device 1 (electrodes 1-128)
        field_importer.translate(x=0, y=xform, electrodes=range(num_electrodes, 2*num_electrodes))  # move device 2 (electrodes 128-256)
    elif transform == 'rotate':
        trans = round(np.sin(np.radians(xform))*20/2)
        field_importer.rotate(psi=xform, electrodes=range(0, num_electrodes))
        field_importer.rotate(psi=-xform, electrodes=range(num_electrodes, 2*num_electrodes))
        field_importer.translate(x=-trans, y=-10, electrodes=range(0, num_electrodes))
        field_importer.translate(x=trans, y=10, electrodes=range(num_electrodes, 2*num_electrodes))
    fields = field_importer.fields

    # Initialize lists that will hold results for each iteration in MC
    C_list = []
    C_list_inter = []
    C_list_intra = []
    C_list_monopolar = []
    C_list_single = []
    for i in range(0, int(dual_iterations)):
        dipoles = np.zeros((num_positions, num_positions*3))  # First, create empty matrix to fill with dipoles
        for idx in range(num_positions):
            rand_dipole = np.random.random(size=3)
            rand_dipole = rand_dipole * np.linalg.norm(rand_dipole) * magnitude
            dipoles[idx, idx*3:idx*3+3] = rand_dipole

        # Only place where, when devices are furthest apart, there is overlap in the field solutions
        x_locs = np.int64(np.floor(np.random.random(size=num_positions)*21)+20)
        y_locs = np.int64(np.floor(np.random.random(size=num_positions)*21)+20)
        z_locs = np.int64(np.floor(np.random.random(size=num_positions)*15)+3)

        G1 = np.reshape(fields[x_locs,y_locs,z_locs,:,0:num_electrodes], (-1, num_electrodes))
        G1 = np.nan_to_num(G1)
        v1 = np.matmul(dipoles, G1)
        v1 = np.multiply(v1, 1e6)  # uV conversion

        G2 = np.reshape(fields[x_locs,y_locs,z_locs,:,num_electrodes:2*num_electrodes], (-1, num_electrodes))
        G2 = np.nan_to_num(G2)
        v2 = np.matmul(dipoles, G2)
        v2 = np.multiply(v2, 1e6)  # uV conversion

        # Calculate powers, then take maximum dipole only for C calculation
        pwr_montages = get_best_montage_each_dipole(v1, v2)
        pwr_inter = get_best_inter_montage(v1, v2)
        pwr_intra1 = get_best_intra_montage(v1)
        pwr_intra2 = get_best_intra_montage(v2)
        pwr_intra = np.maximum(pwr_intra1, pwr_intra2)
        pwr_monopolar = get_best_channel_each_dipole(np.concatenate((v1, v2), axis=1))
        pwr_single = get_best_channel_each_dipole(v1[:num_positions//2, :])

        snr_montages = pwr_montages / noise
        snr_inter = pwr_inter / noise
        snr_intra = pwr_intra / noise
        snr_monopolar = pwr_monopolar / noise
        snr_single = pwr_single / noise  # Only use half the sources so that N=M

        # Shannon-Hartley Eqn: C = B*log2(1+SNR)
        # Sum C from all devices (snr is an array) to get total C. Can take B out of the sum and multiply at end.
        C = B * np.nansum( np.log2(1 + snr_montages) )
        C = C / (num_electrodes*2)  # C per channel
        C_inter = B * np.nansum( np.log2(1 + snr_inter) )
        C_inter = C_inter / (num_electrodes*2)
        C_intra = B * np.nansum( np.log2(1 + snr_intra) )
        C_intra = C_intra / (num_electrodes*2)
        C_monopolar = B * np.nansum( np.log2(1 + snr_monopolar) )
        C_monopolar = C_monopolar / (num_electrodes*2)  # C per channel (2 devices)
        C_single = B * np.nansum( np.log2(1 + snr_single) )
        C_single = C_single / (num_electrodes)
        
        # Add answers to list for statistics later
        C_list = np.append(C_list, C)
        C_list_inter = np.append(C_list_inter, C_inter)
        C_list_intra = np.append(C_list_intra, C_intra)
        C_list_monopolar = np.append(C_list_monopolar, C_monopolar)
        C_list_single = np.append(C_list_single, C_single)
    print(f'Ended xform {xform}. Iteration took {time.time()-start_time} seconds.')
    
    conf_intervals = np.append(conf_intervals, mean_confidence_interval(C_list))
    conf_intervals_inter = np.append(conf_intervals_inter, mean_confidence_interval(C_list_inter))
    conf_intervals_intra = np.append(conf_intervals_intra, mean_confidence_interval(C_list_intra))
    conf_intervals_monopolar = np.append(conf_intervals_monopolar, mean_confidence_interval(C_list_monopolar))
    conf_intervals_single = np.append(conf_intervals_single, mean_confidence_interval(C_list_single))

In [None]:
### SEEG process
print("Starting SEEG")
s_conf_intervals = []
s_conf_intervals_inter = []
s_conf_intervals_intra = []
s_conf_intervals_monopolar = []
s_conf_intervals_single = []

for index, xform in enumerate(xforms):
    xform = int(np.round(xform))
    print(f'Starting xform {xform}.')
    start_time = time.time()
    
    field_importer = FieldImporter()
    field_importer.load(seeg_file)
    num_electrodes = np.shape(field_importer.fields)[4]
    num_positions = round(num_electrodes*2 / M)  # *2 because we use 2 devices
    field_importer.duplicate_fields()  # 2 devices
    if transform == 'translate':
        field_importer.translate(x=0, y=-xform, electrodes=range(0, num_electrodes))  # move device 1 (electrodes 1-128)
        field_importer.translate(x=0, y=xform, electrodes=range(num_electrodes, 2*num_electrodes))  # move device 2 (electrodes 128-256)
    elif transform == 'rotate':
        trans = round(np.sin(np.radians(xform))*20/2)
        field_importer.rotate(psi=xform, electrodes=range(0, num_electrodes))
        field_importer.rotate(psi=-xform, electrodes=range(num_electrodes, 2*num_electrodes))
        field_importer.translate(x=-trans, y=-10, electrodes=range(0, num_electrodes))
        field_importer.translate(x=trans, y=10, electrodes=range(num_electrodes, 2*num_electrodes))
    fields = field_importer.fields

    # Initialize lists that will hold results for each iteration in MC
    C_list = []
    C_list_inter = []
    C_list_intra = []
    C_list_monopolar = []
    C_list_single = []
    for i in range(0, int(dual_iterations)):
        dipoles = np.zeros((num_positions, num_positions*3))  # First, create empty matrix to fill with dipoles
        for idx in range(num_positions):
            rand_dipole = np.random.random(size=3)
            rand_dipole = rand_dipole * np.linalg.norm(rand_dipole) * magnitude
            dipoles[idx, idx*3:idx*3+3] = rand_dipole

        # Only place where, when devices are furthest apart, there is overlap in the field solutions
        x_locs = np.int64(np.floor(np.random.random(size=num_positions)*21)+20)
        y_locs = np.int64(np.floor(np.random.random(size=num_positions)*21)+20)
        z_locs = np.int64(np.floor(np.random.random(size=num_positions)*15)+3)

        G1 = np.reshape(fields[x_locs,y_locs,z_locs,:,0:num_electrodes], (-1, num_electrodes))
        G1 = np.nan_to_num(G1)
        v1 = np.matmul(dipoles, G1)
        v1 = np.multiply(v1, 1e6)  # uV conversion

        G2 = np.reshape(fields[x_locs,y_locs,z_locs,:,num_electrodes:2*num_electrodes], (-1, num_electrodes))
        G2 = np.nan_to_num(G2)
        v2 = np.matmul(dipoles, G2)
        v2 = np.multiply(v2, 1e6)  # uV conversion

        # Calculate powers, then take maximum dipole only for C calculation
        pwr_montages = get_best_montage_each_dipole(v1, v2)
        pwr_inter = get_best_inter_montage(v1, v2)
        pwr_intra1 = get_best_intra_montage(v1)
        pwr_intra2 = get_best_intra_montage(v2)
        pwr_intra = np.maximum(pwr_intra1, pwr_intra2)
        pwr_monopolar = get_best_channel_each_dipole(np.concatenate((v1, v2), axis=1))
        pwr_single = get_best_channel_each_dipole(v1[:num_positions//2, :])

        snr_montages = pwr_montages / noise
        snr_inter = pwr_inter / noise
        snr_intra = pwr_intra / noise
        snr_monopolar = pwr_monopolar / noise
        snr_single = pwr_single / noise  # Only use half the sources so that N=M

        # Shannon-Hartley Eqn: C = B*log2(1+SNR)
        # Sum C from all devices (snr is an array) to get total C. Can take B out of the sum and multiply at end.
        C = B * np.nansum( np.log2(1 + snr_montages) )
        C = C / (num_electrodes*2)  # C per channel
        C_inter = B * np.nansum( np.log2(1 + snr_inter) )
        C_inter = C_inter / (num_electrodes*2)
        C_intra = B * np.nansum( np.log2(1 + snr_intra) )
        C_intra = C_intra / (num_electrodes*2)
        C_monopolar = B * np.nansum( np.log2(1 + snr_monopolar) )
        C_monopolar = C_monopolar / (num_electrodes*2)  # C per channel (2 devices)
        C_single = B * np.nansum( np.log2(1 + snr_single) )
        C_single = C_single / (num_electrodes)
        
        # Add answers to list for statistics later
        C_list = np.append(C_list, C)
        C_list_inter = np.append(C_list_inter, C_inter)
        C_list_intra = np.append(C_list_intra, C_intra)
        C_list_monopolar = np.append(C_list_monopolar, C_monopolar)
        C_list_single = np.append(C_list_single, C_single)
    print(f'Ended xform {xform}. Iteration took {time.time()-start_time} seconds.')
    
    s_conf_intervals = np.append(s_conf_intervals, mean_confidence_interval(C_list))
    s_conf_intervals_inter = np.append(s_conf_intervals_inter, mean_confidence_interval(C_list_inter))
    s_conf_intervals_intra = np.append(s_conf_intervals_intra, mean_confidence_interval(C_list_intra))
    s_conf_intervals_monopolar = np.append(s_conf_intervals_monopolar, mean_confidence_interval(C_list_monopolar))
    s_conf_intervals_single = np.append(s_conf_intervals_single, mean_confidence_interval(C_list_single))

### Plot Results (Large ROI)
Parallels small ROI analysis; expect broader plateau or shift of optimal gap due to reduced overlap region proportion.


In [None]:
x_axis = (np.array(xforms)) if transform == 'translate' else np.array(xforms)*2

# DISC colors: blue, cornflowerblue, deepskyblue, skyblue, lightblue
# SEEG colors: red, orange, darksalmon, tomato, salmon
error=conf_intervals[2::4]-conf_intervals[0::4]
plt.fill_between(x_axis,conf_intervals[0::4]+error, conf_intervals[0::4]-error, alpha=0.2, color='blue')
plt.plot(x_axis, conf_intervals[0::4],label=f'2 {device} montage', color='blue')

# error=conf_intervals_inter[2::4]-conf_intervals_inter[0::4]
# plt.fill_between(x_axis,conf_intervals_inter[0::4]+error, conf_intervals_inter[0::4]-error, alpha=0.2, color='cornflowerblue')
# plt.plot(x_axis, conf_intervals_inter[0::4],label=f'2 {device} inter-montage', color='cornflowerblue')

# error=conf_intervals_intra[2::4]-conf_intervals_intra[0::4]
# plt.fill_between(x_axis,conf_intervals_intra[0::4]+error, conf_intervals_intra[0::4]-error, alpha=0.2, color='deepskyblue')
# plt.plot(x_axis, conf_intervals_intra[0::4],label=f'2 {device} intra-montage', color='deepskyblue')

error=conf_intervals_monopolar[2::4]-conf_intervals_monopolar[0::4]
plt.fill_between(x_axis,conf_intervals_monopolar[0::4]+error, conf_intervals_monopolar[0::4]-error, alpha=0.2, color='skyblue')
plt.plot(x_axis, conf_intervals_monopolar[0::4],label=f'2 {device} monopolar', color='skyblue', linestyle='dashed')

# error=conf_intervals_single[2::4]-conf_intervals_single[0::4]
# plt.fill_between(x_axis,conf_intervals_single[0::4]+error, conf_intervals_single[0::4]-error, alpha=0.2, color='lightblue')
# plt.plot(x_axis, conf_intervals_single[0::4],label=f'1 {device}', color='lightblue', linestyle='dotted')

# SEEG colors: red, orange, darksalmon, tomato, salmon
error=s_conf_intervals[2::4]-s_conf_intervals[0::4]
plt.fill_between(x_axis,s_conf_intervals[0::4]+error, s_conf_intervals[0::4]-error, alpha=0.2, color='red')
plt.plot(x_axis, s_conf_intervals[0::4],label=f'2 SEEG montage', color='red')

error=conf_intervals_monopolar[2::4]-conf_intervals_monopolar[0::4]
plt.fill_between(x_axis,s_conf_intervals_monopolar[0::4]+error, s_conf_intervals_monopolar[0::4]-error, alpha=0.2, color='tomato')
plt.plot(x_axis, s_conf_intervals_monopolar[0::4],label=f'2 SEEG monopolar', color='tomato', linestyle='dashed')


plt.title(f'Gap sweep ({dual_iterations} iterations, {magnitude*10**9}nAm)')
plt.ylabel('Btis per channel')
if transform == 'translate':
    plt.xlabel('Center to Center gap (mm)')
else:
    plt.xlabel('Angle (degrees)')
plt.legend()
plt.savefig(path.join(output,f'3_separation_sweep_large.png'), transparent=True, format='png',dpi=300)

## Source Count vs Capacity (S vs C)
Vary number of dipoles while conserving (or scaling) total magnitude budget to probe saturation behavior and montage efficacy.


In [None]:
### DiSc processing
print("Starting DiSc")
conf_intervals_double = []
conf_intervals_single = []

for nsources in sources_list:
    print(f'Num dipoles: {nsources}.')
    start_time = time.time()
    
    field_importer = FieldImporter()
    field_importer.load(fields_file)
    num_electrodes = np.shape(field_importer.fields)[4]
    field_importer.duplicate_fields()  # 2 devices
    field_importer.translate(x=0, y=-4, electrodes=range(0, num_electrodes))  # move device 1 (electrodes 1-128)
    field_importer.translate(x=0, y=4, electrodes=range(num_electrodes, 2*num_electrodes))  # move device 2 (electrodes 128-256)
    fields = field_importer.fields

    # Initialize lists that will hold results for each iteration in MC
    C_list_double = []
    C_list_single = []
    for i in range(0, int(dual_iterations)):
        dipoles = np.zeros((nsources, nsources*3))  # First, create empty matrix to fill with dipoles
        for idx in range(nsources):
            rand_dipole = np.random.random(size=3)
            rand_dipole = rand_dipole * np.linalg.norm(rand_dipole) * (magnitude / nsources)
            dipoles[idx, idx*3:idx*3+3] = rand_dipole

        # Only place where, when devices are furthest apart, there is overlap in the field solutions
        x_locs = np.int64(np.floor(np.random.random(size=nsources)*61))
        y_locs = np.int64(np.floor(np.random.random(size=nsources)*61))
        z_locs = np.int64(np.floor(np.random.random(size=nsources)*21))

        G1 = np.reshape(fields[x_locs,y_locs,z_locs,:,0:num_electrodes], (-1, num_electrodes))
        G1 = np.nan_to_num(G1)
        v1 = np.matmul(dipoles, G1)
        v1 = np.multiply(v1, 1e6)  # uV conversion

        G2 = np.reshape(fields[x_locs,y_locs,z_locs,:,num_electrodes:2*num_electrodes], (-1, num_electrodes))
        G2 = np.nan_to_num(G2)
        v2 = np.matmul(dipoles, G2)
        v2 = np.multiply(v2, 1e6)  # uV conversion

        # Calculate powers, then take maximum dipole only for C calculation
        pwr_montages = get_best_montage_each_dipole(v1, v2)
        pwr_single = get_best_channel_each_dipole(v1)

        snr_montages = pwr_montages / noise
        snr_single = pwr_single / noise  # Only use half the sources so that N=M

        # Shannon-Hartley Eqn: C = B*log2(1+SNR)
        # Sum C from all devices (snr is an array) to get total C. Can take B out of the sum and multiply at end.
        C_double = B * np.nansum( np.log2(1 + snr_montages) )
        C_double = C_double / (num_electrodes*2)  # C per channel
        C_single = B * np.nansum( np.log2(1 + snr_single) )
        C_single = C_single / (num_electrodes)
        
        # Add answers to list for statistics later
        C_list_double = np.append(C_list_double, C_double)
        C_list_single = np.append(C_list_single, C_single)
    print(f'Ended iteration nsources={nsources}. Iteration took {time.time()-start_time} seconds.')
    
    conf_intervals_double = np.append(conf_intervals_double, mean_confidence_interval(C_list_double))
    conf_intervals_single = np.append(conf_intervals_single, mean_confidence_interval(C_list_single))

In [None]:
### SEEG processing
print("Starting SEEG")
s_conf_intervals_double = []
s_conf_intervals_single = []

for nsources in sources_list:
    print(f'Num dipoles: {nsources}.')
    start_time = time.time()
    
    field_importer = FieldImporter()
    field_importer.load(seeg_file)
    num_electrodes = np.shape(field_importer.fields)[4]
    field_importer.duplicate_fields()  # 2 devices
    field_importer.translate(x=0, y=-4, electrodes=range(0, num_electrodes))  # move device 1 (electrodes 1-128)
    field_importer.translate(x=0, y=4, electrodes=range(num_electrodes, 2*num_electrodes))  # move device 2 (electrodes 128-256)
    fields = field_importer.fields

    # Initialize lists that will hold results for each iteration in MC
    C_list_double = []
    C_list_single = []
    for i in range(0, int(dual_iterations)):
        dipoles = np.zeros((nsources, nsources*3))  # First, create empty matrix to fill with dipoles
        for idx in range(nsources):
            rand_dipole = np.random.random(size=3)
            rand_dipole = rand_dipole * np.linalg.norm(rand_dipole) * (magnitude / nsources)
            dipoles[idx, idx*3:idx*3+3] = rand_dipole

        # Only place where, when devices are furthest apart, there is overlap in the field solutions
        x_locs = np.int64(np.floor(np.random.random(size=nsources)*61))
        y_locs = np.int64(np.floor(np.random.random(size=nsources)*61))
        z_locs = np.int64(np.floor(np.random.random(size=nsources)*21))

        G1 = np.reshape(fields[x_locs,y_locs,z_locs,:,0:num_electrodes], (-1, num_electrodes))
        G1 = np.nan_to_num(G1)
        v1 = np.matmul(dipoles, G1)
        v1 = np.multiply(v1, 1e6)  # uV conversion

        G2 = np.reshape(fields[x_locs,y_locs,z_locs,:,num_electrodes:2*num_electrodes], (-1, num_electrodes))
        G2 = np.nan_to_num(G2)
        v2 = np.matmul(dipoles, G2)
        v2 = np.multiply(v2, 1e6)  # uV conversion

        # Calculate powers, then take maximum dipole only for C calculation
        pwr_montages = get_best_montage_each_dipole(v1, v2)
        pwr_single = get_best_channel_each_dipole(v1)

        snr_montages = pwr_montages / noise
        snr_single = pwr_single / noise  # Only use half the sources so that N=M

        # Shannon-Hartley Eqn: C = B*log2(1+SNR)
        # Sum C from all devices (snr is an array) to get total C. Can take B out of the sum and multiply at end.
        C_double = B * np.nansum( np.log2(1 + snr_montages) )
        C_double = C_double / (num_electrodes*2)  # C per channel
        C_single = B * np.nansum( np.log2(1 + snr_single) )
        C_single = C_single / (num_electrodes)
        
        # Add answers to list for statistics later
        C_list_double = np.append(C_list_double, C_double)
        C_list_single = np.append(C_list_single, C_single)
    print(f'Ended iteration nsources={nsources}. Iteration took {time.time()-start_time} seconds.')
    
    s_conf_intervals_double = np.append(s_conf_intervals_double, mean_confidence_interval(C_list_double))
    s_conf_intervals_single = np.append(s_conf_intervals_single, mean_confidence_interval(C_list_single))

In [None]:
# This first
# seeg_conf_intervals_double = conf_intervals_double
# seeg_conf_intervals_single = conf_intervals_single

In [None]:
# After generating plot, do this next
s_conf_intervals_double = np.multiply(s_conf_intervals_double, 36)
s_conf_intervals_single = np.multiply(s_conf_intervals_single, 18)
conf_intervals_double = np.multiply(conf_intervals_double, 256)
conf_intervals_single = np.multiply(conf_intervals_single, 128)

## Plot Results (Source Scaling)
Dual-device montage vs single-device trajectories; secondary axis displays equivalent dipole diameter from conserved total moment assumption.


In [None]:
x_axis = sources_list
fig, ax1 = plt.subplots(figsize=(9, 6))
ax2 = ax1.twinx()
dipole_sizes = np.sqrt(np.divide(magnitude*1e9/2, sources_list))*2
# DISC colors: blue, cornflowerblue, deepskyblue, skyblue, lightblue
# SEEG colors: red, orange, darksalmon, tomato, salmon
error=conf_intervals_double[2::4]-conf_intervals_double[0::4]
ax1.fill_between(x_axis,conf_intervals_double[0::4]+error, conf_intervals_double[0::4]-error, alpha=0.2, color='blue')
ax1.plot(x_axis, conf_intervals_double[0::4],label=f'2 {device} montage', color='blue')

error=conf_intervals_single[2::4]-conf_intervals_single[0::4]
ax1.fill_between(x_axis,conf_intervals_single[0::4]+error, conf_intervals_single[0::4]-error, alpha=0.2, color='lightblue')
ax1.plot(x_axis, conf_intervals_single[0::4],label=f'1 {device}', color='lightblue', linestyle='dotted')

error=s_conf_intervals_double[2::4]-s_conf_intervals_double[0::4]
ax1.fill_between(x_axis,s_conf_intervals_double[0::4]+error, s_conf_intervals_double[0::4]-error, alpha=0.2, color='red')
ax1.plot(x_axis, s_conf_intervals_double[0::4],label=f'2 SEEG montage', color='red')

error=s_conf_intervals_single[2::4]-s_conf_intervals_single[0::4]
ax1.fill_between(x_axis,s_conf_intervals_single[0::4]+error, s_conf_intervals_single[0::4]-error, alpha=0.2, color='salmon')
ax1.plot(x_axis, s_conf_intervals_single[0::4],label=f'1 SEEG', color='salmon', linestyle='dotted')

ax2.plot(x_axis, dipole_sizes, color='black')

plt.title(f'{device} dipole sweep ({dual_iterations} iterations, Total: {magnitude*1e9} nAm)')
ax1.set_ylabel('Total Bits')
# ax1.set_ylabel('Bits per channel')
ax1.set_xlabel('Num dipoles')
ax2.set_ylabel('Dipole Diameter (mm)')
# plt.ylim(0, 2)
ax1.legend()
ax2.legend()
plt.savefig(path.join(output,'3_dipole_size_sweep.png'), transparent=True, format='png',dpi=300)