### Part 1

In [1]:
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader
import torch
import pickle
from sklearn.model_selection import train_test_split

from data import TextDataset
from encoder import Encoder
from train_encoder import train_bert

In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

## open the data
with open("../../../shared/data/raw_text.pkl", "rb") as f:
    raw_texts = pickle.load(f)

# Combine all story content into one big list of strings
all_texts = []

# Pick a sample story
sample_story = raw_texts['avatar']
all_stories = {}  # Dict[story_id: str]

for story_id, sequence in raw_texts.items():
    try:
        full_story = " ".join(sequence.data)
        all_stories[story_id] = full_story
    except Exception as e:
        print(f"Skipping {story_id}: {e}")



In [4]:
story_ids = list(all_stories.keys())

# 2) split IDs
train_ids, val_ids = train_test_split(
    story_ids,
    test_size=0.1,
    random_state=42
)

# 3) build lists of strings
train_texts = [all_stories[sid] for sid in train_ids]
val_texts   = [all_stories[sid] for sid in val_ids]

In [5]:
tokenizer    = BertTokenizerFast.from_pretrained("bert-base-uncased")
train_ds      = TextDataset(train_texts, tokenizer, max_len=128)
val_ds        = TextDataset(val_texts,   tokenizer, max_len=128)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=1, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=1, pin_memory=True)

In [6]:
import torch
import matplotlib.pyplot as plt
import pickle
from encoder import Encoder
from train_encoder import train_bert
import os


save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True) 

