In [22]:
import pathlib

import torch

In [23]:
CHECKPOINT_DIR = pathlib.Path("checkpoints")

In [24]:
checkpoint = torch.load(CHECKPOINT_DIR / "final_model.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [25]:
from transformers import AutoTokenizer

from train import ContrastiveCollator, DualEncoder, NLProofDataset

args = checkpoint["args"]
# Load tokenizers
nl_tokenizer = AutoTokenizer.from_pretrained(args["nl_model"])
proof_tokenizer = AutoTokenizer.from_pretrained(args["proof_model"])

# Create model
model = DualEncoder(
    nl_model_name=args["nl_model"],
    proof_model_name=args["proof_model"],
    projection_dim=args["projection_dim"],
    freeze_nl=args["freeze_nl"],
    freeze_proof=args["freeze_proof"],
    dropout=args["dropout"],
).to(device)

# Create dataset
dataset = NLProofDataset(
    csv_path=args["csv_path"],
    nl_column=args["nl_column"],
    max_seq_len=args["max_seq_len"],
    max_samples=args["max_samples"],
    concat_states=False,
)

# Split into train/val
val_size = int(len(dataset) * args["val_split"])
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(args["seed"]),
)

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

# Create dataloaders
collator = ContrastiveCollator(
    nl_tokenizer=nl_tokenizer,
    proof_tokenizer=proof_tokenizer,
    nl_max_len=args["nl_max_len"],
    proof_max_len=args["proof_max_len"],
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args["batch_size"],
    shuffle=True,
    collate_fn=collator,
    num_workers=args["num_workers"],
    pin_memory=True,
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args["batch_size"],
    shuffle=False,
    collate_fn=collator,
    num_workers=args["num_workers"],
    pin_memory=True,
)

NL encoder: sentence-transformers/all-MiniLM-L6-v2 (hidden=384, frozen=True)
Proof encoder: kaiyuy/leandojo-lean4-retriever-byt5-small (hidden=1472, frozen=True)
Projection dim: 256
Loaded 29991 NL-proof pairs
Train samples: 26992, Val samples: 2999


In [26]:
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [27]:
from train import evaluate

metrics = evaluate(model=model, dataloader=val_loader, device=device, ks=[10, 25, 50, 100])

Evaluating: 100%|██████████| 94/94 [03:10<00:00,  2.02s/it]


In [28]:
metrics

{'nl_to_proof_R@10': 0.4291430413722992,
 'nl_to_proof_R@25': 0.6415472030639648,
 'nl_to_proof_R@50': 0.7792597413063049,
 'nl_to_proof_R@100': 0.8886295557022095,
 'proof_to_nl_R@10': 0.5055018067359924,
 'proof_to_nl_R@25': 0.6998999714851379,
 'proof_to_nl_R@50': 0.8196065425872803,
 'proof_to_nl_R@100': 0.9103034138679504}