In [1]:
import os, math, random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from datasets import load_dataset
from transformers import (
    BertConfig,
    BertForMaskedLM,
    BertModel,
    BertTokenizerFast,
    DataCollatorForLanguageModeling,
)

from sklearn.metrics import classification_report

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: mps


In [3]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

mlm_config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=256,
    num_hidden_layers=4,
    num_attention_heads=4,
    intermediate_size=1024,
    max_position_embeddings=128,
    type_vocab_size=2,
)

mlm_model = BertForMaskedLM(mlm_config).to(device)

print("MLM hidden size:", mlm_model.config.hidden_size)
print("MLM layers:", mlm_model.config.num_hidden_layers)
print("Vocab:", mlm_model.config.vocab_size)

MLM hidden size: 256
MLM layers: 4
Vocab: 30522


In [4]:
raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

print("Before filter:", len(raw))
raw = raw.filter(lambda x: x["text"] is not None and len(x["text"].strip()) > 0)
print("After filter:", len(raw))

print("Example:", raw[0]["text"][:200])

Before filter: 36718


Filter: 100%|██████████| 36718/36718 [00:00<00:00, 1019971.09 examples/s]

After filter: 23767
Example:  = Valkyria Chronicles III = 






In [5]:
max_len = 128

def tokenize_mlm(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=max_len,
        return_special_tokens_mask=True,
    )

tokenized_mlm = raw.map(tokenize_mlm, batched=True, remove_columns=raw.column_names)
print(tokenized_mlm)
print(tokenized_mlm[0].keys())

Map: 100%|██████████| 23767/23767 [00:01<00:00, 16513.69 examples/s]

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'],
    num_rows: 23767
})
dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'])





In [7]:
data_collator_mlm = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15,
)

train_loader_mlm = DataLoader(
    tokenized_mlm,
    batch_size=8,        # safe for Mac
    shuffle=True,
    collate_fn=data_collator_mlm,
)

optimizer_mlm = torch.optim.AdamW(mlm_model.parameters(), lr=5e-4, weight_decay=0.01)

In [8]:
mlm_model.train()
num_epochs = 1

for epoch in range(num_epochs):
    total_loss = 0.0
    valid_steps = 0

    pbar = tqdm(train_loader_mlm, desc=f"MLM Epoch {epoch+1}/{num_epochs}")
    for batch in pbar:
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = mlm_model(**batch)
        loss = outputs.loss

        # safety: skip rare NaN/Inf batches
        if torch.isnan(loss) or torch.isinf(loss):
            continue

        optimizer_mlm.zero_grad()
        loss.backward()
        optimizer_mlm.step()

        total_loss += loss.item()
        valid_steps += 1
        pbar.set_postfix(loss=float(loss.item()))

    avg_loss = total_loss / max(valid_steps, 1)
    print(f"MLM Epoch {epoch+1}: avg loss = {avg_loss:.6f} | valid steps: {valid_steps}/{len(train_loader_mlm)}")

MLM Epoch 1/1: 100%|██████████| 2971/2971 [02:43<00:00, 18.12it/s, loss=7.05]

MLM Epoch 1: avg loss = 7.016170 | valid steps: 2971/2971





In [9]:
MLM_SAVE_DIR = "./bert_mlm_scratch_final"
os.makedirs(MLM_SAVE_DIR, exist_ok=True)

mlm_model.save_pretrained(MLM_SAVE_DIR)
tokenizer.save_pretrained(MLM_SAVE_DIR)

print("Saved MLM model to:", MLM_SAVE_DIR)

Writing model shards: 100%|██████████| 1/1 [00:00<00:00, 31.90it/s]

Saved MLM model to: ./bert_mlm_scratch_final





In [10]:
encoder = BertModel.from_pretrained(MLM_SAVE_DIR).to(device)

hidden = encoder.config.hidden_size
classifier_head = nn.Linear(hidden * 3, 3).to(device)  # 3 classes for NLI

criterion = nn.CrossEntropyLoss()

print("Encoder hidden:", hidden)
print("Classifier in:", hidden*3, "out:", 3)

Loading weights: 100%|██████████| 69/69 [00:00<00:00, 2039.97it/s, Materializing param=encoder.layer.3.output.dense.weight]              
[1mBertModel LOAD REPORT[0m from: ./bert_mlm_scratch_final
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.bias                       | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
pooler.dense.weight                        | MISSING    | 
pooler.dense.bias                          | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


Encoder hidden: 256
Classifier in: 768 out: 3


In [11]:
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    # last_hidden_state: (batch, seq, hidden)
    # attention_mask:    (batch, seq)
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)  # (batch, seq, 1)
    summed = torch.sum(last_hidden_state * mask, dim=1)
    counts = torch.clamp(mask.sum(dim=1), min=1e-9)
    return summed / counts

def configurations(u: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    uv_abs = torch.abs(u - v)
    return torch.cat([u, v, uv_abs], dim=-1)

In [12]:
snli = load_dataset("snli")

# Filter out -1 labels (missing)
train_snli = snli["train"].filter(lambda x: x["label"] in [0,1,2]).select(range(20000))
val_snli   = snli["validation"].filter(lambda x: x["label"] in [0,1,2]).select(range(5000))

print(train_snli)
print(val_snli)
print(train_snli[0])

Generating test split: 100%|██████████| 10000/10000 [00:00<00:00, 2282117.63 examples/s]
Generating validation split: 100%|██████████| 10000/10000 [00:00<00:00, 3517236.06 examples/s]
Generating train split: 100%|██████████| 550152/550152 [00:00<00:00, 7659487.07 examples/s]
Filter: 100%|██████████| 550152/550152 [00:00<00:00, 1080020.96 examples/s]
Filter: 100%|██████████| 10000/10000 [00:00<00:00, 906288.68 examples/s]

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 20000
})
Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 5000
})
{'premise': 'A person on a horse jumps over a broken down airplane.', 'hypothesis': 'A person is training his horse for a competition.', 'label': 1}





