In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!pip install gensim

Collecting gensim
  Downloading gensim-4.4.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (8.4 kB)
Downloading gensim-4.4.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (27.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.9/27.9 MB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: gensim
Successfully installed gensim-4.4.0


In [7]:
# bert_encoder_from_scratch_with_pooling_multitype_allpairs.py
# Modified version that supports:
# - 'C' = Chunk (uses MLM + NSP) → uses ALL possible positive & negative pairs
# - 'Q' = Query (uses MLM only)
# - Added model saving and evaluation on test subset

import random
import math
import os
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from gensim.models import Word2Vec
import pickle

# -------------------------
# Config
# -------------------------
VOCAB_MIN_FREQ = 1
MAX_SEQ_LEN = 1024
HIDDEN_SIZE = 768
NUM_LAYERS = 12
NUM_HEADS = 12
FFN_DIM = 3072
DROPOUT = 0.1
WORD2VEC_SIZE = HIDDEN_SIZE
WORD2VEC_WINDOW = 5
WORD2VEC_MIN_COUNT = 1
MLM_MASK_PROB = 0.15
BATCH_SIZE = 8
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-5
PICKLE_PATH = "/content/drive/MyDrive/ANLPProj/harry_potter_chunks_hierarchical.pkl"
C_LABEL = "C"
VERBOSE = True


# -------------------------
# Special tokens
# -------------------------
PAD_TOKEN = "[PAD]"
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"
MASK_TOKEN = "[MASK]"
UNK_TOKEN = "[UNK]"
SPECIAL_TOKENS = [PAD_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN, UNK_TOKEN]

# -------------------------
# Utility: Vocab builder
# -------------------------
def build_vocab(sentences: List[str], min_freq: int = VOCAB_MIN_FREQ):
    from collections import Counter
    token_counts = Counter()
    for s in sentences:
        tokens = s.strip().split()
        token_counts.update(tokens)
    stoi, itos = {}, []
    for t in SPECIAL_TOKENS:
        stoi[t] = len(itos)
        itos.append(t)
    for token, cnt in token_counts.items():
        if cnt >= min_freq and token not in stoi:
            stoi[token] = len(itos)
            itos.append(token)
    return stoi, itos

# -------------------------
# Train or load Word2Vec
# -------------------------
def train_word2vec(sentences: List[str], vector_size=WORD2VEC_SIZE, window=WORD2VEC_WINDOW, min_count=WORD2VEC_MIN_COUNT, epochs=5):
    tokenized = [s.strip().split() for s in sentences]
    w2v = Word2Vec(sentences=tokenized, vector_size=vector_size, window=window, min_count=min_count, epochs=epochs, sg=0)
    return w2v

def build_embedding_matrix(w2v: Word2Vec, itos: List[str], hidden_size: int):
    vocab_size = len(itos)
    embeddings = np.random.normal(scale=0.02, size=(vocab_size, hidden_size)).astype(np.float32)
    for idx, tok in enumerate(itos):
        if tok in w2v.wv:
            vec = w2v.wv[tok]
            if vec.shape[0] != hidden_size:
                vec = vec[:hidden_size] if vec.shape[0] >= hidden_size else np.pad(vec, (0, hidden_size - vec.shape[0]))
            embeddings[idx] = vec
    pad_idx = itos.index(PAD_TOKEN)
    embeddings[pad_idx] = np.zeros(hidden_size, dtype=np.float32)
    return torch.tensor(embeddings)

# -------------------------
# Dataset (supports queries and chunks)
# -------------------------
class BertPretrainingDataset(Dataset):
    def __init__(self, data: List[Tuple[str, str]], stoi: dict, max_seq_len=MAX_SEQ_LEN):
        """
        data: list of tuples [(text, discriminator)], where discriminator ∈ {'Q', 'C'}
        """
        self.stoi = stoi
        self.max_seq_len = max_seq_len
        self.samples = []

        for text, dtype in data:
            if dtype == "Q":
                # Single-sentence query (MLM only)
                self.samples.append((text, dtype, None, None))
            elif dtype == "C":
                # Split chunk into sentences
                sents = [s.strip() for s in text.strip().split('.') if s.strip()]
                if len(sents) < 2:
                    sents = sents + sents  # duplicate if only one sentence
                # Positive pairs: consecutive sentences
                for i in range(len(sents) - 1):
                    self.samples.append((sents[i], "C", sents[i + 1], 1))
                # Negative pairs: non-consecutive
                for i in range(len(sents)):
                    for j in range(len(sents)):
                        if abs(i - j) > 1:  # skip consecutive
                            self.samples.append((sents[i], "C", sents[j], 0))

    def __len__(self):
        return len(self.samples)

    def _tokenize_to_ids(self, text: str) -> List[int]:
        toks = text.strip().split()
        return [self.stoi.get(t, self.stoi[UNK_TOKEN]) for t in toks]

    def __getitem__(self, idx):
        sent_a, dtype, sent_b, nsp_label = self.samples[idx]

        # -------------------------------
        # Case 1: Query (MLM only)
        # -------------------------------
        if dtype == 'Q':
            ids = self._tokenize_to_ids(sent_a)
            ids = ids[:self.max_seq_len - 2]
            input_ids = [self.stoi[CLS_TOKEN]] + ids + [self.stoi[SEP_TOKEN]]
            token_type_ids = [0] * len(input_ids)
            nsp_label = -100  # dummy
            return {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
                "batch_type": "Q"
            }

        # -------------------------------
        # Case 2: Chunk (MLM + NSP)
        # -------------------------------
        elif dtype == 'C':
            ids_a = self._tokenize_to_ids(sent_a)
            ids_b = self._tokenize_to_ids(sent_b)
            while len(ids_a) + len(ids_b) > self.max_seq_len - 3:
                if len(ids_a) > len(ids_b):
                    ids_a.pop()
                else:
                    ids_b.pop()
            input_ids = [self.stoi[CLS_TOKEN]] + ids_a + [self.stoi[SEP_TOKEN]] + ids_b + [self.stoi[SEP_TOKEN]]
            token_type_ids = [0] * (len(ids_a) + 2) + [1] * (len(ids_b) + 1)
            return {
                "input_ids": torch.tensor(input_ids, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "nsp_label": torch.tensor(nsp_label, dtype=torch.long),
                "batch_type": "C"
            }

def collate_fn(batch, pad_id):
    input_ids_list = [b["input_ids"] for b in batch]
    token_type_list = [b["token_type_ids"] for b in batch]
    nsp_labels = torch.stack([b["nsp_label"] for b in batch]).long()
    batch_types = [b["batch_type"] for b in batch]

    max_len = max([x.size(0) for x in input_ids_list])
    padded_input_ids, padded_token_types, attention_masks = [], [], []
    for ids, tt in zip(input_ids_list, token_type_list):
        pad_len = max_len - ids.size(0)
        padded_input_ids.append(F.pad(ids, (0, pad_len), value=pad_id))
        padded_token_types.append(F.pad(tt, (0, pad_len), value=0))
        attention_masks.append((F.pad(ids, (0, pad_len), value=pad_id) != pad_id).long())

    return {
        "input_ids": torch.stack(padded_input_ids),
        "token_type_ids": torch.stack(padded_token_types),
        "attention_mask": torch.stack(attention_masks),
        "nsp_labels": nsp_labels,
        "batch_type": batch_types
    }

# -------------------------
# MLM Masking
# -------------------------
def create_mlm_labels_and_masked_input(input_ids, pad_id, mask_token_id, vocab_size, mask_prob=MLM_MASK_PROB):
    batch_size, seq_len = input_ids.shape
    mlm_labels = torch.full_like(input_ids, -100)
    prob_matrix = torch.full((batch_size, seq_len), mask_prob)
    prob_matrix[input_ids == pad_id] = 0.0
    special_upper = len(SPECIAL_TOKENS)
    prob_matrix[input_ids < special_upper] = 0.0
    masked_positions = torch.bernoulli(prob_matrix).bool()
    mlm_labels[masked_positions] = input_ids[masked_positions]
    input_ids_masked = input_ids.clone()
    # ensure masked_positions lives on the same device and is boolean
    masked_positions = masked_positions.to(input_ids.device)
    if masked_positions.dtype != torch.bool:
        masked_positions = masked_positions.bool()
    # random tensor on same device as input_ids
    rand_for_replace = torch.rand(input_ids.shape, dtype=torch.float, device=input_ids.device)
    mask_replace = masked_positions & (rand_for_replace < 0.8)
    random_replace = masked_positions & (rand_for_replace >= 0.8) & (rand_for_replace < 0.9)
    input_ids_masked[mask_replace] = mask_token_id
    if random_replace.any():
        count = int(random_replace.sum().item())
        rand_tokens = torch.randint(len(SPECIAL_TOKENS), vocab_size, (count,), device=input_ids.device)
        input_ids_masked[random_replace] = rand_tokens
    return input_ids_masked, mlm_labels

# -------------------------
# Transformer encoder
# -------------------------
class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        self.ffn = nn.Sequential(nn.Linear(hidden_size, ffn_dim), nn.GELU(), nn.Linear(ffn_dim, hidden_size))
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, mask):
        key_padding_mask = (mask == 0)
        attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
        x = self.ln1(x + self.dropout(attn_out))
        ffn_out = self.ffn(x)
        x = self.ln2(x + self.dropout(ffn_out))
        return x

class BertEncoderModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, ffn_dim=FFN_DIM, max_position_embeddings=512, pad_token_id=0, embedding_weights=None):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
        if embedding_weights is not None:
            self.token_embeddings.weight.data.copy_(embedding_weights)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.segment_embeddings = nn.Embedding(2, hidden_size)
        self.emb_ln = nn.LayerNorm(hidden_size)
        self.emb_dropout = nn.Dropout(0.1)
        self.layers = nn.ModuleList([TransformerEncoderLayer(hidden_size, num_heads, ffn_dim) for _ in range(num_layers)])
        self.nsp_classifier = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 2))
        self.mlm_bias = nn.Parameter(torch.zeros(vocab_size))
    def encode(self, ids, tt=None, mask=None):
        if tt is None:
            tt = torch.zeros_like(ids)
        if mask is None:
            mask = (ids != self.pad_token_id).long()
        pos = torch.arange(ids.size(1), device=ids.device).unsqueeze(0)
        x = self.token_embeddings(ids) + self.position_embeddings(pos) + self.segment_embeddings(tt)
        x = self.emb_dropout(self.emb_ln(x))
        for layer in self.layers:
            x = layer(x, mask)
        return x
    def forward(self, ids, tt=None, mask=None):
        seq_out = self.encode(ids, tt, mask)
        pooled = seq_out[:, 0]
        nsp_logits = self.nsp_classifier(pooled)
        mlm_logits = F.linear(seq_out, self.token_embeddings.weight, self.mlm_bias)
        return mlm_logits, nsp_logits

