In [1]:
import torch
from torch.utils.data import DataLoader
import mne
import os
import numpy as np
import pickle
from model import *
from dataset import *
from utils import *
from train import *

# Check if GPU is available
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
device

'cuda'

In [2]:
train_story_list = dict(
    Moth1=["souls", "avatar", "legacy", "odetostepfather"],
    Moth2=["howtodraw", "myfirstdaywiththeyankees", "naked", "life"],
    Moth3=["tildeath", "fromboyhoodtofatherhood", "sloth", "exorcism"],
    Moth4=["adollshouse", "inamoment", "theclosetthatateeverything", "adventuresinsayingyes", "haveyoumethimyet"],
    Moth5=["thatthingonmyarm", "eyespy", "itsabox", "hangtime"],
)
train_stories = [story for session in train_story_list.keys() for story in train_story_list[session]]
val_stories = [["swimmingwithastronauts1", "swimmingwithastronauts2"]]

In [3]:
# let test stories be all 27 stories
with open("data/story_to_uniquestory.pkl", "rb") as f:
    story_to_uniquestory = pickle.load(f)
test_stories = list(story_to_uniquestory.keys())
test_stories = [s for s in test_stories if s != "stimuli_auditory_localizer" and not s[-1] in ["2", "3", "4", "5"]]
test_stories

['alternateithicatom1',
 'souls',
 'wheretheressmoke1',
 'avatar',
 'legacy',
 'odetostepfather',
 'undertheinfluence1',
 'howtodraw',
 'myfirstdaywiththeyankees',
 'naked',
 'life',
 'stagefright1',
 'tildeath',
 'fromboyhoodtofatherhood',
 'sloth',
 'exorcism',
 'buck1',
 'adollshouse',
 'inamoment',
 'theclosetthatateeverything',
 'adventuresinsayingyes',
 'haveyoumethimyet',
 'swimmingwithastronauts1',
 'thatthingonmyarm',
 'eyespy',
 'itsabox',
 'hangtime']

In [4]:
# config
name = "WdPnFq-seg8-flexconv4-A"
which = "val-loss-min"

with open(f"config/{name}.yaml", "r") as f:
    config = yaml.safe_load(f)
MEG_SUBJECT = config["MEG_SUBJECT"]
FMRI_SUBJECT = config["FMRI_SUBJECT"]
use_segment = config["use_segment"]
spacing = config["spacing"]
meg_loss_weight = config["meg_loss_weight"]
fmri_loss_weight = config["fmri_loss_weight"]
softmax_T = config["softmax_T"]
dataset_params = config["dataset"]
model_params = config["model"]

## Dataset

In [5]:
train_dataset = StoryDataset(
    MEG_SUBJECT,
    FMRI_SUBJECT,
    train_stories[0:1],
    name=name,
    spacing=spacing,
    preload=False,
    **dataset_params,
)  # evalulate don't care this
val_dataset = StoryDataset(
    MEG_SUBJECT,
    FMRI_SUBJECT,
    val_stories,
    name=name,
    spacing=spacing,
    pca_meg=train_dataset.pca_meg,
    pca_mri=train_dataset.pca_mri,
    **dataset_params,
)
test_dataset = StoryDataset(
    MEG_SUBJECT,
    FMRI_SUBJECT,
    test_stories,
    name=name,
    spacing=spacing,
    pca_meg=train_dataset.pca_meg,
    pca_mri=train_dataset.pca_mri,
    MEG_DIR="moth_meg",
    **dataset_params,
)
embed_dim = test_dataset.embed_dim

use_word: True, use_phoneme: True, use_freq: True, use_meg: False, use_mri: False
embed_dim:  852
use_word: True, use_phoneme: True, use_freq: True, use_meg: False, use_mri: False
embed_dim:  852
Loading story swimmingwithastronauts1...


Finished loading story swimmingwithastronauts1!
Loading story swimmingwithastronauts2...
Finished loading story swimmingwithastronauts2!
Preloaded all stories!
use_word: True, use_phoneme: True, use_freq: True, use_meg: False, use_mri: False
embed_dim:  852
Loading story alternateithicatom1...
Finished loading story alternateithicatom1!
Loading story souls...
Finished loading story souls!
Loading story wheretheressmoke1...
Finished loading story wheretheressmoke1!
Loading story avatar...
Finished loading story avatar!
Loading story legacy...
Finished loading story legacy!
Loading story odetostepfather...
Finished loading story odetostepfather!
Loading story undertheinfluence1...
Finished loading story undertheinfluence1!
Loading story howtodraw...
Finished loading story howtodraw!
Loading story myfirstdaywiththeyankees...
Finished loading story myfirstdaywiththeyankees!
Loading story naked...
Finished loading story naked!
Loading story life...
Finished loading story life!
Loading story

