### Part 1

In [20]:
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader
import torch
import pickle
from sklearn.model_selection import train_test_split

from data import TextDataset
from encoder import Encoder
from train_encoder import train_bert

In [21]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [22]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

## open the data
with open("../data/raw_text.pkl", "rb") as f:
    raw_texts = pickle.load(f)

# Combine all story content into one big list of strings
all_texts = []

# Pick a sample story
sample_story = raw_texts['avatar']
all_stories = {}  # Dict[story_id: str]

for story_id, sequence in raw_texts.items():
    try:
        full_story = " ".join(sequence.data)
        all_stories[story_id] = full_story
    except Exception as e:
        print(f"Skipping {story_id}: {e}")



  raw_texts = pickle.load(f)


In [23]:
story_ids = list(all_stories.keys())

# 2) split IDs
train_ids, val_ids = train_test_split(
    story_ids,
    test_size=0.1,
    random_state=42
)

# 3) build lists of strings
train_texts = [all_stories[sid] for sid in train_ids]
val_texts   = [all_stories[sid] for sid in val_ids]

In [24]:
tokenizer    = BertTokenizerFast.from_pretrained("bert-base-uncased")
train_ds      = TextDataset(train_texts, tokenizer, max_len=128)
val_ds        = TextDataset(val_texts,   tokenizer, max_len=128)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=1, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=1, pin_memory=True)

In [25]:
import torch
import matplotlib.pyplot as plt
import pickle
from encoder import Encoder
from train_encoder import train_bert
import os


save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True) 

