## Convex non-negative matrix factorization (cNMF)
cNMF is an approach that finds canonical response types based off "soft clustering." Mathematically, it's a factorization method that decomposes a matrix of event-related activity into component functions that can be thought of as "canonical responses." These components are dependent on both the magnitude of the response (HGA) and the latency of the responses. The equation is formalized as:

<p>
    <center>
        $X\approx\hat{X}=FG^T$,
    </center>
</p>
<p>
    <center>
        $F=XW$,
    </center>
</p>

where $X$ is the neural time series of shape $n$ times x $p$ electrodes, and where $W$ is a matrix of shape $p$ electrodes x $k$ clusters and represents the cluster weights applied to the neural time series. $G$ is a matrix of shape $p$ electrodes x $k$ clusters and represents the weighting of individual electrodes within a cluster.

This notebook makes such a matrix $X$ of high-gamma activity epoched to sentence onset, then calculates cNMF on it. In the context of this notebook, each cNMF component (or "basis") function represents a differential response to speaking and listening across channels.

In [None]:
# Paths - Update locally!
git_path = '/path/to/git/kurteff2024_code/'
data_path = '/path/to/bids/dataset/'

In [11]:
import numpy as np
import pandas as pd
import mne
import os
import csv
from tqdm.notebook import tqdm
from img_pipe import img_pipe
import sys
import pymf3
import h5py
import warnings

from matplotlib import pyplot as plt
plt.style.use('seaborn')
from matplotlib import rcParams as rc
rc['pdf.fonttype'] = 42
from matplotlib import cm
from matplotlib.gridspec import GridSpec
import seaborn as sns
import matplotlib as mpl
%matplotlib inline

In [2]:
nmf_file = os.path.join(git_path, "analysis", "cnmf", "h5", "NMF_grouped_nospkrall.hf5")
tmin, tmax = -1, 2
level = 'sentence'
baseline = None

In [5]:
subjs = [s for s in os.listdir(
    os.path.join(git_path,"preprocessing","events","csv")) if "TCH" in s or "S0" in s]
exclude = ["TCH8"]
no_imaging = ["S0010"]
subjs = [s for s in subjs if s not in exclude]

blocks = {
    s: [
        b.split("_")[-1] for b in os.listdir(os.path.join(
            git_path,"analysis","events","csv",s)) if f"{s}_B" in b and os.path.isfile(os.path.join(
            git_path,"analysis","events","csv",s,b,f"{b}_spkr_sn_all.txt"
        ))
    ] for s in subjs
}

## Load data

In [6]:
def get_ch_names(subj,blocks,git_path,data_path):
    blockid = "_".join([subj,blocks[s][0]])
    if os.path.isfile(os.path.join(
        data_path, "ecog", subj, blockid, "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"
    )):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            ch_names = mne.io.read_raw_fif(os.path.join(
                data_path, "ecog", subj, blockid, "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"
            ),preload=False,verbose=False).info['ch_names']
    else:
        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                ch_names = mne.io.read_raw_fif(os.path.join(
                    data_path, "ecog", f"sub-{subj}", blockid, "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"
                ),preload=False,verbose=False).info['ch_names']
        except:
            raise Exception("Still cannot find raw. (Tried sub- prefix)")
    return ch_names
