### Figures for Dynamic timewarping
1) Results Figure 4 (example DTW for two pairs, RDM & k-means clustering, sorted timeseries and average for cluster)
2) Supplementary Figure S3 (gap statistic)


In [1]:
"""
Author: linateichmann
Email: lina.teichmann@nih.gov

    Created on 2023-03-30 12:37:54
    Modified on 2023-03-30 12:37:54
"""

import numpy as np
import mne, string
import pandas as pd 
import dtaidistance
from dtaidistance import clustering
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy import stats,signal
from scipy.cluster.hierarchy import dendrogram
from collections import defaultdict
import seaborn as sns

np.random.seed(0)
%matplotlib qt

bids_dir = '/System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/'
res_folder = f'{bids_dir}/derivatives/meg_paper/output/'
n_participants = 4
n_dims_all = 66



# plotting 
colours_all = pd.read_csv(f'{bids_dir}/sourcedata/meg_paper/colors66.txt',header=None).to_numpy()
labels_all = pd.read_csv(f'{bids_dir}/sourcedata/meg_paper/labels_super_short.txt',header=None)

font = 'Arial'
text_size = 12
text_size_big = 16
text_size_small = 8

plt.rcParams['font.size'] = text_size
plt.rcParams['font.family'] = font


In [2]:
# DTW
# load one epoched datafile to get time vector
epochs = mne.read_epochs(f'{bids_dir}/derivatives/preprocessed/preprocessed_P1-epo.fif',preload=False)

# load results
all_dat_dims=[]
for p in range(1,n_participants+1):
    all_dat_dims.append(pd.read_csv(f'{res_folder}/regression/P{p}_linreg_within.csv',index_col=0))
all_dat_dims = np.array(all_dat_dims)


# smooth the data
fs = 200
cutoff = fs * 0.2
nyq = 0.5 * fs
Wn = cutoff/nyq
b,a = signal.butter(N=2,Wn=Wn, btype='low', analog=False,output='ba')

all_dat_dims_smooth = np.array([signal.filtfilt(b,a,all_dat_dims[p,:,:],axis=0,method='gust') for p in range(n_participants)])

# z-score data
all_dat_dims_scaled = np.array([stats.zscore(all_dat_dims_smooth[p,:,:]) for p in range(n_participants)])

# average across participants
avg_ts_all = np.mean(all_dat_dims_scaled,axis=0)

# exclude the 12 dimensions with lowest overall peaks [see bottom two rows of supplementary figure]
n_exclude = 12
select_dims = np.sort(np.argsort(np.max(np.mean(all_dat_dims_smooth,axis=0),axis=0))[n_exclude:])
avg_ts = avg_ts_all[:,select_dims]
labels = labels_all.loc[select_dims,:]
labels.reset_index(inplace=True,drop=True)
n_dims = n_dims_all-n_exclude
colours = colours_all[select_dims,:]

# # run dynamic time warping
dist_mat = dtaidistance.dtw.distance_matrix_fast(avg_ts.T)


Reading /System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/derivatives/preprocessed/preprocessed_P1-epo.fif ...
    Found the data of interest:
        t =    -100.00 ...    1300.00 ms
        0 CTF compensation matrices available
Reading /System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/derivatives/preprocessed/preprocessed_P1-epo-1.fif ...
    Found the data of interest:
        t =    -100.00 ...    1300.00 ms
        0 CTF compensation matrices available
Reading /System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/derivatives/preprocessed/preprocessed_P1-epo-2.fif ...
    Found the data of interest:
        t =    -100.00 ...    1300.00 ms
        0 CTF compensation matrices available
Reading /System/Volumes/Data/misc/data16/teichmanna2/2020_MEG_things/THINGS_biowulf/THINGS-MEG-bids/derivatives/preprocessed/preprocessed_P1-epo-3.fif ...
    Found the data of int

In [3]:
# Run SciPy linkage clustering (wrapper function from dtaidistance)
model = clustering.LinkageTree(dtaidistance.dtw.distance_matrix_fast, {})
cluster_idx = model.fit(avg_ts.T)

color_thresh = np.max(cluster_idx[:,2])*0.5

dendro_dict = dendrogram(cluster_idx, no_plot=True,color_threshold=color_thresh,distance_sort='descending')
cluster_order = dendro_dict['leaves']
color_labels = dendro_dict['leaves_color_list']
element_to_index = defaultdict(lambda: len(element_to_index))
cluster_labels = [element_to_index[element] for element in color_labels]

sorted_matrix = dist_mat[cluster_order,:][:, cluster_order]

# cluster colours
colors = sns.color_palette("viridis",n_colors=len(np.unique(cluster_labels)))
sns.set_palette('viridis')



In [4]:
# MAKE RESULTS FIGURE
plt.close('all')
fig = plt.figure(constrained_layout=True,figsize=(8.25, 11.75))