# ─── 1) Define your hyper‑parameter grid ──────────────────────────────────────
configs = [
    {"lr": 5e-4, "num_layers": 2, "hidden_size": 128},
    {"lr": 5e-4, "num_layers": 4, "hidden_size": 256},
    {"lr": 1e-4, "num_layers": 4, "hidden_size": 256},
    {"lr": 1e-4, "num_layers": 6, "hidden_size": 512},
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
results = {}  # Stores loss for each config


# ─── 2) Loop over configs ─────────────────────────────────────────────────

for cfg in configs:
    config_str = f"lr={cfg['lr']},layers={cfg['num_layers']},hs={cfg['hidden_size']}"
    print(f"\n Training with {config_str}")

    # Instantiate encoder with the cfg
    model = Encoder(
        vocab_size=tokenizer.vocab_size,
        hidden_size=cfg["hidden_size"],
        num_heads=4,  # fixed for now
        num_layers=cfg["num_layers"],
        intermediate_size=cfg["hidden_size"] * 2,
        max_len=128
    ).to(device)

    # Call training loop
    model, train_losses, val_losses = train_bert(
        model=model,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        tokenizer=tokenizer,
        epochs=35, #40, can be tuned
        lr=cfg["lr"],
        device=device
    )

    # Store losses
    results[config_str] = (train_losses, val_losses)

    # Save model to saved_models folder
    filename = f"encoder_lr{cfg['lr']}_layers{cfg['num_layers']}_hs{cfg['hidden_size']}.pt"
    save_path = os.path.join(save_dir, filename)
    torch.save(model.state_dict(), save_path)
    print(f"Saved model to: {save_path}")

with open("mlm_results.pkl", "wb") as f:
    pickle.dump(results, f)

Using device: cpu

 Training with lr=0.0005,layers=2,hs=128
Epoch 1/35 — train: 10.3398, val: 10.0106
Epoch 2/35 — train: 10.0185, val: 9.9075
Epoch 3/35 — train: 9.6449, val: 9.5353
Epoch 4/35 — train: 9.3383, val: 9.2914
Epoch 5/35 — train: 8.9953, val: 8.8622
Epoch 6/35 — train: 8.6455, val: 8.7619
Epoch 7/35 — train: 8.2683, val: 8.4776
Epoch 8/35 — train: 8.0058, val: 8.1878
Epoch 9/35 — train: 7.7112, val: 8.1268
Epoch 10/35 — train: 7.4238, val: 7.6349
Epoch 11/35 — train: 7.2199, val: 7.4491
Epoch 12/35 — train: 7.1084, val: 7.2637
Epoch 13/35 — train: 6.8956, val: 7.4326
Epoch 14/35 — train: 6.7210, val: 7.5106
Epoch 15/35 — train: 6.6108, val: 7.3105
Epoch 16/35 — train: 6.3265, val: 7.0123
Epoch 17/35 — train: 6.6248, val: 7.0932
Epoch 18/35 — train: 6.4312, val: 7.1567
Epoch 19/35 — train: 6.3707, val: 6.9545
Epoch 20/35 — train: 6.4515, val: 6.8379
Epoch 21/35 — train: 6.2690, val: 6.9824
Epoch 22/35 — train: 6.2269, val: 6.7614
Epoch 23/35 — train: 6.3649, val: 6.9558
Epo

In [26]:
with open("mlm_results.pkl", "wb") as f:
    pickle.dump(results, f)

In [27]:
import matplotlib.pyplot as plt
import pickle
import os

# ─── Load results from file ────────────────────────────────────────────────
with open("mlm_results.pkl", "rb") as f:
    results = pickle.load(f)

# ─── Create output directory ───────────────────────────────────────────────
os.makedirs("loss_plots", exist_ok=True)

# ─── Plot each config separately ───────────────────────────────────────────
for config_str, (train, val) in results.items():
    plt.figure(figsize=(8, 5))
    plt.plot(train, label="Train Loss")
    plt.plot(val, label="Validation Loss", linestyle="--")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(f"Loss Curve — {config_str}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    # Format filename safely
    safe_name = config_str.replace("=", "_").replace(",", "_").replace(".", "_")
    path = f"loss_plots/loss_{safe_name}.png"
    plt.savefig(path)
    plt.close()  # Close figure to avoid overlapping

    print(f"Saved: {path}")


Saved: loss_plots/loss_lr_0_0005_layers_2_hs_128.png
Saved: loss_plots/loss_lr_0_0005_layers_4_hs_256.png
Saved: loss_plots/loss_lr_0_0001_layers_4_hs_256.png
Saved: loss_plots/loss_lr_0_0001_layers_6_hs_512.png


In [35]:
import matplotlib.pyplot as plt
import pickle
import os

# Load results
with open("mlm_results.pkl", "rb") as f:
    results = pickle.load(f)
os.makedirs("loss_plots", exist_ok=True)

# Filter configs with lr = 5e-4
configs_to_plot = [k for k in results if "lr=0.0005" in k]
configs_to_plot.sort()  # consistent order

fig, axes = plt.subplots(1, 2, figsize=(9, 5), sharey=True)

for ax, config_str in zip(axes, configs_to_plot):
    train, val = results[config_str]
    ax.plot(train, label="Train Loss")
    ax.plot(val, label="Validation Loss", linestyle="--")
    ax.set_title(config_str.replace("lr=", "lr=").replace(",", "\n"))  # multiline title
    ax.set_xlabel("Epoch")
    ax.grid(True)
    if ax is axes[0]:
        ax.set_ylabel("Loss")
    ax.legend()

fig.suptitle("Training vs Validation Loss (lr=5e-4)")
fig.tight_layout(rect=[0, 0, 1, 0.95])  # leave space for suptitle


save_path = "loss_plots/loss_lr_0_0005_comparison.png"
plt.savefig(save_path,dpi=600)
plt.close()
print(f"Saved: {save_path}")


Saved: loss_plots/loss_lr_0_0005_comparison.png


In [36]:
import matplotlib.pyplot as plt
import pickle
import os

# Load results
with open("mlm_results.pkl", "rb") as f:
    results = pickle.load(f)

os.makedirs("loss_plots", exist_ok=True)

# Target config (best model)
target_config = "lr=0.0001,layers=6,hs=512"
train, val = results[target_config]

plt.figure(figsize=(8, 5))
plt.plot(train, label="Train Loss")
plt.plot(val, label="Validation Loss", linestyle="--")
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.title("Best Model\nlr=1e-4, layers=6, hidden=512", fontsize=13)
plt.legend()
plt.grid(True)
plt.tight_layout()

save_path = "loss_plots/best_model_loss.png"
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()

print(f"Saved best model loss plot to: {save_path}")


✔️ Saved best model loss plot to: loss_plots/best_model_loss.png


### Part 2

In [None]:
import torch
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader
import pickle
from tqdm import tqdm
import os

: 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
model = Encoder(
    vocab_size=tokenizer.vocab_size,
    hidden_size=256,
    num_heads=4,
    num_layers=4,
    intermediate_size=512,
    max_len=128
)
model.load_state_dict(torch.load("encoder_lr1e-4_layers4_hs256.pt", map_location=device))
model = model.to(device)
model.eval()
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [None]:
# Load raw text
with open("/ocean/projects/mth240012p/shared/data/raw_text.pkl", "rb") as f:
    raw_texts = pickle.load(f)

story_embeddings = {}
error_stories = []

# Loop over each story
for story_id, sequence in tqdm(raw_texts.items()):
    try:
        # Get full story text
        story_text = " ".join(sequence.data)
        tokens = tokenizer(
            story_text,
            padding="max_length",
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )
        input_ids = tokens["input_ids"].to(device)
        token_type_ids = tokens["token_type_ids"].to(device)
        attention_mask = tokens["attention_mask"].to(device)

        # Get encoder hidden states before the MLM head
        with torch.no_grad():
            x = model.token_emb(input_ids) \
              + model.pos_emb(torch.arange(input_ids.size(1)).unsqueeze(0).to(device)) \
              + model.type_emb(token_type_ids)
            for layer in model.layers:
                x = layer(x, attention_mask)
            x = model.norm(x)

            # Mean pooling over seq_len
            embedding = x.mean(dim=1).squeeze().cpu().numpy()
            story_embeddings[story_id] = embedding

    except Exception as e:
        error_stories.append((story_id, str(e)))
        continue

In [None]:
# Save embeddings
with open("../results/embeddings/encoder_embeddings.pkl", "wb") as f:
    pickle.dump(story_embeddings, f)

print(f"Done. Saved {len(story_embeddings)} embeddings to ../results/embeddings/")

Above is work done so far, the code under haven't sucessfully run once yet.

In [None]:
from ridge_utils.DataSequence import DataSequence
from ridge_utils.ziploader import ZipDataLoader

In [None]:
# Load embeddings
with open("../results/embeddings/encoder_embeddings.pkl", "rb") as f:
    story_embeddings = pickle.load(f)  # dict: story_id -> np.array [D]

# Load subject2
subject = ZipDataLoader("/ocean/projects/mth240012p/shared/data/subject2.zip")

# Create DataSequence dict
ds_dict = {}
for sid in story_embeddings:
    emb = story_embeddings[sid]
    times = subject.stimuli[sid].data_times
    tr_times = subject.stimuli[sid].tr_times
    split_inds = subject.stimuli[sid].split_inds
    if emb.shape[0] != len(times):
        continue  
    ds_dict[sid] = DataSequence(emb, split_inds, times, tr_times)

In [None]:
train_stories = subject.train_story_ids
test_stories  = subject.test_story_ids

X_train = np.concatenate([ds_dict[sid].chunksums(interp="lanczos") for sid in train_stories])
X_test  = np.concatenate([ds_dict[sid].chunksums(interp="lanczos") for sid in test_stories])

Y_train = np.concatenate([subject.responses[sid] for sid in train_stories])
Y_test  = np.concatenate([subject.responses[sid] for sid in test_stories])

In [None]:
from ridge_utils.utils import zscore, make_delayed
from ridge_utils.ridge import bootstrap_ridge

X_train_z = zscore(X_train)
Y_train_z = zscore(Y_train)
X_test_z = zscore(X_test)
Y_test_z = zscore(Y_test)

delays = [0, 1, 2, 3]
X_train_d = make_delayed(X_train_z, delays)
X_test_d = make_delayed(X_test_z, delays)

alphas = np.logspace(1, 4, 20)
wt, test_corrs, val_alphas, allRcorrs, valinds = bootstrap_ridge(
    X_train_d, Y_train_z, X_test_d, Y_test_z,
    alphas=alphas,
    nboots=10,
    chunklen=10,
    nchunks=2,
    normalpha=True,
    return_wt=True
)


In [None]:
import matplotlib.pyplot as plt

plt.hist(test_corrs, bins=50)
plt.xlabel("Test correlation")
plt.ylabel("Voxel count")
plt.title("Encoder Embedding Performance on Subject 2")
plt.grid(True)
plt.tight_layout()
plt.show()