In [1]:
import mne
import numpy as np
import pandas as pd
import scipy.io
from scipy.interpolate import NearestNDInterpolator, griddata

In [None]:
import plotly
import plotly.graph_objects as go
import plotly.express as px

plotly.offline.init_notebook_mode(connected=True)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm

### Read in data

In [None]:
raw = mne.io.read_raw_eeglab("../927/fixica.set")

In [None]:
# get the coordinate for each nodes and save them in an array
node_list = []
for key, value in raw.get_montage().get_positions()["ch_pos"].items():
    node_list.append(value.tolist())
node_coord = np.array(node_list)

In [None]:
df = raw.to_data_frame()

In [None]:
# use the (x, y) location of each node to interpolate the z values to get a head shape
xi = np.linspace(min(node_coord[:, 0]), max(node_coord[:, 0]), num=100)
yi = np.linspace(min(node_coord[:, 1]), max(node_coord[:, 1]), num=100)

x_grid, y_grid = np.meshgrid(xi, yi)

z_grid = griddata(
    (node_coord[:, 0], node_coord[:, 1]),
    node_coord[:, 2],
    (x_grid, y_grid),
    method="cubic",
)

In [None]:
# unpack the interpolated value and save all interpolated (x, y, z) values into one list
xy = []
for x in xi:
    for y in yi:
        xy.append([x, y])

row_count = 0
num_count = 0
xyz = []
list_count = 0
for xyval in xy:
    if num_count != 99:
        xyz.append([xyval[0], xyval[1], z_grid[row_count, num_count]])
        num_count += 1
    elif num_count == 99:
        xyz.append([xyval[0], xyval[1], z_grid[row_count, num_count]])
        num_count = 0
        row_count += 1

## interpolation for one time stamp

In [None]:
t0 = df[df["time"] == 0]

In [None]:
# create the list for egg data for the current time stamp
eeg_list = []
for val in t0.columns:
    if val != "time":
        eeg_list.append(np.mean(t0[val]))
eeg = np.array(eeg_list)

In [None]:
# create interpolation model using the eeg data from the current time stamp
interpolate_model = NearestNDInterpolator(node_coord, eeg)

In [None]:
# use the interpolation model to interpolate eeg values for interpolated data points
interpolated_eeg = interpolate_model(np.nan_to_num(np.array(xyz)))

### Use plotly to plot one time stamp

In [None]:
fig = go.Figure(
    go.Surface(
        x=x_grid,
        y=y_grid,
        z=z_grid,
        surfacecolor=np.reshape(interpolated_eeg, (-1, 100)),
    )
)

fig.show()

### Use matplotlib to plot one time stamp

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

# create colormap
import matplotlib

color_dimension = np.reshape(
    interpolated_eeg, (-1, 100)
)  # change to desired fourth dimension
minn, maxx = color_dimension.min(), color_dimension.max()
norm = matplotlib.colors.Normalize(minn, maxx)
m = plt.cm.ScalarMappable(norm=norm, cmap="jet")
m.set_array([])
fcolors = m.to_rgba(color_dimension)

# Plot the surface.
surf = ax.plot_surface(x_grid, y_grid, z_grid, facecolors=fcolors)

# Add a color bar which maps values to colors.
fig.colorbar(surf)

plt.show()

### Animated using matplotlib

In [None]:
%matplotlib inline

In [None]:
from IPython.display import HTML
from matplotlib import animation, rc

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})


def update_graph(tp):
    now = df[df["time"] == tp]
    ax.cla()
    eeg_list = []
    for val in now.columns:
        if val != "time":
            eeg_list.append(np.mean(now[val]))
    eeg = np.array(eeg_list)
    interpolate_model = NearestNDInterpolator(node_coord, eeg)
    interpolated_eeg = interpolate_model(np.nan_to_num(np.array(xyz)))
    color_dimension = np.reshape(
        interpolated_eeg, (-1, 100)
    )  # change to desired fourth dimension
    minn, maxx = color_dimension.min(), color_dimension.max()
    norm = matplotlib.colors.Normalize(minn, maxx)
    m = plt.cm.ScalarMappable(norm=norm, cmap="jet")
    m.set_array([])
    fcolors = m.to_rgba(color_dimension)

    # Plot the surface.
    surf = ax.plot_surface(x_grid, y_grid, z_grid, facecolors=fcolors)

    # Add a color bar which maps values to colors.
    #     fig.colorbar(surf, ax=ax)

    return fig

In [None]:
ani = matplotlib.animation.FuncAnimation(fig, update_graph, df["time"].head(10))

