In [None]:
%matplotlib widget

import pathlib
import numpy as np
import mne
import psutil

from os import path as op
from mne.simulation import simulate_evoked
import nilearn.plotting

from ipywidgets import interact, fixed, IntSlider, interactive
import matplotlib.pyplot as plt
data_path = pathlib.Path(mne.datasets.sample.data_path())
subjects_dir = data_path / 'subjects'
subject = 'sample'
meg_path = data_path / 'MEG' / subject

data_path = mne.datasets.sample.data_path()
subjects_dir = op.join(data_path, 'subjects')
fname_ave = op.join(data_path, 'MEG', 'sample', 'sample_audvis-ave.fif')
fname_cov = op.join(data_path, 'MEG', 'sample', 'sample_audvis-cov.fif')
cov = mne.read_cov(fname_cov, verbose=False)
fname_bem = op.join(subjects_dir, 'sample', 'bem',
                    'sample-5120-5120-5120-bem-sol.fif')
bem = mne.read_bem_solution(fname_bem, verbose=False)
fname_trans = op.join(data_path, 'MEG', 'sample',
                      'sample_audvis_raw-trans.fif')
trans = mne.read_trans(fname_trans)
trans_mri_to_head = mne.transforms.invert_transform(trans)
t1_fname = op.join(subjects_dir, subject, 'mri', 'T1.mgz')
t1_img = nilearn.image.load_img(t1_fname)

evoked = mne.read_evokeds(fname_ave, condition='Right Auditory',
                          baseline=(None, 0), verbose=False)
evoked.pick_types(meg=True, eeg=False)
evoked_full = evoked.copy()
evoked.crop(0.07, 0.08)

info_fname = meg_path / 'sample_audvis_raw.fif'
info = mne.io.read_info(info_fname, verbose=False)
info['bads'] = []

dipole_mri_pos = np.array([-60.32, -9.89, 65.41]) / 1000
dipole_head_pos = mne.transforms.apply_trans(trans=trans_mri_to_head,
                                             pts=dipole_mri_pos)
dipole_mni_pos = mne.head_to_mni(dipole_head_pos, 'sample', trans_mri_to_head,
                                 subjects_dir=subjects_dir)

dipole_fname = pathlib.Path('/Users/hoechenberger/Development/meg_dipoles/'
                            'dipole.dip')
cpu_count = psutil.cpu_count(logical=True)

z_vals = [60, 70, 80]

In [None]:
def gen_dipole(evoked, cov, bem, trans, pos):
    dip, _ = mne.fit_dipole(evoked,
                            cov,
                            bem,
                            trans,
                            pos=pos,
                            min_dist=1)
    dip.crop(tmax=dip.times[0])
    return dip


def rotate_fwd(fwd, dipole_ori):
    fwd = fwd.copy()
    
    # Ensure we have a free-orientation forward solution.
    if fwd['source_ori'] != mne.io.constants.FIFF.FIFFV_MNE_FREE_ORI:
        fwd = mne.convert_forward_solution(fwd, surf_ori=False,
                                           force_fixed=False, copy=False,
                                           verbose=False)

    # Orientation must be normalized.
    dipole_ori = np.array(dipole_ori).reshape(1, 3)
    dipole_ori /= np.linalg.norm(dipole_ori)
    fwd['src'][0]['nn'] = dipole_ori

    # Now that we've changed the dipole orientation, convert to
    # fixed-orientation forward solution.
    fwd = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                       copy=False, verbose=False)

    return fwd


def plot(phi, theta, z_pos, t1_img, dipoles,
         fwds, stcs, info, cov, axes):
    dipole = dipoles[z_pos]
    fwd = fwds[z_pos]
    stc = stcs[z_pos]

    # Convert to cartesian coordinates.
    phi_rad = np.deg2rad(phi)
    theta_rad = np.deg2rad(theta)

    x = np.sin(theta_rad) * np.cos(phi_rad)
    y = np.sin(theta_rad) * np.sin(phi_rad)
    z = np.cos(theta_rad)

    ori = np.array([[x, y, z]])
    ori /= np.linalg.norm(ori)

    fwd = rotate_fwd(fwd=fwd, dipole_ori=ori)
    evoked = simulate_evoked(fwd,
                             stc,
                             info,
                             cov=None,
                             nave=np.inf, verbose='error')
    evoked.del_proj()
    evoked.set_eeg_reference(projection=True, verbose=False)
    
    # Do the actual plotting.
    [ax.clear() for ax in axes]
    for ax_num, ch_type in enumerate(['mag', 'grad', 'eeg'], start=1):
        evoked.plot_topomap(ch_type=ch_type, colorbar=False,
                            outlines='skirt',
                            nrows=1, ncols=1,
                            times=evoked.times[-1],
                            res=256, show=False,
                            axes=axes[ax_num])
        axes[ax_num].set_title(ch_type, fontweight='bold')


    dipole.ori = ori.reshape(1, 3)
    dipole.pos[0, -1] = z_pos / 1000.  # Convert to meters
    dipole.plot_locations(trans=trans, subject=subject, subjects_dir=subjects_dir,
                          ax=axes[0], show=False, coord_frame='head')
    axes[0].axis('off')
    axes[0].set_title(None)
    [x.axis('off') for x in fig_axes.ravel()]

    
    #########

    # if fwd['source_ori'] != mne.io.constants.FIFF.FIFFV_MNE_FIXED_ORI:
    #     raise ValueError('Please provide a fixed-orientation fwd to plot().')

