### Figures for Dynamic timewarping
1) Results Figure 4 (example DTW for two pairs, RDM & k-means clustering, sorted timeseries and average for cluster)
2) Supplementary Figure S3 (gap statistic)


In [1]:
"""
Author: linateichmann
Email: lina.teichmann@nih.gov

    Created on 2023-03-30 12:37:54
    Modified on 2023-03-30 12:37:54
"""

import numpy as np
import mne, os
import pandas as pd 
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

from scipy.spatial.distance import squareform
import dtaidistance
from matplotlib.patches import Rectangle

from scipy import stats

np.random.seed(0)
%matplotlib qt

bids_dir = '/System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/'
res_folder = f'{bids_dir}/derivatives/meg_paper/output/'
n_participants = 4
n_dims = 66
colors = ['mediumpurple','cornflowerblue','darkblue','teal','yellowgreen','mediumseagreen'] # for the clusters

# plotting 
colours = pd.read_csv(f'{bids_dir}/sourcedata/meg_paper/colors66.txt',header=None).to_numpy()
labels = pd.read_csv(f'{bids_dir}/sourcedata/meg_paper/labels_super_short.txt',header=None)

font = 'Arial'
text_size = 12
text_size_big = 16
text_size_small = 8

plt.rcParams['font.size'] = text_size
plt.rcParams['font.family'] = font


In [2]:
# DTW
# load one epoched datafile to get time vector
epochs = mne.read_epochs(f'{bids_dir}/derivatives/preprocessed/preprocessed_P1-epo.fif',preload=False)

# load results
all_dat_dims=[]
for p in range(1,n_participants+1):
    all_dat_dims.append(pd.read_csv(f'{res_folder}/regression/P{p}_linreg_within.csv',index_col=0))
all_dat_dims = np.array(all_dat_dims)

# scale data 
all_dat_dims_scaled = np.array([stats.zscore(all_dat_dims[p,:,:]) for p in range(n_participants)])
all_dat_dims_scaled_avg = np.mean(all_dat_dims_scaled,axis=0)

# run dynamic time warping
dist_mat = dtaidistance.dtw.distance_matrix_fast(all_dat_dims_scaled_avg.T)


Reading /System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/derivatives/preprocessed/preprocessed_P1-epo.fif ...
    Found the data of interest:
        t =    -100.00 ...    1300.00 ms
        0 CTF compensation matrices available
Reading /System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/derivatives/preprocessed/preprocessed_P1-epo-1.fif ...
    Found the data of interest:
        t =    -100.00 ...    1300.00 ms
        0 CTF compensation matrices available
Reading /System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/derivatives/preprocessed/preprocessed_P1-epo-2.fif ...
    Found the data of interest:
        t =    -100.00 ...    1300.00 ms
        0 CTF compensation matrices available
Reading /System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/derivatives/preprocessed/preprocessed_P1-epo-3.fif ...
    Found the data of int

In [128]:
# CLUSTERING
# Gap Statistic for K means
def optimalK(data, nrefs=3, maxClusters=15):
    """
    Calculates KMeans optimal K using Gap Statistic 
    Params:
        data: ndarry of shape (n_samples, n_features)
        nrefs: number of sample reference datasets to create
        maxClusters: Maximum number of clusters to test for
    Returns: (gaps, optimalK)
    """
    gaps = np.zeros((len(range(1, maxClusters)),))
    resultsdf = pd.DataFrame({'clusterCount':[], 'gap':[]})
    for gap_index, k in enumerate(range(1, maxClusters)):
        # Holder for reference dispersion results
        refDisps = np.zeros(nrefs)
        # For n references, generate random sample and perform kmeans getting resulting dispersion of each loop
        for i in range(nrefs):
            # Create new random reference set
            randomReference = np.random.random_sample(size=data.shape)
            
            # Fit to it
            km = KMeans(k,random_state=0)
            km.fit(randomReference)
            
            refDisp = km.inertia_
            refDisps[i] = refDisp
        # Fit cluster to original data and create dispersion
        km = KMeans(k,random_state=0)
        km.fit(data)
        
        origDisp = km.inertia_
        # Calculate gap statistic
        gap = np.log(np.mean(refDisps)) - np.log(origDisp)
        # Assign this loop's gap statistic to gaps
        gaps[gap_index] = gap
        resultsdf = pd.concat([resultsdf,pd.DataFrame({'clusterCount':[k], 'gap':[gap]})], ignore_index=True)
    return (gaps.argmax() + 1, resultsdf)