def epoch_data(subj,blocks,git_path,data_path,
               channel='spkr',level='sn',condition='all',
               tmin=-.5, tmax=2, baseline=None,
               set_picks=False,picks=None):
    epochs = []
    for b in blocks:
        blockid = f'{subj}_{b}'
        if os.path.isfile(os.path.join(
            data_path, "ecog", subj, blockid, "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"
        )):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                raw = mne.io.read_raw_fif(os.path.join(
                    data_path, "ecog", subj, blockid, "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"
                ),preload=True,verbose=False)
        else:
            try:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    raw = mne.io.read_raw_fif(os.path.join(
                        data_path, "ecog", f"sub-{subj}", blockid, "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"
                    ),preload=True,verbose=False)
            except:
                raise Exception("Still cannot find raw. (Tried sub- prefix)")
        fs = raw.info['sfreq']
        if not set_picks:
            picks = raw.info['ch_names']
        onset_index, offset_index, id_index = 0,1,2
        eventfile = os.path.join(
            git_path,"preprocessing","events","csv",subj,blockid,
            f"{blockid}_{channel}_{level}_{condition}.txt"
        )
        with open(eventfile,'r') as f:
            r = csv.reader(f,delimiter='\t')
            events = np.array([[np.ceil(float(row[onset_index])*fs).astype(int),
                                np.ceil(float(row[offset_index])*fs).astype(int),
                                int(row[id_index])] for row in r]).astype(int)
        if len(events) > 0:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                epochs.append(mne.Epochs(raw,events,picks=picks,tmin=tmin,tmax=tmax,baseline=baseline,preload=True,verbose=False))
    if len(epochs) > 0:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            return mne.concatenate_epochs(epochs)
    else:
        return None
def epoch_other_events(subj,blocks,git_path,data_path,
                      epoch_type="click",tmin=-.5,tmax=2,baseline=None,set_picks=False,picks=None):
    '''
    Function for non-spkr/mic epochs.
    Supported values for epoch_type: "click", "text"
    '''
    epochs = []
    for b in blocks:
        blockid = "_".join([subj,b])
        if os.path.isfile(os.path.join(
            data_path, "ecog", subj, blockid, "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"
        )):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                raw = mne.io.read_raw_fif(os.path.join(
                    data_path, "ecog", subj, blockid, "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"
                ),preload=True,verbose=False)
        else:
            try:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    raw = mne.io.read_raw_fif(os.path.join(
                        data_path, "ecog", f"sub-{subj}", blockid, "HilbAA_70to150_8band", "ecog_hilbAA70to150.fif"
                    ),preload=True,verbose=False)
            except:
                raise Exception("Still cannot find raw. (Tried sub- prefix)")
        fs = raw.info['sfreq']
        if not set_picks:
            picks = raw.info['ch_names']
        if epoch_type == "click":
            eventfile = os.path.join(git_path,"preprocessing","events","csv",
                                     subj,blockid,f"{blockid}_click_eve.txt")
            if subj not in ['S0026','TCH14']:
                onset_index, offset_index, id_index = 0,2,4
            else:
                onset_index, offset_index, id_index = 0,1,2
        elif epoch_type == "text":
            eventfile = os.path.join(git_path,"preprocessing","events","csv",subj,blockid,
                                     f"{blockid}_display_text.txt")
            onset_index, offset_index, id_index = 0,1,2
        if os.path.isfile(eventfile):
            with open(eventfile,'r') as f:
                r = csv.reader(f,delimiter='\t')
                events = np.array([[np.ceil(float(row[onset_index])*fs).astype(int),
                                    np.ceil(float(row[offset_index])*fs).astype(int),
                                    float(row[id_index])] for row in r]).astype(int)
        else:
            warnings.warn(f"No {epoch_type} events for {blockid}, skipping...")
            events = []
        if len(events) > 0:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                epochs.append(mne.Epochs(raw,events,picks=picks,tmin=tmin,tmax=tmax,baseline=baseline,preload=True,verbose=False))
    if len(epochs) > 0:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            return mne.concatenate_epochs(epochs)
    else:
        return None