#     dipole_ori = np.array([dipole_ori_x, dipole_ori_y, dipole_ori_z], dtype='float')

#     dipole_angle_rad = np.deg2rad(dipole_angle)
#     if fixed_plane == 'axial':
#         x = np.cos(dipole_angle_rad)
#         y = np.sin(dipole_angle_rad)
#         z = 0
#     elif fixed_plane == 'coronal':
#         x = np.cos(dipole_angle_rad)
#         y = 0
#         z = np.sin(dipole_angle)
#     elif fixed_plane == 'saggital':
#         x = 0
#         y = np.cos(dipole_angle_rad)
#         z = np.sin(dipole_angle_rad)
#     else:
#         raise ValueError('Unknown fixed_plane passed!')
#     nilearn.plotting.plot_anat(t1_img, title='Dipole loc.',
#                                cut_coords=dipole_mni_pos, axes=axes[2],
#                                display_mode='ortho')

#     if fixed_plane == 'axial':
#         pos = dipole_mni_pos[-1]
#         display_mode = 'z'
#     elif fixed_plane == 'coronal':
#         pos = dipole_mni_pos[1]
#         display_mode = 'y'
#     elif fixed_plane == 'saggital':
#         pos = dipole_mni_pos[0]
#         display_mode = 'x'

#     ax = nilearn.plotting.plot_anat(t1_img, title='Dipole ori.',
#                                cut_coords=[pos], axes=axes[0],
#                                display_mode=display_mode)

#     dx = np.cos(dipole_angle_rad)
#     dy = np.sin(dipole_angle_rad)
    
#     axes[1].plot(0, 0, 'o', ms=5, color='red')
#     circle = plt.Circle((0, 0), 1.0, color='r', lw=1, fill=False)
#     axes[1].add_artist(circle)
#     axes[1].arrow(x=0, y=0, dx=dx, dy=dy, color='red', lw=3, head_width=0.1, length_includes_head=True)
#     axes[1].set_xlim(-1.05, 1.05)
#     axes[1].set_ylim(-1.05, 1.05)
#     axes[1].axis('off')
#     axes[1].text(0, 0.3, f'{int(round(dipole_angle))}°', fontsize=12, color='red')

In [None]:
dipoles = dict()
fwds = dict()
stcs = dict()

for z in z_vals:
    print(f'Generating dipole fwd solution for depth z={z} mm.')
    pos = dipole_head_pos.copy()
    pos[-1] = z/1000.  # Convert to meters
    dipole = gen_dipole(evoked=evoked, cov=cov, bem=bem,
                        trans=trans, pos=pos)
    # FIXME there seems to be a bug in MNE that prevents us from saving!
    # dipole.save(dipole_fname, overwrite=True)

    fwd, stc = mne.make_forward_dipole(dipole[0], bem=bem, info=info, trans=trans,
                                       verbose='error')
    
    dipoles[z] = dipole
    fwds[z] = fwd
    stcs[z] = stc

    
del dipole, fwd, stc, pos, z

#     # Convert to free-orientation fwd.
#     fwd_free = mne.convert_forward_solution(fwd,
#                                             surf_ori=False,
#                                             copy=True, verbose='error')

In [None]:
widths = [3, 3, 3, 3, 3]
heights = [5, 5]

gs_kw = dict(width_ratios=widths, height_ratios=heights)
fig, fig_axes = plt.subplots(ncols=5, nrows=2,
                      gridspec_kw=gs_kw, figsize=(10, 3))
gs = fig_axes[1, 2].get_gridspec()

axes = []

axes.append(fig.add_subplot(gs[:2, :2], projection='3d'))
axes.append(fig.add_subplot(gs[:2, 2]))
axes.append(fig.add_subplot(gs[:2, 3]))
axes.append(fig.add_subplot(gs[:2, 4]))

phi_slider = IntSlider(min=0, max=360, step=1, value=0, continuous_update=False)
theta_slider = IntSlider(min=0, max=360, step=1, value=90, continuous_update=False)
z_slider = IntSlider(min=60, max=80, step=10, value=60, continuous_update=False)
interactive_plot =  interactive(plot,
                                phi=phi_slider,
                                theta=theta_slider,
                                z_pos=z_slider,
                                t1_img=fixed(t1_img),
                                dipoles=fixed(dipoles),
                                fwds=fixed(fwds), stcs=fixed(stcs),
                                info=fixed(info), cov=fixed(cov),
                                axes=fixed(axes))

output = interactive_plot.children[-1]
output.layout.height = '600px'
interactive_plot