In [None]:
# HTML(ani.to_jshtml())

### Animated using plotly

In [None]:
# Define frames
import plotly.graph_objects as go

nb_frames = df["time"].head(50)

In [None]:
node_df = pd.DataFrame(raw.get_montage().get_positions()["ch_pos"]).transpose().reset_index()
node_df = node_df.rename(columns={'index':'channel', 0:'x', 1:'y', 2:'z'})

In [None]:
def interpolated_time(k):
    now = df[df["time"] == k]
    eeg_list = []
    for val in now.columns:
        if val != "time":
            eeg_list.append(np.mean(now[val]))
    eeg = np.array(eeg_list)
    interpolate_model = NearestNDInterpolator(node_coord, eeg)
    interpolated_eeg = interpolate_model(np.nan_to_num(np.array(xyz)))
    return np.reshape(interpolated_eeg, (-1, 100))

In [None]:
fig = go.Figure(
    frames=[
        go.Frame(
            data=go.Surface(
                x=x_grid, y=y_grid, z=z_grid, surfacecolor=interpolated_time(k)
            ),
            name=str(
                k
            ),  # you need to name the frame for the animation to behave properly
        )
        for k in nb_frames
    ]
)

# Add data to be displayed before animation starts
fig.add_trace(
    go.Surface(x=x_grid, y=y_grid, z=z_grid, surfacecolor=interpolated_time(0))
)
fig.add_scatter3d(connectgaps=True,  x=node_df['x'], y=node_df['y'], z=node_df['z'], text=node_df["channel"], mode='markers')

def frame_args(duration):
    return {
        "frame": {"duration": duration},
        "mode": "immediate",
        "fromcurrent": True,
        "transition": {"duration": duration, "easing": "linear"},
    }


sliders = [
    {
        "pad": {"b": 10, "t": 60},
        "len": 0.9,
        "x": 0.1,
        "y": 0,
        "steps": [
            {
                "args": [[f.name], frame_args(0)],
                "label": str(k),
                "method": "animate",
            }
            for k, f in enumerate(fig.frames)
        ],
    }
]

fig.update_layout(
    title="EEG Interpolated 3D Graph",
    width=1000,
    height=600,
    scene=dict(
        zaxis=dict(
            range=[
                np.nan_to_num(z_grid.data.tolist()).min(),
                np.nan_to_num(z_grid.data.tolist()).max(),
            ],
            autorange=False,
        ),
        aspectratio=dict(x=1.5, y=1.5, z=1),
    ),
    updatemenus=[
        {
            "buttons": [
                {
                    "args": [None, frame_args(50)],
                    "label": "&#9654;",  # play symbol
                    "method": "animate",
                },
                {
                    "args": [[None], frame_args(0)],
                    "label": "&#9724;",  # pause symbol
                    "method": "animate",
                },
            ],
            "direction": "left",
            "pad": {"r": 1, "t": 1},
            "type": "buttons",
            "x": 0.1,
            "y": 0,
        }
    ],
    sliders=sliders,
)

### 3D animated scatter plot using plotly

In [None]:
import matplotlib as mpl
import matplotlib.colors
import plotly.express as px

In [None]:
plot_df = raw.to_data_frame()
col_names = plot_df.columns.tolist()[1:]
df = plot_df.melt(
    id_vars="time", value_vars=col_names, var_name="channels", value_name="signal"
)
channel_dict = raw.get_montage().get_positions()["ch_pos"]
n = mpl.colors.SymLogNorm(
    linthresh=1, linscale=1, vmin=min(df["signal"]), vmax=max(df["signal"])
)
m = mpl.cm.ScalarMappable(norm=n, cmap="seismic")
plot_df3 = df[df["time"] <= 50]

In [None]:
plot_df4 = plot_df3.copy()
loca_listx = []
loca_listy = []
loca_listz = []
for index, row in plot_df4.iterrows():
    loca_listx.append(channel_dict[row["channels"]][0])
    loca_listy.append(channel_dict[row["channels"]][1])
    loca_listz.append(channel_dict[row["channels"]][2])
plot_df4["chan_locax"] = loca_listx
plot_df4["chan_locay"] = loca_listy
plot_df4["chan_locaz"] = loca_listz

In [None]:
fig = px.scatter_3d(
    plot_df4,
    x="chan_locax",
    y="chan_locay",
    z="chan_locaz",
    color="signal",
    animation_frame="time",
    hover_name="channels",
    text="channels",
    color_continuous_scale="RdYlGn_r",
)
fig.update_layout(margin=dict(l=0, r=0, t=30, b=10))
fig.show()