def get_sig_chs(subj,git_path,ch_names,
                p_thresh=0.05,nsubjs=16,nboots=1000,debug=False):
    '''
    Returns a subset of ch_names that only has p values below
    the specified threshold.
    '''
    csv_fpath = os.path.join(git_path,"stats","bootstraps",
                             f"seeg_elec_significance_{nsubjs}_subjs_{nboots}_boots.csv")
    pvals_df = pd.read_csv(csv_fpath)
    sig_chs, dropped_chs = [], []
    for ch in ch_names:
        spkr_p = pvals_df.loc[
            (pvals_df["subj"]==subj)&(pvals_df["ch_name"]==ch)
        ]["spkr_p"].values[0]
        mic_p = pvals_df.loc[
            (pvals_df["subj"]==subj)&(pvals_df["ch_name"]==ch)
        ]["mic_p"].values[0]
        if spkr_p < p_thresh or mic_p < p_thresh:
            sig_chs.append(ch)
        else:
            dropped_chs.append(ch)
    if debug:
        print(
            f"{subj}: Dropped {len(dropped_chs)} channels with p<{p_thresh}: {dropped_chs}"
        )
    return sig_chs
def get_elecs_outside_brain(subj,imaging_path,tch_imaging_path):
    in_bolt_fpath = os.path.join(data_path,f"{subj}_complete","elecs",f"{subj}_IN_BOLT.txt")
    elecs_in_bolt = np.loadtxt(in_bolt_fpath, dtype=str, skiprows=1)
    if len(elecs_in_bolt.shape) > 0:
        if elecs_in_bolt.shape[0] != 0:
            elecs_in_bolt = list(elecs_in_bolt)
        else:
            elecs_in_bolt = []
    else:
        elecs_in_bolt = [str(elecs_in_bolt)]
    # Handle naming discrepancies to make ch names from imaging match the fif files
    # Also drop bad channels
    # This is kind of a hack, sorry lol
    if subj == "S0014": # rename + bads
        elecs_in_bolt = [e.replace("MAF-LOF","MAFLOF") for e in elecs_in_bolt if e not in ["STG-HG10","MTG-PH12"]]
    elif subj == "S0015": # bads
        elecs_in_bolt = [e for e in elecs_in_bolt if e not in ['MMF-LOF16','SP-PI15','SP-PI16']]
    elif subj == "S0017": # bads
        elecs_in_bolt = [e for e in elecs_in_bolt if e not in ['ASF-MOF16', 'AMF-LOF14']]
    return elecs_in_bolt

In [7]:
# Epoch the raw data from all participants
excl_df = pd.read_csv(os.path.join(git_path,"analysis","all_excluded_electrodes.csv"))
conditions = ['el','sh','all']
channels = ['spkr', 'mic']
epochs = dict()
for s in tqdm(subjs):
    chs = get_ch_names(s,blocks,git_path,data_path)
    excl_ch_names = list(excl_df.loc[excl_df['subject']==s]['channel'].values)
    picks = [ch for ch in chs if ch not in excl_ch_names]
    print(f"Subj {s} has {len(chs)} channels, picking {len(picks)} of them...")
    if len(picks) > 0:
        epochs[s] = dict()
        for cond in conditions:
            epochs[s][cond] = dict()
            for channel in channels:
                epochs[s][cond][channel] = epoch_data(
                    s, blocks[s], git_path, data_path, channel=channel, level='sn', condition=cond,
                    tmin=tmin, tmax=tmax, baseline=baseline, picks=picks, set_picks=True
                )
                if epochs[s][cond][channel] is not None:
                    print(f"Loaded {len(epochs[s][cond][channel])} epochs.")
                else:
                    print(f"No {channel} {cond} epochs for {s}")
        # Epoch clicks, display_text, display_cross
        epochs[s]['click'] = epoch_other_events(
            s, blocks[s], git_path, data_path, epoch_type="click",
            tmin=tmin, tmax=tmax, baseline=baseline, picks=picks, set_picks=True
        )
        epochs[s]['text'] = epoch_other_events(
            s, blocks[s], git_path, data_path, epoch_type="text",
            tmin=tmin, tmax=tmax, baseline=baseline, picks=picks, set_picks=True
        )

  0%|          | 0/16 [00:00<?, ?it/s]

