In [1]:
import os, sys, torch
from transformers import AutoModel, AutoTokenizer
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append("../")
from src.models.model_head import HistoricalTextDatingModel, create_model_head_config
from src.utils import init_tracker, DataLoadAndFilter


In [None]:
# 1) Set your checkpoint path here
ckpt_base_path = "../outputs/noise-adding/google-bert/bert-base-multilingual-cased/2025-09-22/01-53-30/"
ckpt_path = ckpt_base_path + "None_noise_std0.01_noise_typegaussian.pt"

if not os.path.exists(ckpt_path):
    raise FileNotFoundError(f"Checkpoint path does not exist: {ckpt_path}")
if not os.path.exists(config_path):
    raise FileNotFoundError(f"Config path does not exist: {config_path}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
ckpt = torch.load(ckpt_path, map_location=device)

In [None]:
cfg = OmegaConf.load(config_path)

In [6]:

# Load the model and tokenizer
tokenizer = instantiate(cfg.model.tokenizer)
encoder = instantiate(cfg.model.encoder)

model_head_config = create_model_head_config(**cfg.model.model_head.head_config)

# Load datasets using the data loader
data_loader = DataLoadAndFilter(cfg)
train_dataset, eval_dataset = data_loader.create_tokenized_datasets(tokenizer, base_path="../")

model_head_config["num_classes"] = len(train_dataset[0][1])


Error loading schema ../data/raw/SefariaData/Sefaria-Export-master/schemas/Sheet.json: Expecting value: line 1 column 1 (char 0)
Loading Sefaria texts: 100%|██████████| 100/100 [00:02<00:00, 35.78it/s]
Loading Ben Yehuda texts: 100%|██████████| 100/100 [00:00<00:00, 3536.84it/s]


In [7]:
model_head_config

{'hidden_sizes': [384, 128], 'dropout_rate': 0.1, 'activation': 'relu', 'pooling_strategy': 'cls', 'num_classes': 28}

In [8]:
model = HistoricalTextDatingModel(
    encoder=encoder,
    head_config=model_head_config,
    freeze_encoder=True,  # Freeze encoder to not change bert
)

In [9]:
def _clean_state_dict(state_dict: dict) -> dict:
    if not any(k.startswith("module.") for k in state_dict.keys()):
        return state_dict
    return {k.replace("module.", "", 1): v for k, v in state_dict.items()}

In [10]:
clean_ckpt = _clean_state_dict(ckpt)

In [11]:
missing, unexpected = model.load_state_dict(clean_ckpt, strict=False)
if missing:
    print(f"[load_state_dict] Missing keys: {len(missing)} (showing first 10)\n", missing[:10])
if unexpected:
    print(f"[load_state_dict] Unexpected keys: {len(unexpected)} (showing first 10)\n", unexpected[:10])
model.eval()


RuntimeError: Error(s) in loading state_dict for HistoricalTextDatingModel:
	size mismatch for head.head.6.weight: copying a param with shape torch.Size([36, 128]) from checkpoint, the shape in current model is torch.Size([28, 128]).
	size mismatch for head.head.6.bias: copying a param with shape torch.Size([36]) from checkpoint, the shape in current model is torch.Size([28]).

In [23]:
print("Model loaded to:", device)


Model loaded to: cpu


In [30]:

def predict_date(text: str) -> float:
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        preds, _ = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"))
    # preds shape: (batch,) -> take scalar
    return preds[0]

In [None]:
example_text = "This is a sample historical text snippet to test the model."
print("Predicted date:", predict_date(example_text))