score_g, df = optimalK(dist_mat, nrefs=5, maxClusters=30)
plt.close('all')
fig,ax = plt.subplots()

ax.plot(df['clusterCount'], df['gap'], linestyle=':', marker='o', color='grey')
ax.set_xlabel('K')
ax.set_ylabel('Gap Statistic')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
fig.savefig(f'{bids_dir}/derivatives/meg_paper/figures/Supplementary_gapstatistic.pdf')

nclusters = int(df.loc[np.where(np.diff(df['gap'])<0.1)[0][0],'clusterCount'])

kmeans= KMeans(nclusters,random_state=0).fit(dist_mat)




reorder = [3,2,0,1,4]
kmeans_labels_renamed = np.zeros_like(kmeans.labels_)
for i in np.unique(kmeans.labels_):
    kmeans_labels_renamed[kmeans.labels_==i] = reorder[i]

kmeans.labels_ = kmeans_labels_renamed
order = np.argsort(kmeans.labels_)
kmeans_labels = np.sort(kmeans.labels_)


In [129]:
# MAKE RESULTS FIGURE

plt.close('all')
fig = plt.figure(constrained_layout=True,figsize=(8.25, 11.75))

################################
###*********** A ************###
################################
fig.text(0.01,0.97,'A',fontsize=text_size_big*2)

x_pos = [0.1,0.38,0.73]
y_pos = [0.8,0.6]

## 1) Compare different timeseries
dim_of_interest = [[11,60],[11,1]]
# dim_of_interest = [[7,13],[7,6]]
avg_ts = np.mean(all_dat_dims_scaled,axis=0)

for i,pair in enumerate(dim_of_interest):

    ax = fig.add_subplot()
    ax.set_position([x_pos[0],y_pos[i]+0.05,0.25,0.1])

    x = avg_ts[:,pair[0]]
    y = avg_ts[:,pair[1]]

    ax.plot(epochs.times*1000,x,linewidth=2, color=colours[pair[0],:],label=labels.loc[pair[0],0])
    ax.plot(epochs.times*1000,y, linewidth=2, color=colours[pair[1],:],label=labels.loc[pair[1],0])
    plt.legend(frameon=False,fontsize = 14)

    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    ax.legend(frameon=False,fontsize=text_size_small-1)

    ax.set_xlabel('time (ms)')
    ax.set_ylabel('Correlation (zscored)')


### 2) Demo of dynamic time warping paths
shift = 3
for i,pair in enumerate(dim_of_interest):

    ax = fig.add_subplot()
    ax.set_position([x_pos[1],y_pos[i]+0.05,0.25,0.1])
    x = avg_ts[:,pair[0]]
    y = avg_ts[:,pair[1]]
    path = dtaidistance.dtw.warping_path(x, y)
    time = (epochs.times*1000)
    for [map_x, map_y] in path:
        ax.plot([time[map_x], time[map_y]], [x[map_x]+shift, y[map_y]], ':', color='grey', linewidth=0.5)

    ax.plot(time,x+shift,linewidth=2, color=colours[pair[0],:],label=labels.loc[pair[0],0])
    ax.plot(time,y, linewidth=2, color=colours[pair[1],:],label=labels.loc[pair[1],0])
    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)

    ax.set_yticks([])
    ax.set_xlabel('time (ms)')


