In [1]:
import os, sys, torch
from transformers import AutoModel, AutoTokenizer, AutoModelForMaskedLM
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append("../")
from src.models.noised_encoder import NoisedEncoder


In [3]:
# 1) Set your checkpoint path here
ckpt_base_path = "../outputs/noise-adding/google-bert/bert-base-multilingual-cased/2025-09-22/04-45-25/"
ckpt_path = ckpt_base_path + "None_noise_std0.01_noise_typegaussian.pt"
config_path = ckpt_base_path + ".hydra/config.yaml"

if not os.path.exists(ckpt_path):
    raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
if not os.path.exists(config_path):
    raise FileNotFoundError(f"Config path does not exist: {config_path}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
ckpt = torch.load(ckpt_path, map_location=device)
cfg = OmegaConf.load(config_path)

In [27]:
# Load the model and tokenizer
tokenizer = instantiate(cfg.model.tokenizer)
encoder = instantiate(cfg.model.encoder)
noised_encoder = instantiate(cfg.model.encoder)

Some weights of the model checkpoint at google-bert/bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at google-bert/bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing Be

In [19]:
original_std = cfg.training.noise.noise_std

In [50]:
cfg.training.noise.noise_std = 0.005

In [51]:

temp_encoder = instantiate(cfg.model.encoder)
new_noised_encoder = NoisedEncoder(cfg, temp_encoder)

Some weights of the model checkpoint at google-bert/bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [32]:
def _clean_state_dict(state_dict: dict) -> dict:
    if not any(k.startswith("encoder.") for k in state_dict.keys()):
        return state_dict
    return {k.replace("encoder.", "", 1): v for k, v in state_dict.items()}
clean_ckpt = _clean_state_dict(ckpt)

In [33]:
# Load state dict into the noised encoder
missing, unexpected = noised_encoder.load_state_dict(clean_ckpt, strict=False)
if missing:
    print(f"[load_state_dict] Missing keys: {len(missing)} (showing first 10)\n", missing[:10])
if unexpected:
    print(f"[load_state_dict] Unexpected keys: {len(unexpected)} (showing first 10)\n", unexpected[:10])
noised_encoder.eval()
print("Model loaded to:", device)

Model loaded to: cpu


In [34]:

def run_model_over_text(model, text: str) -> float:
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        res = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"))
    # preds shape: (batch,) -> take scalar
    return res

In [35]:
example_text = "This is a sample historical text snippet to test the model."
print(run_model_over_text(encoder, example_text))
print(run_model_over_text(noised_encoder, example_text))
print(run_model_over_text(new_noised_encoder, example_text))

MaskedLMOutput(loss=None, logits=tensor([[[ -9.2823,  -9.2035,  -9.2564,  ...,  -8.9687,  -8.7790,  -9.1370],
         [-11.5207, -10.9052, -10.7849,  ...,  -8.9367,  -8.1311,  -9.7978],
         [-14.0182, -14.2658, -14.0961,  ..., -10.7366, -10.5054, -11.6423],
         ...,
         [-11.7990, -11.6905, -12.1672,  ..., -11.3576, -10.9677, -11.5446],
         [-11.9769, -11.8335, -11.6036,  ..., -10.7477, -10.2442, -11.4482],
         [ -9.9157,  -9.4814,  -9.4806,  ...,  -8.2701,  -8.3781,  -8.9941]]]), hidden_states=None, attentions=None)
MaskedLMOutput(loss=None, logits=tensor([[[-8.2281, -8.3524, -7.7910,  ..., -6.6339, -6.9057, -6.1463],
         [-7.7523, -8.1757, -7.4527,  ..., -6.8175, -6.5724, -5.9234],
         [-8.4843, -8.9459, -8.0856,  ..., -7.2133, -6.9029, -6.5164],
         ...,
         [-8.7283, -8.5921, -8.0858,  ..., -7.0931, -7.4747, -6.7143],
         [-8.5823, -8.4527, -7.8191,  ..., -6.6339, -6.9772, -6.2922],
         [-7.5783, -7.7716, -7.0237,  ..., -5.809

In [36]:
def predict_masked_word(text: str, tokenizer, model, k=5):

    # 3. Tokenize the input
    inputs = tokenizer(text, return_tensors="pt")

    # 4. Find the position of the mask token
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

    # 5. Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits

    # 6. Get the top predicted tokens for the mask
    mask_token_logits = predictions[0, mask_token_index, :]
    top_k_tokens = torch.topk(mask_token_logits, k, dim=1).indices.squeeze().tolist()

    # 7. Decode and print the predictions
    results = []
    for token_id in top_k_tokens:
        print(f"Predicted: {tokenizer.decode([token_id])}")
        results.append(tokenizer.decode([token_id]))
    return results


In [52]:
for enc in [encoder, noised_encoder, new_noised_encoder]:
    print("Using encoder:", enc.__class__.__name__)
    predict_masked_word(f"The capital of France is {tokenizer.mask_token}.", tokenizer, enc)


Using encoder: BertForMaskedLM
Predicted: Paris
Predicted: Rome
Predicted: Metz
Predicted: d
Predicted: Strasbourg
Using encoder: BertForMaskedLM
Predicted: ,
Predicted: '
Predicted: )
Predicted: :
Predicted: a
Using encoder: NoisedEncoder
Predicted: France
Predicted: .
Predicted: Paris
Predicted: ,
Predicted: -