#new
def load_chunks_from_pkl(pkl_path):
    """
    Load many possible pickle shapes and return a list of tuples:
      [(chunk_text, None), ...]
    Handles common forms:
      - list[str]
      - list[dict] with keys like 'chunk','text','content','page_content'
      - dict with 'documents' key (Chromadb-like)
      - list[tuple] (uses first element as text)
      - fallback: str(data)
    """
    if not os.path.exists(pkl_path):
        raise FileNotFoundError(f"Pickle file not found: {pkl_path}")
    with open(pkl_path, "rb") as fh:
        data = pickle.load(fh)

    def extract_text_from_dict(d):
        for k in ("chunk", "text", "content", "page_content", "document", "body"):
            if k in d and isinstance(d[k], str):
                return d[k]
        meta = d.get("metadata") if isinstance(d, dict) else None
        if isinstance(meta, dict):
            for k in ("text", "content", "page_content", "chunk"):
                if k in meta and isinstance(meta[k], str):
                    return meta[k]
        for v in d.values():
            if isinstance(v, str) and v.strip():
                return v
        return str(d)

    corpus_texts = []

    if isinstance(data, list):
        if all(isinstance(x, str) for x in data):
            corpus_texts = data
        elif all(isinstance(x, tuple) for x in data):
            corpus_texts = [t[0] if len(t)>0 else str(t) for t in data]
        elif all(isinstance(x, dict) for x in data):
            corpus_texts = [extract_text_from_dict(x) for x in data]
        else:
            for item in data:
                if isinstance(item, str):
                    corpus_texts.append(item)
                elif isinstance(item, dict):
                    corpus_texts.append(extract_text_from_dict(item))
                elif isinstance(item, tuple) and len(item) >= 1:
                    corpus_texts.append(item[0])
                else:
                    corpus_texts.append(str(item))
    elif isinstance(data, dict):
        if "documents" in data and isinstance(data["documents"], list):
            for d in data["documents"]:
                if isinstance(d, str):
                    corpus_texts.append(d)
                elif isinstance(d, dict):
                    corpus_texts.append(extract_text_from_dict(d))
                else:
                    corpus_texts.append(str(d))
        else:
            for v in data.values():
                if isinstance(v, str):
                    corpus_texts.append(v)
                elif isinstance(v, list):
                    for item in v:
                        corpus_texts.append(str(item))
                elif isinstance(v, dict):
                    corpus_texts.append(extract_text_from_dict(v))
                else:
                    corpus_texts.append(str(v))
    else:
        corpus_texts = [str(data)]

    corpus_texts = [str(t) for t in corpus_texts]
    return [(t, None) for t in corpus_texts]


