In [1]:
import os
import pickle
import numpy as np
import torch
from tqdm import tqdm
from transformers import BertTokenizerFast, BertModel
from preprocessing import lanczosinterp2D, make_delayed

In [2]:
raw_text_path = "../../../shared/data/raw_text.pkl"
output_dir = "../results/embeddings/bert_XY"
os.makedirs(output_dir, exist_ok=True)

# Load tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [3]:
# ## Embedding extraction function

def extract_bert_embeddings(seq, model, tokenizer, chunk_size=128, stride=64, hidden_size=768):
    device = next(model.parameters()).device
    text_words = seq.data
    total_words = len(text_words)
    word_embeddings = [None] * total_words

    for start in range(0, total_words, stride):
        chunk_words = text_words[start:start + chunk_size]
        tokens = tokenizer(
            chunk_words,
            is_split_into_words=True,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=chunk_size,
            return_attention_mask=True,
            return_token_type_ids=True
        )

        input_ids = tokens["input_ids"]
        token_type_ids = tokens["token_type_ids"]
        attention_mask = tokens["attention_mask"]

        with torch.no_grad():
            outputs = model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask
            )
            hidden_states = outputs.last_hidden_state.squeeze(0).cpu()

        word_ids = tokens.word_ids(batch_index=0)
        for token_idx, word_idx in enumerate(word_ids):
            if word_idx is None:
                continue
            abs_word_idx = start + word_idx
            if abs_word_idx >= total_words:
                continue
            if word_embeddings[abs_word_idx] is None:
                word_embeddings[abs_word_idx] = []
            word_embeddings[abs_word_idx].append(hidden_states[token_idx])

    for i in range(total_words):
        if word_embeddings[i] is None:
            word_embeddings[i] = torch.zeros(hidden_size)
        else:
            word_embeddings[i] = torch.stack(word_embeddings[i]).mean(0)

    return torch.stack(word_embeddings).numpy()

In [None]:
# Old version
'''
with open(raw_text_path, "rb") as f:
    raw_texts = pickle.load(f)

for story_id, seq in tqdm(raw_texts.items(), desc="Extracting BERT embeddings"):
    try:
        emb = extract_bert_embeddings(seq, model, tokenizer)
        X_interp = lanczosinterp2D(emb, seq.data_times, seq.tr_times)

        # Trim start and end
        TR = np.mean(np.diff(seq.tr_times))
        n_skip_start = int(np.ceil(5 / TR))
        n_skip_end = int(np.ceil(10 / TR))
        X_interp = X_interp[n_skip_start:-n_skip_end]

        # Create lag features
        X_delayed = make_delayed(X_interp, [1, 2, 3, 4])

        np.save(os.path.join(output_dir, f"{story_id}.npy"), X_delayed)
    except Exception as e:
        print(f"⚠ Skipping {story_id}: {e}")

print("Done: Embeddings saved to", output_dir)
'''

In [4]:
with open(raw_text_path, "rb") as f:
    raw_texts = pickle.load(f)

for story_id, seq in tqdm(raw_texts.items(), desc="Extracting BERT embeddings"):
    try:
        emb = extract_bert_embeddings(seq, model, tokenizer)
        X_interp = lanczosinterp2D(emb, seq.data_times, seq.tr_times)

        # Trim start and end
        TR = np.mean(np.diff(seq.tr_times))
        n_skip_start = int(np.ceil(5 / TR))
        n_skip_end = int(np.ceil(10 / TR))
        X_interp = X_interp[n_skip_start:-n_skip_end]

        # Create lag features
        X_delayed = make_delayed(X_interp, [1, 2, 3, 4])

        y_path = os.path.join("../../../shared/data/subject2", f"{story_id}.npy")
        Y = np.load(y_path)
        Y = Y[n_skip_start:-n_skip_end]
        Y_delayed = make_delayed(Y, [1, 2, 3, 4])

        np.save(os.path.join(output_dir, f"{story_id}_X.npy"), X_delayed)
        np.save(os.path.join(output_dir, f"{story_id}_Y.npy"), Y_delayed)

    except Exception as e:
        print(f"⚠️ Skipping {story_id}: {e}")

print("Done: Embeddings saved to", output_dir)

Extracting BERT embeddings:  11%|█         | 12/109 [00:22<02:37,  1.62s/it]

⚠️ Skipping dialogue4: [Errno 2] No such file or directory: '../../../shared/data/subject2/dialogue4.npy'


Extracting BERT embeddings:  27%|██▋       | 29/109 [00:59<02:37,  1.97s/it]

⚠️ Skipping myfirstdaywiththeyankees: [Errno 2] No such file or directory: '../../../shared/data/subject2/myfirstdaywiththeyankees.npy'


Extracting BERT embeddings:  46%|████▌     | 50/109 [01:44<01:34,  1.60s/it]

⚠️ Skipping dialogue2: [Errno 2] No such file or directory: '../../../shared/data/subject2/dialogue2.npy'


Extracting BERT embeddings:  60%|█████▉    | 65/109 [02:50<06:06,  8.32s/it]

⚠️ Skipping dialogue1: [Errno 2] No such file or directory: '../../../shared/data/subject2/dialogue1.npy'


Extracting BERT embeddings:  62%|██████▏   | 68/109 [03:34<06:31,  9.56s/it]

⚠️ Skipping dialogue5: [Errno 2] No such file or directory: '../../../shared/data/subject2/dialogue5.npy'


Extracting BERT embeddings:  67%|██████▋   | 73/109 [04:34<04:48,  8.01s/it]

⚠️ Skipping onlyonewaytofindout: [Errno 2] No such file or directory: '../../../shared/data/subject2/onlyonewaytofindout.npy'


Extracting BERT embeddings:  94%|█████████▎| 102/109 [11:41<01:32, 13.17s/it]

⚠️ Skipping dialogue3: [Errno 2] No such file or directory: '../../../shared/data/subject2/dialogue3.npy'


Extracting BERT embeddings:  95%|█████████▌| 104/109 [11:44<00:36,  7.25s/it]

⚠️ Skipping dialogue6: [Errno 2] No such file or directory: '../../../shared/data/subject2/dialogue6.npy'


Extracting BERT embeddings: 100%|██████████| 109/109 [12:38<00:00,  6.96s/it]

Done: Embeddings saved to ../results/embeddings/bert_XY