## Lead Field

In [6]:
# load forward solution to get the lead field
fname_fwd = f"data/{MEG_SUBJECT}-{spacing}-fwd.fif"
fwd = mne.read_forward_solution(fname_fwd)
fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True)  # let's do fixed orientation
lead_field = fwd_fixed["sol"]["data"]
lead_field = torch.from_numpy(lead_field)
n_channels, n_neurons = lead_field.shape
n_channels, n_neurons

Reading forward solution from /home/yishuli/MEG-fMRI/data/A-oct6-fwd.fif...
    Reading a source space...
    Computing patch statistics...
    Patch information added...
    [done]
    Reading a source space...
    Computing patch statistics...
    Patch information added...
    [done]
    2 source spaces read
    Desired named matrix (kind = 3523 (FIFF_MNE_FORWARD_SOLUTION_GRAD)) not available
    Read MEG forward solution (8196 sources, 306 channels, free orientations)
    Source spaces transformed to the forward solution coordinate frame
    Average patch normals will be employed in the rotation to the local surface coordinates....
    Converting to surface-based source orientations...
    [done]


(306, 8196)

## Get Sources

In [7]:
model = TransformerSourceModel(
    embed_dim=embed_dim,
    lead_field=lead_field,
    **model_params,
).to(device)
ckpt_path = f"trained_models/{name}_{which}.pth"
info = load_checkpoint(ckpt_path, model, None, None, device=device)
print(info)

(68, 1.0055146501442056, 0.07546357089945342, 0.9299632706765684)


In [8]:
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
with torch.no_grad():
    validate(
        model,
        val_dataloader,
        device,
        subject=MEG_SUBJECT,
        meg_loss_weight=meg_loss_weight,
        fmri_loss_weight=fmri_loss_weight,
        softmax_T=softmax_T,
    )

val_meg_loss: 0.07466866387896606, val_fmri_loss: 0.9308459862652395


In [9]:
model.eval()
neurons_dict, neurons_power_dict = {}, {}
for i in range(len(test_dataset)):
    # load data
    embeds, meg, fmri = test_dataset[i]
    embeds = embeds.to(device)
    # foward pass
    with torch.no_grad():
        neurons, _, _ = model(embeds.unsqueeze(0))
        neurons = neurons.squeeze(0).detach().cpu()
        # calculate power
        neurons_power = hilbert_torch(neurons, dim=0)
        neurons_power = neurons_power.abs()
        neurons_power.pow_(2)
        # demean
        neurons = (neurons.abs() - neurons.abs().mean(dim=0, keepdims=True)).numpy()
        neurons_power = (neurons_power - neurons_power.mean(dim=0, keepdims=True)).numpy()
    neurons_dict[test_stories[i]] = neurons
    neurons_power_dict[test_stories[i]] = neurons_power
    # save memory
    del embeds
    torch.cuda.empty_cache()

## Word Onsets

In [10]:
# load grid transcript
with open("data/moth_word_surprisal_context20.pkl", "rb") as fp:
    story_word_surprisal_dict = pickle.load(fp)

In [11]:
all_words_dict = {}
for key in story_word_surprisal_dict.keys():
    if key == "stimuli_auditory_localizer":
        continue
    for t in story_word_surprisal_dict[key]:
        word = t[2].lower().strip()
        if t[-1] == "NOUN":
            if word not in all_words_dict:
                all_words_dict[word] = 0
            else:
                all_words_dict[word] += 1

In [12]:
# all_words_dict sort by values
all_words_dict = dict(sorted(all_words_dict.items(), key=lambda item: item[1], reverse=True))
all_words_dict