Subj S0004 has 81 channels, picking 38 of them...
Loaded 93 epochs.
Loaded 95 epochs.
Loaded 93 epochs.
Loaded 92 epochs.
Loaded 186 epochs.
Loaded 187 epochs.
Subj S0006 has 104 channels, picking 29 of them...
Loaded 114 epochs.
Loaded 120 epochs.
Loaded 114 epochs.
Loaded 120 epochs.
Loaded 228 epochs.
Loaded 240 epochs.
Subj S0007 has 96 channels, picking 39 of them...
Loaded 93 epochs.
Loaded 99 epochs.
Loaded 94 epochs.
Loaded 100 epochs.
Loaded 188 epochs.
Loaded 199 epochs.
Subj S0010 has 146 channels, picking 6 of them...
Loaded 81 epochs.
Loaded 88 epochs.
Loaded 84 epochs.
Loaded 91 epochs.
Loaded 165 epochs.
Loaded 179 epochs.




Subj S0014 has 103 channels, picking 37 of them...
Loaded 37 epochs.
Loaded 40 epochs.
Loaded 38 epochs.
Loaded 38 epochs.
Loaded 75 epochs.
Loaded 78 epochs.
Subj S0015 has 121 channels, picking 60 of them...
Loaded 28 epochs.
Loaded 32 epochs.
Loaded 35 epochs.
Loaded 31 epochs.
Loaded 63 epochs.
Loaded 63 epochs.
Subj S0017 has 125 channels, picking 79 of them...
Loaded 47 epochs.
Loaded 50 epochs.
Loaded 45 epochs.
Loaded 45 epochs.
Loaded 92 epochs.
Loaded 95 epochs.
Subj S0018 has 142 channels, picking 46 of them...
Loaded 92 epochs.
Loaded 96 epochs.
Loaded 93 epochs.
Loaded 99 epochs.
Loaded 185 epochs.
Loaded 195 epochs.
Subj S0019 has 100 channels, picking 68 of them...
Loaded 96 epochs.
Loaded 99 epochs.
Loaded 93 epochs.
Loaded 98 epochs.
Loaded 189 epochs.
Loaded 197 epochs.
Subj S0020 has 106 channels, picking 53 of them...
Loaded 95 epochs.
Loaded 100 epochs.
Loaded 94 epochs.
Loaded 100 epochs.
Loaded 189 epochs.
Loaded 200 epochs.
Subj S0021 has 108 channels, picking 1



Subj TCH14 has 255 channels, picking 107 of them...
Loaded 46 epochs.
Loaded 48 epochs.
Loaded 44 epochs.
Loaded 49 epochs.
Loaded 90 epochs.
Loaded 97 epochs.




## Format $X$

We are making a large $n$ samples x $p$ electrodes matrix by concatenating across all subjects. $n$ is event-related to sentence onset and is just all the samples in our tmin-tmax range. Channels are kept separate ($p$) while we average across epochs within $p$. 

<center>
    $X_p = \frac{1}{n}\sum\limits^{n=2}_{n=-1}H\gamma_{p,n}$
</center>
<p></p>
<center>
    $X_s = [X_{p1}|X_{p2}|...|X_{pn}]$
</center>
<p></p>
<center>
    $X = [X_{s1}|X_{s2}|...|X_{sn}]$
</center>

We will make one $X$ per condition (`spkr`, `mic`, `echo`, `shuff`) and store as a `dict`. We'll save this to a Pandas DataFrame as well for easy access.

We also are going to save anatomical information to this `df`. So let's grab that first.

