# Get BERT Hidden States

In [1]:
# from transformers import AutoTokenizer, AutoModelForMaskedLM

# tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
# model = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base")

# model


In [2]:
# # prepare input
# text = "Replace me by any text you'd like BERT to be fine-tuned on"
# encoded_input = tokenizer(text, return_tensors="pt")

# # forward pass
# output = model(**encoded_input, output_hidden_states=True)

In [3]:
# last_hidden_state = output.hidden_states[-1] # shape: batch_size, sequence_length, hidden_size
# h_s = output.hidden_states # tuple of length 13

# # Extract the CLS embedding
# cls_embedding = last_hidden_state[:, 0, :]

# last_hidden_state.shape, cls_embedding.shape, len(h_s)

# Get Sentences from Tree


In [4]:
from nltk.tree import Tree
import pandas as pd
from pathlib import Path


In [5]:
def extract_words_from_tree(tree):
    words = []
    if isinstance(tree, str):  # Base case: leaf node (word)
        return [tree]
    
    elif isinstance(tree, Tree):
        for subtree in tree:
            words.extend(extract_words_from_tree(subtree))
            
    #sentence = " ".join(words)
    return words


def extract_sent_list_from_tree_file(PATH):
    
    with open(PATH, "r", encoding='utf-8') as f:
        lines = f.readlines()
    
    sentences = []
    counter = 0
    for i, line in enumerate(lines):
        line = line.strip()
        try:
            tree = Tree.fromstring(line)
        except ValueError:
            try: # remove last ')'
                tree = Tree.fromstring(line[:-1])
                
            except ValueError:
                counter += 1
                print(f"=== ValueError: line {i} \n {line} ===")
                continue
        words = extract_words_from_tree(tree)
        sentences.append(words)
        
    print(f'Errors: {counter}')
    return sentences
    
sentences_EN = extract_sent_list_from_tree_file('/project/gpuuva021/shared/FMRI-Data/annotation/EN/lppEN_tree.csv')
sentences_CN = extract_sent_list_from_tree_file('/project/gpuuva021/shared/FMRI-Data/annotation/CN/lppCN_tree.csv')
sentences_FR = extract_sent_list_from_tree_file('/project/gpuuva021/shared/FMRI-Data/annotation/FR/lppFR_tree.csv')


Errors: 0
Errors: 0
Errors: 0


In [6]:
# TREE TEST
line =  '(S (VN (CLS je) (ADV ne) (CLO le) (V savais)) (ADV pas))'
tree = Tree.fromstring(line)
words = extract_words_from_tree(tree)
words

['je', 'ne', 'le', 'savais', 'pas']

# Get Sentence infos from csv file

In [7]:
PATH = Path("/project/gpuuva021/shared/FMRI-Data")
language = "EN"

# /project/gpuuva021/shared/FMRI-Data/annotation/EN/lppEN_word_information.csv

In [8]:
df = pd.read_csv(
            PATH / f"annotation/{language}/lpp{language}_word_information.csv"
        )

df_list = df['word'].tolist()
df.head()

Unnamed: 0.1,Unnamed: 0,word,lemma,onset,offset,logfreq,pos,section,top_down,bottom_up,left_corner
0,0,once,once,0.113,0.728,5.824406,ADV,1,3,1,2
1,1,when,when,0.728,0.919,7.562214,ADV,1,2,2,2
2,2,i,i,0.919,1.025,8.500759,PRON,1,3,2,3
3,3,was,was,1.025,1.158,8.848911,AUX,1,2,1,2
4,4,six,six,1.158,1.464,5.20894,NUM,1,3,1,2


In [9]:
sentences_EN[0]

['once',
 'when',
 'i',
 'was',
 'six',
 'years',
 'old',
 'i',
 'saw',
 'a',
 'magnificent',
 'picture',
 'in',
 'a',
 'book',
 'about',
 'the',
 'primeval',
 'forest',
 'called',
 'real',
 'life',
 'stories']

In [10]:
# after the last word of each sentence, add # to indicate the end of sentence
for i, sent in enumerate(sentences_EN):
    sentences_EN[i][-1] = sent[-1] + "#" # replace the last word with the word + #

# flatten the list of lists of words into a list of words
words = [item for sublist in sentences_EN for item in sublist]

words[-30:]


