In [None]:
import sys; sys.path.insert(0, '../../invert')
from invert.forward import get_info, create_forward_model
import mne
import pickle as pkl
import os
import numpy as np
import matplotlib.pyplot as plt
from config import *
os.makedirs("forward_models", exist_ok=True)
verbose = 0

# Settings

In [None]:
samplings = {
    "fine": "oct6", 
    "coarse": "ico4"}

# Load the Files

In [None]:
import mne
mne.datasets.fetch_fsaverage(subjects_dir=subjects_dir, verbose=True)

In [None]:
import mne
# sample_data_folder = mne.datasets.sample.data_path()
# meg_data_path = (
#     sample_data_folder / "MEG" / "sample" / "sample_audvis_filt-0-40_raw.fif"
# )
raw = mne.io.read_raw_fif(meg_data_path)
raw = raw.pick_types(meg=True, eeg=False, eog=False, stim=False)
info = raw.info

fn = f"forward_models/info.pkl"
with open(fn, 'wb') as f:
    pkl.dump(info, f)

fs_dir = os.path.join(subjects_dir, subject)
subjects_dir = os.path.dirname(fs_dir)
src = os.path.join(fs_dir, 'bem', f'{subject}-ico-5-src.fif')
bem = os.path.join(fs_dir, 'bem', f'{subject}-5120-5120-5120-bem-sol.fif')
trans = r"C:\Users\lukas\mne_data\MNE-sample-data\subjects\fsaverage\bem\fsaverage-trans.fif"

# Create and Store Clean Forward Models

In [None]:
ratio = 80
conductivity = [0.33, 0.33/ratio, 0.33]

ico = 4
surfaces = mne.make_bem_model(subject, ico=ico, conductivity=conductivity, verbose=verbose)
bem = mne.make_bem_solution(surfaces, verbose=verbose)
fwds = dict()
srcs = dict()

for samp_label, sampling in samplings.items():
    src = mne.setup_source_space(subject, spacing=sampling, surface='white',
                                        subjects_dir=subjects_dir, add_dist=False,
                                        n_jobs=-1, verbose=verbose)
    fwd = mne.make_forward_solution(info, trans=trans, src=src, bem=bem, eeg=False, 
                                    meg=True, mindist=5.0, n_jobs=-1,
                                        verbose=verbose)
    fwd = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                                use_cps=True, verbose=verbose)
    
    fname = f"forward_models/Clean_{samp_label}-fwd.fif"
    mne.write_forward_solution(fname, fwd, overwrite=True)
    
    fwds[samp_label] = fwd.copy()
    srcs[samp_label] = src.copy()

# Visualize

## MEG Data

In [None]:
L = fwd["sol"]["data"].copy()
L /= np.linalg.norm(L, axis=0)
evoked = mne.EvokedArray(L, info)
evoked.plot_joint(times=[0, 5, 10, 15, 20, 25, 30, 32])

## Helmet & Cortex

In [None]:
import mne


maps = mne.make_field_map(
    evoked,
    trans=trans,
    ch_type="meg",
    subject=subject,
    subjects_dir=subjects_dir,
)
time = 0.083
fig = mne.viz.create_3d_figure((256, 256))
mne.viz.plot_alignment(
    evoked.info,
    subject=subject,
    subjects_dir=subjects_dir,
    fig=fig,
    trans=trans,
    meg="sensors",
    eeg=False,
    surfaces="pial",
    coord_frame="mri",
)
evoked.plot_field(
    maps, time=time, fig=fig, time_label=None, vmax=5e-13, time_viewer=False
)
mne.viz.set_3d_view(
    fig,
    azimuth=40,
    elevation=87,
    focalpoint=(0.0, -0.01, 0.04),
    roll=-25,
    distance=0.55,
)

# MEG-type Errors

## Translation Posterior

In [None]:
direction = "posterior"

samp_label = "fine"
sampling = samplings[samp_label]

translation_list = [1e-3, 2e-3]  # 1-2 mm in meters
for translation in translation_list:
    info_trans = info.copy()
    for i in range(len(info_trans["chs"])):
        info_trans["chs"][i]["loc"][1] -= translation 

    src = srcs[samp_label].copy()
    
    fwd.comment = dict(sampling=sampling, error_type="translation posterior", error_magnitude=int(translation*1e3))
    fwd = mne.make_forward_solution(info_trans, trans=trans, src=src,
                                        bem=bem, eeg=False, meg=True, mindist=5.0, n_jobs=-1,
                                        verbose=verbose)
    fwd = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                                use_cps=True, verbose=verbose)
    mne.write_forward_solution(f"forward_models/Altered_{samp_label}_translation-{int(translation*1e3)}mm-{direction}-fwd.fif", fwd, overwrite=True, verbose=0)
    
    fn = f"forward_models/info_translation-{int(translation*1e3)}mm-{direction}.pkl"
    with open(fn, 'wb') as f:
        pkl.dump(info_trans, f)
    
    
    print(f"Saved {fn}")

In [None]:
info.plot_sensors()
info_trans.plot_sensors()

## Translation Dorsal

In [None]:
direction = "dorsal"

samp_label = "fine"
sampling = samplings[samp_label]

translation_list = [1e-3, 2e-3]  # 1-2 mm in meters
for translation in translation_list:
    info_trans = info.copy()
    for i in range(len(info_trans["chs"])):
        info_trans["chs"][i]["loc"][2] += translation
    
    src = srcs[samp_label].copy()
    fwd = mne.make_forward_solution(info_trans, trans=trans, src=src,
                                        bem=bem, eeg=False, meg=True, mindist=5.0, n_jobs=-1,
                                        verbose=verbose)
    fwd = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                                use_cps=True, verbose=verbose)
    mne.write_forward_solution(f"forward_models/Altered_{samp_label}_translation-{int(translation*1e3)}mm-{direction}-fwd.fif", fwd, overwrite=True, verbose=0)

    fn = f"forward_models/info_translation-{int(translation*1e3)}mm-{direction}.pkl"
    with open(fn, 'wb') as f:
        pkl.dump(info_trans, f)
    
    
    print(f"Saved {fn}")