In [15]:
imaging = dict()
for s in tqdm(subjs):
    if s not in no_imaging:
        imaging[s] = dict()
        patient = img_pipe.freeCoG(f"{s}_complete", hem="stereo")
        an = patient.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['anatomy']
        fs_ch_names = [a[0][0].replace("-","") for a in an]
        if s == "S0020":
            # One device for S0020 is named incorrectly so we have to write an exception for it.
            fs_ch_names = [c.replace("APIOF'","AIPOF'") for c in fs_ch_names]
        fif_ch_names = epochs[s]['all']['spkr'].info['ch_names']
        fs_idxs, fif_idxs, dropped_chs = [], [], []
        for fif_i, ch_name in enumerate(fif_ch_names):
            if ch_name.replace("-","") in fs_ch_names:
                fif_idxs.append(fif_i)
                fs_i = fs_ch_names.index(ch_name.replace("-",""))
                fs_idxs.append(fs_i)
            else:
                dropped_chs.append(ch_name)
        if len(dropped_chs) > 0:
            print(
                f"{s}: Dropping these channels as they aren't present in the anat: {', '.join(dropped_chs)}"
            )
        imaging[s]['ch_names'] = list(np.array([a[0][0] for a in an])[fs_idxs])
        imaging[s]['elec_surf'] = patient.get_elecs(elecfile_prefix="TDT_elecs_all_warped")['elecmatrix'][fs_idxs]
        imaging[s]['elec_mri'] = cnmf_utils.tkRAS_to_MNI(imaging[s]['elec_surf'])
        imaging[s]['anat'] = list(np.array([cnmf_utisl.get_anatomy_short(a[3][0]) for a in an])[fs_idxs])
        imaging[s]['fs_idxs'] = fs_idxs
        imaging[s]['fif_idxs'] = fif_idxs

  0%|          | 0/16 [00:00<?, ?it/s]

S0019: Dropping these channels as they aren't present in the anat: LAT2, LAT3, LAT4, LAT5, LPF1, LPF2, LPF3, LPF5, LPF6, LPF7, LPF8


In [17]:
nmf_X = dict()
missing_types = {
    epoch_type:[] for epoch_type in ['spkr','mic','el','sh','click','text']
}
for s in subjs:
    nmf_X[s] = dict()
    for epoch_type in ['spkr', 'mic', 'el', 'sh', 'click', 'text']:
        nmf_X[s][epoch_type] = dict()
        if epoch_type in ['spkr', 'mic']:
            if epochs[s]['all'][epoch_type] is not None:
                if s not in no_imaging:
                    nmf_X[s][epoch_type]['resp'] = epochs[s]['all'][epoch_type].get_data()[:,imaging[s]['fif_idxs'],:].mean(0)
                    nmf_X[s][epoch_type]['sem'] = cnmf_utils.sem(epochs[s]['all'][epoch_type].get_data()[:,imaging[s]['fif_idxs'],:])
                else:
                    # Just use all channels
                    nmf_X[s][epoch_type]['resp'] = epochs[s]['all'][epoch_type].get_data().mean(0)
                    nmf_X[s][epoch_type]['sem'] = cnmf_utils.sem(epochs[s]['all'][epoch_type].get_data())
            else:
                nmf_X[s][epoch_type]['resp'] = None
                nmf_X[s][epoch_type]['sem'] = None
                missing_types[epoch_type].append(s)
        elif epoch_type in ['el', 'sh']:
            if epochs[s][epoch_type]['spkr'] is not None:
                if s not in no_imaging:
                    nmf_X[s][epoch_type]['resp'] = epochs[s][epoch_type]['spkr'].get_data()[:,imaging[s]['fif_idxs'],:].mean(0)
                    nmf_X[s][epoch_type]['sem'] = cnmf_utils.sem(epochs[s][epoch_type]['spkr'].get_data()[:,imaging[s]['fif_idxs'],:])
                else:
                    # Just use all channels
                    nmf_X[s][epoch_type]['resp'] = epochs[s][epoch_type]['spkr'].get_data().mean(0)
                    nmf_X[s][epoch_type]['sem'] = cnmf_utils.sem(epochs[s][epoch_type]['spkr'].get_data())
            else:
                nmf_X[s][epoch_type]['resp'] = None
                nmf_X[s][epoch_type]['sem'] = None
                missing_types[epoch_type].append(s)
        else:
            if epochs[s][epoch_type] is not None:
                if s not in no_imaging:
                    nmf_X[s][epoch_type]['resp'] = epochs[s][epoch_type].get_data()[:,imaging[s]['fif_idxs'],:].mean(0)
                    nmf_X[s][epoch_type]['sem'] = cnmf_utils.sem(epochs[s][epoch_type].get_data()[:,imaging[s]['fif_idxs'],:])
                else:
                    # Just use all channels
                    nmf_X[s][epoch_type]['resp'] = epochs[s][epoch_type].get_data().mean(0)
                    nmf_X[s][epoch_type]['sem'] = cnmf_utils.sem(epochs[s][epoch_type].get_data())
            else:
                nmf_X[s][epoch_type]['resp'] = None
                nmf_X[s][epoch_type]['sem'] = None
                missing_types[epoch_type].append(s)