['not',
 'respond',
 'when',
 'questioned',
 'you',
 'will',
 'easily',
 'guess',
 'who',
 'it',
 'is#',
 'so',
 'think',
 'of',
 'me#',
 'save',
 'me',
 'from',
 'this',
 'sorrow#',
 'write',
 'to',
 'me',
 'quickly',
 'to',
 'tell',
 'me',
 'he',
 'is',
 'back#']

In [12]:
len(words), len(df_list)

(15376, 15376)

In [16]:
# integrate back into df
df['word'] = words

# keep only relevant columns
df = df[['word', 'onset', 'offset', 'section']]
df.head()

Unnamed: 0,word,onset,offset,section
0,once,0.113,0.728,1
1,when,0.728,0.919,1
2,i,0.919,1.025,1
3,was,1.025,1.158,1
4,six,1.158,1.464,1


In [25]:
# extract as lists
onsets, offsets, sections = df['onset'].tolist(), df['offset'].tolist(), df['section'].tolist()

# create list of dicts
data = []
sentence = ""
temp_onsets, temp_offsets, temp_sections = [], [], []
for i, word in enumerate(words):
    sentence = sentence + word + " "
    temp_offsets.append(offsets[i])
    temp_onsets.append(onsets[i])
    temp_sections.append(sections[i])
    
    
    if word[-1] == "#":
        data.append({"sentence": sentence[:-2] + '.', "onset": temp_onsets[0], "offset": temp_offsets[-1], "section": sections[0]})
        sentence, temp_onsets, temp_offsets, temp_sections = "", [], [], [] # reset
        
data[:2]      

[{'sentence': 'once when i was six years old i saw a magnificent picture in a book about the primeval forest called real life stories.',
  'onset': 0.113,
  'offset': 8.339,
  'section': 1},
 {'sentence': 'it showed a boa constrictor swallowing a wild animal.',
  'onset': 9.247,
  'offset': 12.416,
  'section': 1}]

## (Get Sentences from Dependency File)



In [11]:
import regex as re

def sentences_from_dep_file(PATH):
    with open(PATH, "r", encoding='utf-8') as f:
        lines = f.readlines()
        
    sentences = ['']
    
    print(len(lines))
    
    for line in lines:
        if line != '\n':
            line = line.strip()
            # print(line)
            
            word = re.search(r",\s*(.*?)\s*-\d+", line)
            nr = re.search(r"-(\d+\))", line)
            
            if word:
                result_word = word.group(1).strip()
                result_word = result_word + " "
            else:
                print("No word found.")
                print(line)
                continue
                
            if nr:
                result_nr = nr.group(1).strip()
            else:
                
                if line[:-3] == f"''')":
                    continue
                else:
                    print("no nr found.")
                    print(line)
                    print(len(line))
                    print(line[:-4])
                    continue
        
            # print(result_word, result_nr)
            
            
            if result_nr == '1)':
                # print(f'FIrst sent first word: {result_word}')
                sentences.append(result_word)
                
                
            else:
                # print(f'{result_word}')
                sentences[-1] = sentences[-1] + result_word
              
    return sentences

sents=sentences_from_dep_file('/project/gpuuva021/shared/FMRI-Data/annotation/FR/lppFR_dependency.csv')
sents
len(sents)

17201
no nr found.
conj:et(laisser-8, intéresser-22''')
36
conj:et(laisser-8, intéresser-22
no nr found.
conj:et(intéresser-22, intéresser-22''')
40
conj:et(intéresser-22, intéresser-22
no nr found.
nmod:de(bridge-5, golf-7''')
28
nmod:de(bridge-5, golf-7
no nr found.
conj:et(golf-7, golf-7''')
26
conj:et(golf-7, golf-7
no nr found.
conj:ou(question-6, question-6''')
34
conj:ou(question-6, question-6
no nr found.
advcl(hésite-2, tâtonne-12''')
30
advcl(hésite-2, tâtonne-12
no nr found.
conj:et(tâtonne-12, tâtonne-12''')
34
conj:et(tâtonne-12, tâtonne-12
no nr found.
obl:comme(avait-13, planètes-18''')
35
obl:comme(avait-13, planètes-18
no nr found.
conj:et(planètes-18, planètes-18''')
36
conj:et(planètes-18, planètes-18
no nr found.
obl:arg(agit-4, brindille-7''')
31
obl:arg(agit-4, brindille-7
no nr found.
conj:ou(brindille-7, brindille-7''')
36
conj:ou(brindille-7, brindille-7
no nr found.
nmod:de(marteau-12, boulon-15''')
33
nmod:de(marteau-12, boulon-15
no nr found.
conj:et(boulon-

1366