In [13]:
def tokenize_snli(batch):
    prem = tokenizer(
        batch["premise"],
        truncation=True,
        padding="max_length",
        max_length=max_len,
    )
    hyp = tokenizer(
        batch["hypothesis"],
        truncation=True,
        padding="max_length",
        max_length=max_len,
    )
    return {
        "premise_input_ids": prem["input_ids"],
        "premise_attention_mask": prem["attention_mask"],
        "hypothesis_input_ids": hyp["input_ids"],
        "hypothesis_attention_mask": hyp["attention_mask"],
        "labels": batch["label"],
    }

train_tok = train_snli.map(tokenize_snli, batched=True, remove_columns=train_snli.column_names)
val_tok   = val_snli.map(tokenize_snli, batched=True, remove_columns=val_snli.column_names)

print(train_tok[0].keys())

Map: 100%|██████████| 20000/20000 [00:01<00:00, 18945.91 examples/s]
Map: 100%|██████████| 5000/5000 [00:00<00:00, 19267.54 examples/s]

dict_keys(['premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'])





In [14]:
def collate_nli(features):
    # convert list of dicts to torch tensors
    batch = {}
    for k in features[0].keys():
        batch[k] = torch.tensor([f[k] for f in features])
    return batch

train_loader = DataLoader(train_tok, batch_size=16, shuffle=True, collate_fn=collate_nli)
val_loader   = DataLoader(val_tok, batch_size=32, shuffle=False, collate_fn=collate_nli)

In [15]:
optimizer = torch.optim.AdamW(
    list(encoder.parameters()) + list(classifier_head.parameters()),
    lr=2e-5,
    weight_decay=0.01,
)

num_epochs = 1  # baseline; you can raise to 2-3 if time
for epoch in range(num_epochs):
    encoder.train()
    classifier_head.train()
    total_loss = 0.0

    pbar = tqdm(train_loader, desc=f"SBERT Epoch {epoch+1}/{num_epochs}")
    for batch in pbar:
        input_ids_a = batch["premise_input_ids"].to(device)
        input_ids_b = batch["hypothesis_input_ids"].to(device)
        attn_a = batch["premise_attention_mask"].to(device)
        attn_b = batch["hypothesis_attention_mask"].to(device)
        labels = batch["labels"].to(device).long()

        out_a = encoder(input_ids_a, attention_mask=attn_a)
        out_b = encoder(input_ids_b, attention_mask=attn_b)

        u = mean_pool(out_a.last_hidden_state, attn_a)
        v = mean_pool(out_b.last_hidden_state, attn_b)

        x = configurations(u, v)
        logits = classifier_head(x)

        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix(loss=float(loss.item()))

    print(f"SBERT Epoch {epoch+1}: avg loss = {total_loss/len(train_loader):.6f}")

SBERT Epoch 1/1: 100%|██████████| 1250/1250 [01:16<00:00, 16.27it/s, loss=0.911]

SBERT Epoch 1: avg loss = 1.053322





In [16]:
encoder.eval()
classifier_head.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Eval"):
        input_ids_a = batch["premise_input_ids"].to(device)
        input_ids_b = batch["hypothesis_input_ids"].to(device)
        attn_a = batch["premise_attention_mask"].to(device)
        attn_b = batch["hypothesis_attention_mask"].to(device)
        labels = batch["labels"].to(device).long()

        out_a = encoder(input_ids_a, attention_mask=attn_a)
        out_b = encoder(input_ids_b, attention_mask=attn_b)

        u = mean_pool(out_a.last_hidden_state, attn_a)
        v = mean_pool(out_b.last_hidden_state, attn_b)

        x = configurations(u, v)
        logits = classifier_head(x)

        preds = torch.argmax(logits, dim=-1)

        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

target_names = ["entailment", "neutral", "contradiction"]  # SNLI mapping is 0/1/2
print(classification_report(all_labels, all_preds, target_names=target_names))

Eval: 100%|██████████| 157/157 [00:05<00:00, 27.91it/s]

               precision    recall  f1-score   support

   entailment       0.50      0.65      0.56      1685
      neutral       0.50      0.42      0.46      1650
contradiction       0.47      0.40      0.43      1665

     accuracy                           0.49      5000
    macro avg       0.49      0.49      0.48      5000
 weighted avg       0.49      0.49      0.48      5000






In [17]:
SBERT_SAVE_DIR = "./sbert_nli_model"
os.makedirs(SBERT_SAVE_DIR, exist_ok=True)

encoder.save_pretrained(SBERT_SAVE_DIR)
tokenizer.save_pretrained(SBERT_SAVE_DIR)
torch.save(classifier_head.state_dict(), os.path.join(SBERT_SAVE_DIR, "classifier_head.pt"))

print("Saved SBERT encoder + classifier to:", SBERT_SAVE_DIR)

Writing model shards: 100%|██████████| 1/1 [00:00<00:00, 22.95it/s]


Saved SBERT encoder + classifier to: ./sbert_nli_model