# ─── 1) Define your hyper‑parameter grid ──────────────────────────────────────
configs = [
    {"lr": 5e-4, "num_layers": 2, "hidden_size": 128},
    {"lr": 5e-4, "num_layers": 4, "hidden_size": 256},
    {"lr": 1e-4, "num_layers": 4, "hidden_size": 256},
    {"lr": 1e-4, "num_layers": 6, "hidden_size": 512},
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
results = {}  # Stores loss for each config


# ─── 2) Loop over configs ─────────────────────────────────────────────────

for cfg in configs:
    config_str = f"lr={cfg['lr']},layers={cfg['num_layers']},hs={cfg['hidden_size']}"
    print(f"\n Training with {config_str}")

    # Instantiate encoder with the cfg
    model = Encoder(
        vocab_size=tokenizer.vocab_size,
        hidden_size=cfg["hidden_size"],
        num_heads=4,  # fixed for now
        num_layers=cfg["num_layers"],
        intermediate_size=cfg["hidden_size"] * 2,
        max_len=128
    ).to(device)

    # Call training loop
    model, train_losses, val_losses = train_bert(
        model=model,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        tokenizer=tokenizer,
        epochs=35, #40, can be tuned
        lr=cfg["lr"],
        device=device
    )

    # Store losses
    results[config_str] = (train_losses, val_losses)

    # Save model to saved_models folder
    filename = f"encoder_lr{cfg['lr']}_layers{cfg['num_layers']}_hs{cfg['hidden_size']}.pt"
    save_path = os.path.join(save_dir, filename)
    torch.save(model.state_dict(), save_path)
    print(f"Saved model to: {save_path}")

with open("mlm_results.pkl", "wb") as f:
    pickle.dump(results, f)

Using device: cuda

 Training with lr=0.0005,layers=2,hs=128
Epoch 1/35 — train: 10.4352, val: 10.1911
Epoch 2/35 — train: 10.0177, val: 10.0208
Epoch 3/35 — train: 9.7289, val: 9.6310
Epoch 4/35 — train: 9.3819, val: 9.3914
Epoch 5/35 — train: 8.9530, val: 8.9530
Epoch 6/35 — train: 8.7871, val: 8.7570
Epoch 7/35 — train: 8.4488, val: 8.4374
Epoch 8/35 — train: 8.2173, val: 8.2353
Epoch 9/35 — train: 7.9796, val: 8.0041
Epoch 10/35 — train: 7.6848, val: 7.5925
Epoch 11/35 — train: 7.3822, val: 7.5015
Epoch 12/35 — train: 7.1574, val: 7.5737
Epoch 13/35 — train: 6.9396, val: 7.2475
Epoch 14/35 — train: 6.7642, val: 7.3772
Epoch 15/35 — train: 6.6847, val: 7.0580
Epoch 16/35 — train: 6.5803, val: 7.2271
Epoch 17/35 — train: 6.4136, val: 7.0179
Epoch 18/35 — train: 6.4506, val: 6.9291
Epoch 19/35 — train: 6.3766, val: 6.9859
Epoch 20/35 — train: 6.4764, val: 7.4159
Epoch 21/35 — train: 6.4955, val: 6.7978
Epoch 22/35 — train: 6.2033, val: 6.8398
Epoch 23/35 — train: 6.4447, val: 6.8933
E

In [7]:
with open("mlm_results.pkl", "wb") as f:
    pickle.dump(results, f)

In [8]:
import matplotlib.pyplot as plt
import pickle
import os

# ─── Load results from file ────────────────────────────────────────────────
with open("mlm_results.pkl", "rb") as f:
    results = pickle.load(f)

# ─── Create output directory ───────────────────────────────────────────────
os.makedirs("loss_plots", exist_ok=True)

# ─── Plot each config separately ───────────────────────────────────────────
for config_str, (train, val) in results.items():
    plt.figure(figsize=(8, 5))
    plt.plot(train, label="Train Loss")
    plt.plot(val, label="Validation Loss", linestyle="--")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"Loss Curve — {config_str}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    # Format filename safely
    safe_name = config_str.replace("=", "_").replace(",", "_").replace(".", "_")
    path = f"loss_plots/loss_{safe_name}.png"
    plt.savefig(path)
    plt.close()  # Close figure to avoid overlapping

    print(f"Saved: {path}")


Saved: loss_plots/loss_lr_0_0005_layers_2_hs_128.png
Saved: loss_plots/loss_lr_0_0005_layers_4_hs_256.png
Saved: loss_plots/loss_lr_0_0001_layers_4_hs_256.png
Saved: loss_plots/loss_lr_0_0001_layers_6_hs_512.png


In [9]:
import matplotlib.pyplot as plt
import pickle
import os

# Load results
with open("mlm_results.pkl", "rb") as f:
    results = pickle.load(f)
os.makedirs("loss_plots", exist_ok=True)

# Filter configs with lr = 5e-4
configs_to_plot = [k for k in results if "lr=0.0005" in k]
configs_to_plot.sort()  # consistent order

fig, axes = plt.subplots(1, 2, figsize=(9, 5), sharey=True)

for ax, config_str in zip(axes, configs_to_plot):
    train, val = results[config_str]
    ax.plot(train, label="Train Loss")
    ax.plot(val, label="Validation Loss", linestyle="--")
    ax.set_title(config_str.replace("lr=", "lr=").replace(",", "\n"))  # multiline title
    ax.set_xlabel("Epoch")
    ax.grid(True)
    if ax is axes[0]:
        ax.set_ylabel("Loss")
    ax.legend()

fig.suptitle("Training vs Validation Loss (lr=5e-4)")
fig.tight_layout(rect=[0, 0, 1, 0.95])  # leave space for suptitle


save_path = "loss_plots/loss_lr_0_0005_comparison.png"
plt.savefig(save_path,dpi=600)
plt.close()
print(f"Saved: {save_path}")


Saved: loss_plots/loss_lr_0_0005_comparison.png


In [10]:
import matplotlib.pyplot as plt
import pickle
import os

# Load results
with open("mlm_results.pkl", "rb") as f:
    results = pickle.load(f)

os.makedirs("loss_plots", exist_ok=True)

# Target config (best model)
target_config = "lr=0.0001,layers=6,hs=512"
train, val = results[target_config]

plt.figure(figsize=(8, 5))
plt.plot(train, label="Train Loss")
plt.plot(val, label="Validation Loss", linestyle="--")
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.title("Best Model\nlr=1e-4, layers=6, hidden=512", fontsize=13)
plt.legend()
plt.grid(True)
plt.tight_layout()

save_path = "loss_plots/best_model_loss.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()

print(f"Saved best model loss plot to: {save_path}")


Saved best model loss plot to: loss_plots/best_model_loss.png


### Part 2

In [11]:
from transformers import BertTokenizerFast
from encoder import Encoder
from tqdm import tqdm
import torch
import numpy as np
from preprocessing import lanczosinterp2D, make_delayed
import os
import pickle

In [12]:
model_path = "saved_models/encoder_lr0.0001_layers6_hs512.pt" 

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
vocab_size = len(tokenizer.get_vocab()) 
print("Real vocab size:", vocab_size)

model = Encoder(
    vocab_size=tokenizer.vocab_size,
    hidden_size=512,
    num_heads=4,
    num_layers=6,
    intermediate_size=1024,
    max_len=128
)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()

print("tokenizer vocab size:", tokenizer.vocab_size)
print("model token embedding shape:", model.token_emb.weight.shape)

def extract_embeddings(seq, model, tokenizer, chunk_size=128, stride=64, hidden_size=512):
    device = next(model.parameters()).device
    text_words = seq.data
    total_words = len(text_words)
    word_embeddings = [None] * total_words

    for start in range(0, total_words, stride):
        chunk_words = text_words[start:start + chunk_size]
        tokens = tokenizer(
            chunk_words,
            is_split_into_words=True,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=128, 
            return_attention_mask=True,
            return_token_type_ids=True
        )

        word_ids = tokens.word_ids(batch_index=0)

        input_ids = tokens["input_ids"].to(device)
        token_type_ids = tokens["token_type_ids"].to(device)
        attention_mask = tokens["attention_mask"].to(device)

        with torch.no_grad():
            hidden_states = model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                return_hidden=True
            )[0]

        hidden_states = hidden_states.squeeze(0).cpu()

        attention_mask = tokens["attention_mask"][0]  # shape: (seq_len,)
        word_ids = tokens.word_ids(batch_index=0)
        
        for token_idx, word_idx in enumerate(word_ids):
            if word_idx is None:
                continue
            absolute_word_idx = start + word_idx
            if absolute_word_idx >= total_words:
                continue
            if word_embeddings[absolute_word_idx] is None:
                word_embeddings[absolute_word_idx] = []
            word_embeddings[absolute_word_idx].append(hidden_states[token_idx])

    for i in range(total_words):
        if word_embeddings[i] is None:
            word_embeddings[i] = torch.zeros(hidden_size)
        else:
            word_embeddings[i] = torch.stack(word_embeddings[i]).mean(0)

    return torch.stack(word_embeddings).numpy()

Real vocab size: 30522
tokenizer vocab size: 30522
model token embedding shape: torch.Size([30522, 512])


In [13]:
with open("../../../shared/data/raw_text.pkl", "rb") as f:
    raw_texts = pickle.load(f)

for story_id, seq in tqdm(raw_texts.items()):
    try:
        emb = extract_embeddings(seq, model, tokenizer)
        X_interp = lanczosinterp2D(
            emb, oldtime=seq.data_times, newtime=seq.tr_times
        )
        TR = np.mean(np.diff(seq.tr_times))
        n_skip_start = int(np.ceil(5 / TR))
        n_skip_end = int(np.ceil(10 / TR))
        X_interp = X_interp[n_skip_start:-n_skip_end]
        X_delayed = make_delayed(X_interp, [1,2,3,4])
        np.save(f"../results/embeddings/encoder/{story_id}.npy", X_delayed)
    except Exception as e:
        print(f"⚠️ Skipping {story_id}: {e}")

100%|██████████| 109/109 [01:58<00:00,  1.09s/it]