################################
###*********** A ************###
################################
fig.text(0.01,0.97,'A',fontsize=text_size_big*2)

x_pos = [0.1,0.38,0.73]
y_pos = [0.8,0.6]

## 1) Compare different timeseries
dim_of_interest_labels = [['colorful/playful','cylindrical/conical/cushioning'],['colorful/playful','food-related']]
dim_of_interest = [[np.where(labels.loc[:,0]==dim_of_interest_labels[0][0])[0][0],np.where(labels.loc[:,0]==dim_of_interest_labels[0][1])[0][0]],
 [np.where(labels.loc[:,0]==dim_of_interest_labels[1][0])[0][0],np.where(labels.loc[:,0]==dim_of_interest_labels[1][1])[0][0]]]

for i,pair in enumerate(dim_of_interest):

    ax = fig.add_subplot()
    ax.set_position([x_pos[0],y_pos[i]+0.05,0.25,0.1])

    x = avg_ts[:,pair[0]]
    y = avg_ts[:,pair[1]]

    ax.plot(epochs.times*1000,x,linewidth=2, color=colours[pair[0],:],label=labels.loc[pair[0],0])
    ax.plot(epochs.times*1000,y, linewidth=2, color=colours[pair[1],:],label=labels.loc[pair[1],0])
    plt.legend(frameon=False,fontsize = 14)

    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    ax.legend(frameon=False,fontsize=text_size_small-1)

    ax.set_xlabel('time (ms)')
    ax.set_ylabel('Correlation (zscored)')
    ax.set_ylim([-2,4])



### 2) Demo of dynamic time warping paths
shift = 5
for i,pair in enumerate(dim_of_interest):

    ax = fig.add_subplot()
    ax.set_position([x_pos[1],y_pos[i]+0.05,0.25,0.1])
    x = avg_ts[:,pair[0]]
    y = avg_ts[:,pair[1]]
    path = dtaidistance.dtw.warping_path(x, y)
    time = (epochs.times*1000)
    for [map_x, map_y] in path:
        ax.plot([time[map_x], time[map_y]], [x[map_x]+shift, y[map_y]], ':', color='grey', linewidth=0.5)

    ax.plot(time,x+shift,linewidth=2, color=colours[pair[0],:],label=labels.loc[pair[0],0])
    ax.plot(time,y, linewidth=2, color=colours[pair[1],:],label=labels.loc[pair[1],0])
    ax.set_xlim([epochs.times[0]*1000,epochs.times[-1]*1000])

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)

    ax.set_yticks([])
    ax.set_xlabel('time (ms)')


### 3) cost matrix
for i,pair in enumerate(dim_of_interest):

    ax = fig.add_subplot()
    ax.set_position([x_pos[2],y_pos[i],0.25,0.2])


    x = avg_ts[:,pair[0]]
    y = avg_ts[:,pair[1]]
    path = dtaidistance.dtw.warping_path(x, y)
    d,cost_matrix = dtaidistance.dtw.warping_paths(x,y)

    im =ax.imshow(cost_matrix, cmap="YlGnBu")
    ax.invert_yaxis()

    # Get the warp path in x and y directions
    path_x = [p[1] for p in path]
    path_y = [p[0] for p in path]

    path_xx = [x+0.5 for x in path_x]
    path_yy = [y+0.5 for y in path_y]

    ax.plot(path_xx, path_yy, color='blue', linewidth=3, alpha=0.2)

    ax.set_xticks(np.arange(20,len(time),100))
    ax.set_xticklabels([int(i) for i in time[ax.get_xticks()]],fontsize=text_size_small)
    ax.set_yticks(np.arange(20,len(time),100))
    ax.set_yticklabels([int(i) for i in time[ax.get_yticks()]],fontsize=text_size_small)

    ax.plot([0, 1], [0, 1], '--', color='grey',lw=1, transform=ax.transAxes)


    ax.set_xlabel('time (ms):\n' + labels.loc[pair[0],0],fontsize=text_size_small)
    ax.set_ylabel('time (ms):\n' + labels.loc[pair[1],0],fontsize=text_size_small)


    cbar = plt.colorbar(im,shrink=0.5)
    cbar.set_ticks([np.nanmin(cost_matrix[cost_matrix != -np.inf]),np.nanmax(cost_matrix[cost_matrix != np.inf])])
    cbar.set_ticklabels(['low','high'],fontsize=text_size_small)
    cbar.set_label('Warping cost', labelpad=-8, rotation=270)
    cbar.outline.set_visible(False)

    if i ==1:
        ax.annotate('path of lowest cost',
                    xy=(path_xx[250], path_yy[250]), xycoords='data',
                    xytext=(50, 250), textcoords='data',fontsize=text_size_small,
                    arrowprops=dict(facecolor='blue', shrink=0.2,alpha=0.2,ec=None),
                    horizontalalignment='left', verticalalignment='bottom')
    else:
        ax.annotate('path of lowest cost',
                    xy=(path_xx[220], path_yy[220]), xycoords='data',
                    xytext=(50, 250), textcoords='data',fontsize=text_size_small,
                    arrowprops=dict(facecolor='blue', shrink=0.2,alpha=0.2,ec=None),
                    horizontalalignment='left', verticalalignment='bottom')

    ax.set_aspect('equal')