### 3) cost matrix
for i,pair in enumerate(dim_of_interest):

    ax = fig.add_subplot()
    ax.set_position([x_pos[2],y_pos[i],0.25,0.2])


    x = avg_ts[:,pair[0]]
    y = avg_ts[:,pair[1]]
    path = dtaidistance.dtw.warping_path(x, y)
    d,cost_matrix = dtaidistance.dtw.warping_paths(x,y)

    im =ax.imshow(cost_matrix, cmap="YlGnBu")
    ax.invert_yaxis()

    # Get the warp path in x and y directions
    path_x = [p[1] for p in path]
    path_y = [p[0] for p in path]

    path_xx = [x+0.5 for x in path_x]
    path_yy = [y+0.5 for y in path_y]

    ax.plot(path_xx, path_yy, color='blue', linewidth=3, alpha=0.2)

    ax.set_xticks(np.arange(20,len(time),100))
    ax.set_xticklabels([int(i) for i in time[ax.get_xticks()]],fontsize=text_size_small)
    ax.set_yticks(np.arange(20,len(time),100))
    ax.set_yticklabels([int(i) for i in time[ax.get_yticks()]],fontsize=text_size_small)

    ax.plot([0, 1], [0, 1], '--', color='grey',lw=1, transform=ax.transAxes)


    ax.set_xlabel('time (ms):\n' + labels.loc[pair[0],0],fontsize=text_size_small)
    ax.set_ylabel('time (ms):\n' + labels.loc[pair[1],0],fontsize=text_size_small)


    cbar = plt.colorbar(im,shrink=0.5)
    cbar.set_ticks([np.nanmin(cost_matrix[cost_matrix != -np.inf]),np.nanmax(cost_matrix[cost_matrix != np.inf])])
    cbar.set_ticklabels(['low','high'],fontsize=text_size_small)
    cbar.set_label('Warping cost', labelpad=-8, rotation=270)
    cbar.outline.set_visible(False)


    ax.annotate('path of lowest cost',
                xy=(path_xx[250], path_yy[250]), xycoords='data',
                xytext=(50, 250), textcoords='data',fontsize=text_size_small,
                arrowprops=dict(facecolor='blue', shrink=0.2,alpha=0.2,ec=None),
                horizontalalignment='left', verticalalignment='bottom')

    ax.set_aspect('equal')


################################
###*********** B ************###
################################
fig.text(0.01,0.55,'B',fontsize=text_size_big*2)

ax = fig.add_subplot()
ax.set_position([0.57,0.25,0.4,.4])

mat = dist_mat[order,:][:,order]
mask = np.ones_like(mat)
mask = np.triu(mask,1)

cm = 'Greys'
max = 9
min = 0
rdm = ax.imshow(mat,cmap=cm,vmax=max,vmin=min)

ax.set_xticks(np.arange(-.5,n_dims,1))
ax.set_yticks(np.arange(-.5,n_dims,1))
ax.grid(color='w', linestyle='-', linewidth=1,alpha=0.2)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.tick_params(bottom=False, left=False)

ax.set_aspect('equal')
ax.invert_yaxis()

cbar = plt.colorbar(rdm,shrink=0.5)
cbar.set_ticks([min,max])
cbar.set_ticklabels(['similar','dissimilar'],fontsize=text_size_small)
cbar.set_label('DTW distance measure', labelpad=-28, rotation=270)
cbar.outline.set_visible(False)


# add coloured rectangles to the RDM
ticks = [0]+list(np.where(np.diff(kmeans_labels)==1)[0])+[len(dist_mat)-1]
for i,[i1,i2] in enumerate(zip(ticks[0:-1],ticks[1:])):
    rect = Rectangle(xy=(i1,i1),width=i2-i1,height=i2-i1,color=colors[i],fill=False,lw=3)
    ax.add_artist(rect)
    print([i1,i2])