print(missing_types)

{'spkr': [], 'mic': [], 'el': [], 'sh': ['S0023', 'TCH06'], 'click': [], 'text': ['S0010']}


In [18]:
num_bases = np.arange(2,29)
nmf, percent_variance = dict(), dict()

In [27]:
bname = 'spkrmicclick'
nmf_fpath = os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}.hf5")

In [28]:
all_subjs[bname],resp[bname],stderr[bname],elecs[bname],anat[bname],ch_names[bname] = dict(),dict(),dict(
    ),dict(),dict(),dict()
for epoch_type in ['spkr','mic','click']:
    als, r, se, el, an, cn = [], [], [], [], [], []
    for s in subjs:
        if s not in no_imaging:
            for i in np.arange(len(imaging[s]['fif_idxs'])):
                als.append(s)
                r.append(nmf_X[s][epoch_type]['resp'][i,:])
                se.append((nmf_X[s][epoch_type]['sem'][0][i,:],nmf_X[s][epoch_type]['sem'][1][i,:]))
                el.append(imaging[s]['elec_surf'][i])
                an.append(imaging[s]['anat'][i])
                cn.append(imaging[s]['ch_names'][i])
        else:
            for i in np.arange(nmf_X[s][epoch_type]['resp'].shape[0]):
                als.append(s)
                r.append(nmf_X[s][epoch_type]['resp'][i,:])
                se.append((nmf_X[s][epoch_type]['sem'][0][i,:],nmf_X[s][epoch_type]['sem'][1][i,:]))
                el.append(np.zeros(3))
                an.append('Anatomy unavailable')
                cn.append(epochs[s]['all']['spkr'].info['ch_names'][i])
    all_subjs[bname][epoch_type] = np.hstack((als))
    resp[bname][epoch_type] = np.vstack((r))
    stderr[bname][epoch_type] = np.array(se)
    elecs[bname][epoch_type] = np.vstack((el))
    anat[bname][epoch_type] = np.hstack((an))
    ch_names[bname][epoch_type] = np.hstack((cn))
X = np.hstack((resp[bname]['mic'],resp[bname]['spkr'],resp[bname]['click']))
print(X.shape) # N channels x N times*3

(796, 903)


In [29]:
nmf[bname] = dict()
percent_variance[bname] = []
if not os.path.isfile(nmf_fpath):
    for nb in num_bases:
        nmf[bname][nb] = pymf3.convexNMF(X,W=None,H=None,num_bases=nb)
        nmf[bname][nb].factorize()
        percent_variance[bname].append(cnmf_utils.pve(X, nmf[bname][nb]))
        print(f"{nb} clusters, PVE={percent_variance[bname][-1]*100:.2f} percent variance explained")
    with h5py.File(nmf_fpath, 'w') as f:
        f.create_dataset("/times", data=epochs['S0004']['all']['spkr'].times)
        f.create_dataset("/resp/spkr", data=resp[bname]['spkr'])
        f.create_dataset("/resp/mic", data=resp[bname]['mic'])
        f.create_dataset("/resp/click", data=resp[bname]['click'])
        f.create_dataset("/pve", data=percent_variance[bname])
        f.create_dataset("/num_bases", data=num_bases)
        f.create_dataset("/elecs", data=elecs[bname]['spkr'])
        for nb in num_bases:
            f.create_dataset(f"/{nb}_bases/W", data=nmf[bname][nb].W)
            f.create_dataset(f"/{nb}_bases/H", data=nmf[bname][nb].H)
        # h5py can't save strings for some fucked up reason so we will save it as a separate textfile
        np.savetxt(os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}_all_subjs.txt"
                               ), all_subjs[bname]['spkr'], fmt="%s")
        np.savetxt(os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}_anat.txt"
                               ), anat[bname]['spkr'], fmt="%s")
        np.savetxt(os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}_ch_names.txt"
                               ), ch_names[bname]['spkr'], fmt="%s")