{'i': 1758,
 'uh': 181,
 'um': 149,
 'time': 125,
 'day': 94,
 'thing': 93,
 'kind': 90,
 'people': 82,
 'home': 78,
 'something': 76,
 'things': 74,
 'years': 72,
 'way': 69,
 'life': 66,
 'right': 66,
 'one': 51,
 'house': 50,
 'job': 46,
 'school': 46,
 'guy': 45,
 'family': 45,
 'room': 45,
 'oh': 45,
 'man': 44,
 'moment': 43,
 'mother': 43,
 'water': 43,
 'everything': 41,
 'mom': 39,
 'phone': 38,
 'cause': 38,
 'stuff': 38,
 'lot': 36,
 'car': 36,
 'friends': 35,
 'night': 35,
 'anything': 35,
 'year': 35,
 'door': 34,
 'twenty': 33,
 'whole': 32,
 'sort': 32,
 'dad': 31,
 'everybody': 29,
 'girl': 29,
 'head': 29,
 'look': 29,
 'president': 29,
 'hand': 28,
 'yeah': 27,
 'parents': 27,
 'york': 26,
 'work': 26,
 'boy': 26,
 'birthday': 26,
 'okay': 26,
 'thought': 25,
 'story': 25,
 'someone': 25,
 'stage': 25,
 'god': 24,
 'thirty': 24,
 'father': 24,
 'bat': 24,
 'couple': 23,
 'children': 23,
 'book': 23,
 'front': 23,
 'call': 23,
 'point': 23,
 'yes': 22,
 'morning': 22,


In [13]:
# load grid transcript
with open("data/moth_phon_words.pkl", "rb") as fp:
    _ = pickle.load(fp)
    grid_transcript_words = pickle.load(fp)

# loading textgrids
words = dict()
word_onset = dict()
removed_time = 0
DEFAULT_BAD_WORDS = ["sentence_start", "sentence_end", "{SL}", "{{BR}", "{BR}", "(BR}", "{BR", "{LG}", "{ls}", "{LS}", "{IG}", "{CG}", "{LS)", "{NS}", "{NS_AP}", "{SP}", "sp", "", " "]
special_words = ["time", "day", "people", "home", "year", "years", "life", "house", "job", "school", "room", "phone", "door", "hand"]
cnt = 0

for this_story in test_stories:
    this_story_unique = story_to_uniquestory[this_story]
    transcript_words = grid_transcript_words[this_story_unique]
    # correct delay
    time_features = [(float(tp[0]), float(tp[1]), tp[2]) for tp in transcript_words]
    time_features_corrected = get_stretched_features(time_features, MEG_SUBJECT, None, None, use_mean_rate=True)
    # remove bad words!
    words_this_story, word_onset_this_story = [], []
    for t in time_features_corrected:
        if t[2].lower() in special_words:
            words_this_story.append(t[2])
            word_onset_this_story.append(float(t[0]) - removed_time)
            cnt += 1
    words[this_story] = words_this_story
    word_onset[this_story] = np.array(word_onset_this_story)
cnt

854

In [14]:
# example
print(f"first five words of alternateithicatom:", words["alternateithicatom1"][:15])
print("onset times:", word_onset["alternateithicatom1"][:15])

first five words of alternateithicatom: ['TIME', 'TIME', 'JOB', 'YEARS', 'JOB', 'JOB', 'JOB', 'JOB', 'JOB', 'PEOPLE', 'TIME', 'LIFE', 'LIFE', 'LIFE', 'TIME']
onset times: [ 15.54472169  40.09085792  72.04375831  72.89224207  75.75712255
  80.78813217  87.31646608 117.19307679 159.64721155 246.97116425
 254.7772149  311.42598394 343.66836703 354.7186203  398.72997232]


In [15]:
# define lags
sfreq = 50
step = int(1000 / sfreq)  # 20ms
lag_sample_start = -10
lag_sample_end = 50
lag_samples = np.arange(lag_sample_start, lag_sample_end + 1)
lags = lag_samples * step

In [16]:
# get neuron response for each word onset
neurons_onset_story, neurons_power_onset_story = [], []
for story in test_stories:
    word_onset_story = word_onset[story]
    neurons_story = neurons_dict[story]
    neurons_power_story = neurons_power_dict[story]
    for t in word_onset_story:
        t_sample = int(t * 50)
        t_sample_start = t_sample + lag_sample_start
        t_sample_end = t_sample + lag_sample_end + 1
        if t_sample_start > 0 and t_sample_end < neurons_story.shape[0]:
            neurons_onset_story.append(neurons_story[t_sample_start:t_sample_end, :])
            neurons_power_onset_story.append(neurons_power_story[t_sample_start:t_sample_end, :])
            cnt += 1
neurons_onset_story = np.mean(neurons_onset_story, axis=0)
neurons_power_onset_story = np.mean(neurons_power_onset_story, axis=0)
neurons_onset_story.shape

(61, 8196)