# labels
ax = fig.add_subplot()
ax.set_position([0.45,0.05,0.01,.53])
ax.set_yticks(np.arange(n_dims))
ax.set_yticklabels(list(labels.loc[order,0]),fontsize=8)
[ax.spines[i].set_visible(False) for i in ['top','right','bottom']]
ax.set_xticks([])


# timeseries
ax = fig.add_subplot()
ax.set_position([0.1,0.04,0.1,.545])

for i,[dim,col_idx] in enumerate(zip(order,kmeans_labels)):
    ax.plot(epochs.times*1000,avg_ts[:,dim]+i*3,color=colors[col_idx])
ax.set_ylim([-1,n_dims*3])
ax.axis('off')

## average timeseries for each cluster
# ax = fig.add_subplot()
# ax.set_position([0.56,0.15,0.35,0.15])
# ax.set_xlabel('time (ms)',fontsize=text_size)
# ax.set_ylabel('Average cluster\n correlation (zscored)',fontsize=text_size)
# ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])

# # ticks = [0]+list(np.where(np.diff(kmeans_labels)==1)[0])+[len(dist_mat)-1]
# ticks = [0]+list(np.where(np.diff(np.array([0]+list(kmeans_labels)))==1)[0])+[len(dist_mat)]

# for i,[i1,i2] in enumerate(zip(ticks[0:-1],ticks[1:])):
#     ax.plot(epochs.times*1000,np.mean(avg_ts[:,order][:,i1:i2],axis=1),color=colors[i],lw=2,alpha=0.8)
# [ax.spines[i].set_visible(False) for i in ['top','right']]


n_subplots = len(np.unique(kmeans_labels))

left = 0.57
right = 0.57+0.35
top = 0.3
bottom = 0.04

horiz_pos = np.linspace(left,right,n_subplots+1)
vert_pos = np.linspace(bottom,top,n_subplots+1)

all_pos = np.empty((n_subplots+1,n_subplots+1,4))
ticks_start = ticks[0:-1]
ticks_end = ticks[1:]

axs_mask = np.ones((n_subplots+1,n_subplots+1))
axs_mask = np.triu(axs_mask)
axs = ['#' for col in range(n_subplots)]


for row in range(n_subplots):
    ax = fig.add_subplot()
    ax.set_position([horiz_pos[1],vert_pos[row],0.25,0.04])
    ax.set_ylim([-3,4])
    axs[row]=ax

    if row==2:
        ax.set_ylabel('Average cluster\n correlation (zscored)',fontsize=text_size)
    if row==0:
        ax.set_xlabel('Time (s)',fontsize=text_size)
    else:
        ax.set_xticklabels([])


    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.yticks(fontsize=10)
    plt.xticks(fontsize=10)


start_end = []
ts = []
for i in range(n_subplots):
    i1 = ticks_start[i]
    i2 = ticks_end[i]
    axs[i].plot(epochs.times*1000,np.mean(avg_ts[:,order][:,i1:i2],axis=1),color=colors[i],lw=2,alpha=0.8)
    start_end.append([i1,i2])
    ts.append(np.mean(avg_ts[:,order][:,i1:i2],axis=1))



#text
# fig.text(left-0.06,0.05, 'Key characteristics\n1) Presence/absence ~150 ms peak\n2) Relative amplitude of ~300 ms peak')




fig.savefig(f'{bids_dir}/derivatives/meg_paper/figures/Figure4.pdf')

[0, 3]
[3, 16]
[16, 53]
[53, 62]
[62, 65]


In [410]:

plt.close('all')
fig = plt.figure(constrained_layout=True,figsize=(8.25, 11.75))

n_subplots = len(np.unique(kmeans_labels))

left = 0.57
right = 0.57+0.35
top = 0.3
bottom = 0.04

horiz_pos = np.linspace(left,right,n_subplots+1)
vert_pos = np.linspace(bottom,top,n_subplots+1)