# -------------------------
# Training and Evaluation
# -------------------------
def main():
    PICKLE_PATH = globals().get("PICKLE_PATH", None)
    CORPUS_FALLBACK = globals().get("CORPUS_FALLBACK", None)
    C_label = globals().get("C", globals().get("C_LABEL", "C"))

    corpus = []

    # try loading from pickle if provided
    if PICKLE_PATH:
        try:
            loaded = load_chunks_from_pkl(PICKLE_PATH)
            normalized = []
            for item in loaded:
                if isinstance(item, tuple) and len(item) >= 1:
                    text = item[0]
                elif isinstance(item, str):
                    text = item
                elif isinstance(item, dict):
                    text = item.get("text") or item.get("page_content") or item.get("content") or item.get("chunk") or str(item)
                else:
                    text = str(item)
                normalized.append((text, C_label))
            corpus = normalized
        except Exception as e:
            print(f"[main] Failed to load pickle: {e!r} — falling back to CORPUS_FALLBACK or default dummy corpus")
            corpus = []

    # if pickle didn't yield anything, use CORPUS_FALLBACK if available
    if not corpus:
        if CORPUS_FALLBACK:
            normalized = []
            for item in CORPUS_FALLBACK:
                if isinstance(item, tuple) and len(item) >= 1:
                    text = item[0]
                elif isinstance(item, str):
                    text = item
                elif isinstance(item, dict):
                    text = item.get("text") or item.get("page_content") or item.get("content") or item.get("chunk") or str(item)
                else:
                    text = str(item)
                normalized.append((text, C_label))
            corpus = normalized
        else:
            # final fallback -> keep your original dummy corpus so notebook remains runnable
            corpus = [
                ("the quick brown fox jumps over the lazy dog. the dog did not mind.", C_label),
                ("i love machine learning and transformers.", "Q"),
                ("deep learning enables summarization and translation. it is powerful.", C_label),
                ("best restaurants near me", "Q")
            ]

    #new_end
    stoi, itos = build_vocab([x[0] for x in corpus])
    vocab_size = len(itos)
    w2v = train_word2vec([x[0] for x in corpus])
    emb = build_embedding_matrix(w2v, itos, HIDDEN_SIZE)
    pad_id = stoi[PAD_TOKEN]; mask_id = stoi[MASK_TOKEN]
    ds = BertPretrainingDataset(corpus, stoi)

    # Split train/test
    total_len = len(ds)
    test_len = max(1, total_len // 5)
    train_len = total_len - test_len
    train_ds, test_ds = random_split(ds, [train_len, test_len])
    dl_train = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda b: collate_fn(b, pad_id))
    dl_test = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda b: collate_fn(b, pad_id))

    model = BertEncoderModel(vocab_size, embedding_weights=emb).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    nsp_loss_fct = nn.CrossEntropyLoss()
    model.train()
    for epoch in range(1):
        for batch in dl_train:
            ids = batch["input_ids"].to(DEVICE)
            tts = batch["token_type_ids"].to(DEVICE)
            mask = batch["attention_mask"].to(DEVICE)
            nsp_labels = batch["nsp_labels"].to(DEVICE)
            btypes = batch["batch_type"]
            ids_masked, mlm_labels = create_mlm_labels_and_masked_input(ids, pad_id, mask_id, vocab_size)
            ids_masked, mlm_labels = ids_masked.to(DEVICE), mlm_labels.to(DEVICE)
            mlm_logits, nsp_logits = model(ids_masked, tts, mask)
            mlm_loss = mlm_loss_fct(mlm_logits.view(-1, vocab_size), mlm_labels.view(-1))
            if all(bt == "C" for bt in btypes):
                nsp_loss = nsp_loss_fct(nsp_logits.view(-1, 2), nsp_labels.view(-1))
            else:
                nsp_loss = torch.tensor(0.0, device=DEVICE)
            loss = mlm_loss + nsp_loss
            opt.zero_grad()
            loss.backward()
            opt.step()
            print(f"Loss {loss.item():.4f} (MLM {mlm_loss.item():.4f}, NSP {nsp_loss.item():.4f})")

    # -------------------------
    # Save model and vocab
    # -------------------------
    save_dir = "saved_bert_encoder"
    os.makedirs(save_dir, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_dir, "bert_encoder.pt"))
    import json
    with open(os.path.join(save_dir, "vocab.json"), "w") as f:
        json.dump({"stoi": stoi, "itos": itos}, f)
    print(f"Model and vocab saved to {save_dir}")

    # -------------------------
    # Evaluation
    # -------------------------
    model.eval()
    total_mlm_correct = 0
    total_mlm_count = 0
    total_nsp_correct = 0
    total_nsp_count = 0

    with torch.no_grad():
        for batch in dl_test:
            ids = batch["input_ids"].to(DEVICE)
            tts = batch["token_type_ids"].to(DEVICE)
            mask = batch["attention_mask"].to(DEVICE)
            nsp_labels = batch["nsp_labels"].to(DEVICE)
            btypes = batch["batch_type"]
            ids_masked, mlm_labels = create_mlm_labels_and_masked_input(ids, pad_id, mask_id, vocab_size)
            ids_masked, mlm_labels = ids_masked.to(DEVICE), mlm_labels.to(DEVICE)
            mlm_logits, nsp_logits = model(ids_masked, tts, mask)
            # MLM accuracy
            mlm_preds = mlm_logits.argmax(-1)
            mask_positions = mlm_labels != -100
            total_mlm_correct += (mlm_preds[mask_positions] == mlm_labels[mask_positions]).sum().item()
            total_mlm_count += mask_positions.sum().item()
            # NSP accuracy
            if all(bt == "C" for bt in btypes):
                nsp_preds = nsp_logits.argmax(-1)
                total_nsp_correct += (nsp_preds == nsp_labels).sum().item()
                total_nsp_count += nsp_labels.numel()

    mlm_acc = total_mlm_correct / max(1, total_mlm_count)
    nsp_acc = total_nsp_correct / max(1, total_nsp_count)
    print(f"MLM Accuracy: {mlm_acc:.4f}, NSP Accuracy: {nsp_acc:.4f}")
    print("Training and evaluation done.")