################################
###*********** B ************###
################################
from dtaidistance import similarity
fig.text(0.01,0.55,'B',fontsize=text_size_big*2)

ax = fig.add_subplot()
ax.set_position([0.57,0.25,0.4,.4])

mat = sorted_matrix
cm = 'Greys'
rdm = ax.imshow(mat,cmap=cm)
ax.invert_yaxis()
ax.invert_xaxis()

min = np.min(mat)
max = np.max(mat)

ax.set_xticks(np.arange(-.5,n_dims,1))
ax.set_yticks(np.arange(-.5,n_dims,1))
ax.grid(color='w', linestyle='-', linewidth=1,alpha=0.2)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.tick_params(bottom=False, left=False)

ax.set_aspect('equal')

cbar = plt.colorbar(rdm,shrink=0.5)
cbar.set_ticks([min,max])
cbar.set_ticklabels(['similar','dissimilar'],fontsize=text_size_small)
cbar.set_label('Dynamic time warping:\nDissimilarity metric', labelpad=-10, rotation=270)
cbar.outline.set_visible(False)

# add coloured rectangles to the RDM
ticks = [0]+list(np.where(np.diff(cluster_labels)!=0)[0])+[len(dist_mat)-1]
for i,[i1,i2] in enumerate(zip(ticks[0:-1],ticks[1:])):
    rect = Rectangle(xy=(i1,i1),width=i2-i1,height=i2-i1,color=colors[i],fill=False,lw=3)
    ax.add_artist(rect)

# timeseries
ax = fig.add_subplot()
ax.set_position([0.05,0.04,0.1,.545])
ts_plot = avg_ts
for i,[dim,col_idx] in enumerate(zip(cluster_order,cluster_labels)):
    ax.plot(epochs.times*1000,ts_plot[:,dim]/np.max(ts_plot[:,dim])+i*1,color=colors[col_idx])
    ax.set_ylim([-1,n_dims*1])
    ax.axis('off')


## plot dendrogram
ax = fig.add_subplot()
ax.set_position([0.22,0.04,0.18,.545])
dendro_dict = dendrogram(cluster_idx, color_threshold=color_thresh,orientation='left',
                         above_threshold_color='grey',distance_sort='descending',
                         labels=labels.loc[:,0].to_list())
[ax.spines[s].set_visible(False) for s in ['top','left','bottom']]
ax.set_xticks([])

## average timeseries for each cluster
n_subplots = len(np.unique(cluster_labels))

left = 0.57
right = 0.57+0.35
top = 0.34
bottom = 0.04

horiz_pos = np.linspace(left,right,n_subplots+1)
vert_pos = np.linspace(bottom,top,n_subplots+1)

all_pos = np.empty((n_subplots+1,n_subplots+1,4))
ticks_start = ticks[0:-1]
ticks_end = ticks[1:]

axs_mask = np.ones((n_subplots+1,n_subplots+1))
axs_mask = np.triu(axs_mask)
axs = ['#' for col in range(n_subplots)]

letters = string.ascii_lowercase
for row,title in zip(range(n_subplots),np.flip(np.arange(n_subplots))):
    ax = fig.add_subplot()
    ax.set_position([horiz_pos[1],vert_pos[row],0.25,0.04])
    axs[row]=ax

    if row==0:
        ax.set_xlabel('Time (s)',fontsize=text_size)
    else:
        ax.set_xticklabels([])


    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim([-2,4])
    plt.yticks(fontsize=10)
    plt.xticks(fontsize=10)
    ax.text(500,3.5,f'cluster {letters[title]}',fontsize=text_size)
plt.text(-500, -12, 'Average cluster timecourse', va='center',ha='center', rotation='vertical', fontsize=text_size)


start_end = []
ts = []
for i in range(n_subplots):
    i1 = ticks_start[i]
    i2 = ticks_end[i]


    axs[i].plot(epochs.times*1000,ts_plot[:,cluster_order][:,i1:i2],color=colors[i],lw=1,alpha=0.2)
    axs[i].plot(epochs.times*1000,np.mean(ts_plot[:,cluster_order][:,i1:i2],axis=1),color=colors[i],lw=2,alpha=1)

fig.savefig(f'{bids_dir}/derivatives/meg_paper/figures/Figure4.pdf')