In [None]:
info.plot_sensors()
info_trans.plot_sensors()

## Translation Right

In [None]:
direction = "right"

samp_label = "fine"
sampling = samplings[samp_label]

translation_list = [1e-3, 2e-3]  # 1-2 mm in meters
for translation in translation_list:
    info_trans = info.copy()
    for i in range(len(info_trans["chs"])):
        info_trans["chs"][i]["loc"][0] += translation

    fwd = fwds[samp_label].copy()
    src = srcs[samp_label].copy()
    fwd.comment = dict(sampling=sampling, error_type=f"translation {direction}", error_magnitude=int(translation*1e3))
    fwd = mne.make_forward_solution(info_trans, trans=trans, src=src,
                                        bem=bem, eeg=False, meg=True, mindist=5.0, n_jobs=-1,
                                        verbose=verbose)
    fwd = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                                use_cps=True, verbose=verbose)
    mne.write_forward_solution(f"forward_models/Altered_{samp_label}_translation-{int(translation*1e3)}mm-{direction}-fwd.fif", fwd, overwrite=True, verbose=0)

    fn = f"forward_models/info_translation-{int(translation*1e3)}mm-{direction}.pkl"
    with open(fn, 'wb') as f:
        pkl.dump(info_trans, f)
    
    print(f"Saved {fn}")

In [None]:
info.plot_sensors()
info_trans.plot_sensors()

# Rotation

In [None]:
import math

def rotate_coordinates(coords, axis, degree):
    theta = math.radians(degree)  # Convert degrees to radians
    
    # Depending on the axis, define the rotation matrix
    if axis == 'x':
        rotation_matrix = np.array([
            [1, 0, 0],
            [0, np.cos(theta), -np.sin(theta)],
            [0, np.sin(theta), np.cos(theta)]
        ])
    elif axis == 'y':
        rotation_matrix = np.array([
            [np.cos(theta), 0, np.sin(theta)],
            [0, 1, 0],
            [-np.sin(theta), 0, np.cos(theta)]
        ])
    elif axis == 'z':
        rotation_matrix = np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta), np.cos(theta), 0],
            [0, 0, 1]
        ])
    else:
        raise ValueError("Invalid axis. Choose from 'x', 'y', or 'z'.")
    
    # Rotate each coordinate
    rotated_coords = [np.dot(rotation_matrix, np.array(coord).T).T for coord in coords]
    return rotated_coords

## Rotation right

In [None]:
degrees_list = [1, 2]
direction = "right"
alteration = "rotation"
for degrees in degrees_list:
    pos = np.stack([p["loc"][:3] for p in info["chs"]], axis=0)
    pos_rotated = rotate_coordinates(pos, 'y', degrees)
    info_rotate = info.copy()
    for i, new_coord in enumerate(pos_rotated):
        info_rotate["chs"][i]["loc"][:3] = new_coord 

    fwd = fwds[samp_label].copy()
    src = srcs[samp_label].copy()
    
    fwd.comment = dict(sampling=sampling, error_type=f"{alteration} {direction}", error_magnitude=degrees)
    fwd = mne.make_forward_solution(info_rotate, trans=trans, src=src,
                                        bem=bem, eeg=False, meg=True, mindist=5.0, n_jobs=-1,
                                        verbose=verbose)
    fwd = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                                use_cps=True, verbose=verbose)
    mne.write_forward_solution(f"forward_models/Altered_{samp_label}_{alteration}-{degrees}deg-{direction}-fwd.fif", fwd, overwrite=True, verbose=0)

    fn = f"forward_models/info_{alteration}-{degrees}deg-{direction}.pkl"
    with open(fn, 'wb') as f:
        pkl.dump(info_rotate, f)
    
    
    print(f"Saved {fn}")

In [None]:
info.plot_sensors()
info_rotate.plot_sensors()

## Rotation up

In [None]:
degrees_list = [0.25, 0.5]
direction = "up"
alteration = "rotation"
for degrees in degrees_list:
    pos = np.stack([p["loc"][:3] for p in info["chs"]], axis=0)
    pos_rotated = rotate_coordinates(pos, 'x', degrees)
    info_rotate = info.copy()
    for i, new_coord in enumerate(pos_rotated):
        info_rotate["chs"][i]["loc"][:3] = new_coord 

    fwd = fwds[samp_label].copy()
    src = srcs[samp_label].copy()
    
    fwd.comment = dict(sampling=sampling, error_type=f"{alteration} {direction}", error_magnitude=degrees)
    fwd = mne.make_forward_solution(info_rotate, trans=trans, src=src,
                                        bem=bem, eeg=False, meg=True, mindist=5.0, n_jobs=-1,
                                        verbose=verbose)
    fwd = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True,
                                                use_cps=True, verbose=verbose)
    fn = f"forward_models/info_{alteration}-{degrees.replace(".","")}deg-{direction}.pkl"
    
    with open(fn, 'wb') as f:
        pkl.dump(info_rotate, f)
    
    mne.write_forward_solution(f"forward_models/Altered_{samp_label}_{alteration}-{degrees.replace(".","")}deg-{direction}-fwd.fif", fwd, overwrite=True, verbose=0)
    print(f"Saved {fn}")

In [None]:
%matplotlib qt
info.plot_sensors()
info_rotate.plot_sensors()