if __name__ == "__main__":
    main()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Loss 6.8585 (MLM 6.8163, NSP 0.0422)
Loss 6.3676 (MLM 6.0456, NSP 0.3220)
Loss 6.6001 (MLM 6.4841, NSP 0.1161)
Loss 6.5873 (MLM 6.1692, NSP 0.4181)
Loss 5.4714 (MLM 5.0652, NSP 0.4063)
Loss 6.8655 (MLM 6.5000, NSP 0.3655)
Loss 7.8900 (MLM 6.6502, NSP 1.2398)
Loss 6.0568 (MLM 5.6339, NSP 0.4230)
Loss 4.4229 (MLM 4.3900, NSP 0.0329)
Loss 7.3947 (MLM 6.5684, NSP 0.8263)
Loss 5.6463 (MLM 5.5549, NSP 0.0914)
Loss 7.0188 (MLM 6.6861, NSP 0.3328)
Loss 7.1575 (MLM 6.8142, NSP 0.3432)
Loss 5.4082 (MLM 5.3520, NSP 0.0562)
Loss 6.6765 (MLM 5.9725, NSP 0.7040)
Loss 5.6043 (MLM 5.2930, NSP 0.3113)
Loss 7.1495 (MLM 6.7980, NSP 0.3515)
Loss 5.5553 (MLM 5.4691, NSP 0.0862)
Loss 6.6124 (MLM 6.2474, NSP 0.3651)
Loss 6.7582 (MLM 6.5213, NSP 0.2368)
Loss 6.1402 (MLM 6.0391, NSP 0.1011)
Loss 6.6040 (MLM 6.4939, NSP 0.1101)
Loss 5.5422 (MLM 5.4346, NSP 0.1076)
Loss 5.8627 (MLM 5.7853, NSP 0.0774)
Loss 5.8432 (MLM 5.2196, NSP 0.6236)
Loss 6.046