else:
    print("Loading NMF from file")
    for nb in num_bases:
        nmf[bname][nb] = pymf3.convexNMF(X,W=None,H=None,num_bases=nb)
        with h5py.File(nmf_fpath,'r') as f:
            nmf[bname][nb].W = f[f"/{nb}_bases/W"][:]
            nmf[bname][nb].H = f[f"/{nb}_bases/H"][:]
            percent_variance[bname].append(f["/pve"][nb-2])

2 clusters, PVE=63.32 percent variance explained
3 clusters, PVE=76.98 percent variance explained
4 clusters, PVE=78.42 percent variance explained
5 clusters, PVE=79.95 percent variance explained
6 clusters, PVE=81.24 percent variance explained
7 clusters, PVE=83.28 percent variance explained
8 clusters, PVE=84.82 percent variance explained
9 clusters, PVE=85.61 percent variance explained
10 clusters, PVE=86.12 percent variance explained
11 clusters, PVE=86.64 percent variance explained
12 clusters, PVE=87.04 percent variance explained
13 clusters, PVE=87.43 percent variance explained
14 clusters, PVE=87.46 percent variance explained
15 clusters, PVE=87.87 percent variance explained
16 clusters, PVE=88.38 percent variance explained
17 clusters, PVE=88.30 percent variance explained
18 clusters, PVE=88.37 percent variance explained
19 clusters, PVE=89.19 percent variance explained
20 clusters, PVE=89.30 percent variance explained
21 clusters, PVE=89.53 percent variance explained
22 clust

In [30]:
bname = 'micelshuffclick'
nmf_fpath = os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}.hf5")

In [31]:
all_subjs[bname],resp[bname],stderr[bname],elecs[bname],anat[bname],ch_names[bname] = dict(),dict(),dict(
    ),dict(),dict(),dict()
for epoch_type in ['mic','el','sh','click']:
    als, r, se, el, an, cn = [], [], [], [], [], []
    for s in subjs:
        if s not in np.hstack((missing_types['mic'], missing_types['el'], missing_types['sh'])):
            if s not in no_imaging:
                for i in np.arange(len(imaging[s]['fif_idxs'])):
                    als.append(s)
                    r.append(nmf_X[s][epoch_type]['resp'][i,:])
                    se.append((nmf_X[s][epoch_type]['sem'][0][i,:],nmf_X[s][epoch_type]['sem'][1][i,:]))
                    el.append(imaging[s]['elec_surf'][i])
                    an.append(imaging[s]['anat'][i])
                    cn.append(imaging[s]['ch_names'][i])
            else:
                for i in np.arange(nmf_X[s][epoch_type]['resp'].shape[0]):
                    als.append(s)
                    r.append(nmf_X[s][epoch_type]['resp'][i,:])
                    se.append((nmf_X[s][epoch_type]['sem'][0][i,:],nmf_X[s][epoch_type]['sem'][1][i,:]))
                    el.append(np.zeros(3))
                    an.append('Anatomy unavailable')
                    cn.append(epochs[s]['all']['spkr'].info['ch_names'][i])
    all_subjs[bname][epoch_type] = np.hstack((als))
    resp[bname][epoch_type] = np.vstack((r))
    stderr[bname][epoch_type] = np.array(se)
    elecs[bname][epoch_type] = np.vstack((el))
    anat[bname][epoch_type] = np.hstack((an))
    ch_names[bname][epoch_type] = np.hstack((cn))