all_pos = np.empty((n_subplots+1,n_subplots+1,4))
ticks_start = ticks[0:-1]
ticks_end = ticks[1:]


def twod(a, b):
    lst = [['#' for col in range(a)] for row in range(b)]
    return lst

axs_mask = np.ones((n_subplots+1,n_subplots+1))
axs_mask = np.triu(axs_mask)

axs = twod(n_subplots+1,n_subplots+1)



for row in range(n_subplots+1):
    for col_i, col in enumerate(np.flip(np.arange(n_subplots+1))):

        ax = fig.add_subplot()
        ax.set_position([horiz_pos[row],vert_pos[col],0.04,0.04])
        ax.set_ylim([-3,3])
        axs[col_i][row]=ax

        if axs_mask[col_i][row]==1:
            if ((row==0) and (col==5)):
                ax.set_ylabel('Average cluster\n correlation (zscored)',fontsize=text_size)
                ax.set_xticklabels([])
            elif ((col==0) and (row==5)):
                ax.set_xlabel('Time (s)',fontsize=text_size)
                ax.set_yticklabels([])
            elif ((col==5) and (row==5)):
                ax.axis('off')
            else:
                ax.set_xticklabels([])
                ax.set_yticklabels([])
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            plt.yticks(fontsize=10)
            plt.xticks(fontsize=10)
        else:
            ax.axis('off')


start_end = []
ts = []
for i,ii in zip(range(n_subplots),range(1,n_subplots+1)):
    i1 = ticks_start[i]
    i2 = ticks_end[i]
    axs[ii][5].plot(epochs.times,np.mean(avg_ts[:,order][:,i1:i2],axis=1),color=colors[i],lw=2,alpha=0.8)
    axs[0][i].plot(epochs.times,np.mean(avg_ts[:,order][:,i1:i2],axis=1),color=colors[i],lw=2,alpha=0.8)
    start_end.append([i1,i2])
    ts.append(np.mean(avg_ts[:,order][:,i1:i2],axis=1))



for row in range(n_subplots):
    for col in np.arange(n_subplots):
        if row==col:
            continue
        if axs_mask[col][row]==1:
            axs[col+1][row].plot(epochs.times*1000,ts[row]-ts[col],color='k',alpha=0.5,lw=0.5)



In [459]:

plt.close('all')
fig = plt.figure(constrained_layout=True,figsize=(8.25, 11.75))

n_subplots = len(np.unique(kmeans_labels))

left = 0.57
right = 0.57+0.35
top = 0.3
bottom = 0.04

horiz_pos = np.linspace(left,right,n_subplots+1)
vert_pos = np.linspace(bottom,top,n_subplots+1)

all_pos = np.empty((n_subplots+1,n_subplots+1,4))
ticks_start = ticks[0:-1]
ticks_end = ticks[1:]

axs_mask = np.triu(axs_mask)
axs = ['#' for col in range(n_subplots)]


for row in range(n_subplots):
    ax = fig.add_subplot()
    ax.set_position([horiz_pos[1],vert_pos[row],0.25,0.04])
    ax.set_ylim([-3,3])
    axs[row]=ax

    if row==2:
        ax.set_ylabel('Average cluster\n correlation (zscored)',fontsize=text_size)
    if row==0:
        ax.set_xlabel('Time (s)',fontsize=text_size)
    else:
        ax.set_xticklabels([])


    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.yticks(fontsize=10)
    plt.xticks(fontsize=10)


start_end = []
ts = []
for i in range(n_subplots):
    i1 = ticks_start[i]
    i2 = ticks_end[i]
    axs[i].plot(epochs.times*1000,np.mean(avg_ts[:,order][:,i1:i2],axis=1),color=colors[i],lw=2,alpha=0.8)
    start_end.append([i1,i2])
    ts.append(np.mean(avg_ts[:,order][:,i1:i2],axis=1))