X = np.hstack((resp[bname]['mic'],resp[bname]['el'],resp[bname]['sh'],resp[bname]['click']))
print(X.shape)

(667, 1204)


In [32]:
nmf[bname] = dict()
percent_variance[bname] = []
if not os.path.isfile(nmf_fpath):
    for nb in num_bases:
        nmf[bname][nb] = pymf3.convexNMF(X,W=None,H=None,num_bases=nb)
        nmf[bname][nb].factorize()
        percent_variance[bname].append(cnmf_utils.pve(X, nmf[bname][nb]))
        print(f"{nb} clusters, PVE={percent_variance[bname][-1]*100:.2f} percent variance explained")
    with h5py.File(nmf_fpath, 'w') as f:
        f.create_dataset("/times", data=epochs['S0004']['all']['spkr'].times)
        f.create_dataset("/resp/mic", data=resp[bname]['mic'])
        f.create_dataset("/resp/el", data=resp[bname]['el'])
        f.create_dataset("/resp/sh", data=resp[bname]['sh'])
        f.create_dataset("/resp/click", data=resp[bname]['click'])
        f.create_dataset("/pve", data=percent_variance[bname])
        f.create_dataset("/num_bases", data=num_bases)
        f.create_dataset("/elecs", data=elecs[bname]['mic'])
        for nb in num_bases:
            f.create_dataset(f"/{nb}_bases/W", data=nmf[bname][nb].W)
            f.create_dataset(f"/{nb}_bases/H", data=nmf[bname][nb].H)
        # h5py can't save strings for some fucked up reason so we will save it as a separate textfile
        np.savetxt(os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}_all_subjs.txt"
                               ), all_subjs[bname]['spkr'], fmt="%s")
        np.savetxt(os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}_anat.txt"
                               ), anat[bname]['spkr'], fmt="%s")
        np.savetxt(os.path.join(git_path,"analysis","cnmf","h5",f"NMF_grouped_{bname}_ch_names.txt"
                               ), ch_names[bname]['spkr'], fmt="%s")
else:
    print("Loading NMF from file")
    for nb in num_bases:
        nmf[bname][nb] = pymf3.convexNMF(X,W=None,H=None,num_bases=nb)
        with h5py.File(nmf_fpath,'r') as f:
            nmf[bname][nb].W = f[f"/{nb}_bases/W"][:]
            nmf[bname][nb].H = f[f"/{nb}_bases/H"][:]
            percent_variance[bname].append(f["/pve"][nb-2])

2 clusters, PVE=62.50 percent variance explained
3 clusters, PVE=73.17 percent variance explained
4 clusters, PVE=77.68 percent variance explained
5 clusters, PVE=79.34 percent variance explained
6 clusters, PVE=81.40 percent variance explained
7 clusters, PVE=82.54 percent variance explained
8 clusters, PVE=83.48 percent variance explained
9 clusters, PVE=84.42 percent variance explained
10 clusters, PVE=85.40 percent variance explained
11 clusters, PVE=85.88 percent variance explained
12 clusters, PVE=86.10 percent variance explained
13 clusters, PVE=86.34 percent variance explained
14 clusters, PVE=87.14 percent variance explained
15 clusters, PVE=87.18 percent variance explained
16 clusters, PVE=87.15 percent variance explained
17 clusters, PVE=87.47 percent variance explained
18 clusters, PVE=87.90 percent variance explained
19 clusters, PVE=88.16 percent variance explained
20 clusters, PVE=87.86 percent variance explained
21 clusters, PVE=88.24 percent variance explained
22 clust