In [1]:
import torch

from torch.utils.data import DataLoader

In [2]:
from bert_scratch import TokenVectorBERT, mask_input_tokens, masked_mse_loss

from dataset import CVFConfigForBertDataset

Total configs: 32.
Dataset: implicit_graph_n5 | Size: 25,935


In [3]:
device = "cuda"

In [4]:
model_name = "bert_trained_at_2025_04_15_21_45"

In [5]:
# Model class must be defined somewhere
model = torch.load(f"trained_models/{model_name}.pt", weights_only=False)
model.eval()

TokenVectorBERT(
  (token_proj): Linear(in_features=5, out_features=64, bias=True)
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(1, 64, padding_idx=0)
      (position_embeddings): Embedding(128, 64)
      (token_type_embeddings): Embedding(2, 64)
      (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=64, out_features=64, bias=True)
              (key): Linear(in_features=64, out_features=64, bias=True)
              (value): Linear(in_features=64, out_features=64, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=64, out_features=64, bias=True)
              (La

In [11]:
model.eval()

dataset = CVFConfigForBertDataset(
    device,
    "implicit_graph_n5",
    "implicit_graph_n5_pt_adj_list.txt",
    D=5,
    program="dijkstra",
)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

total_loss = 0.0
for batch in loader:
    x = batch[0]
    attention_mask = batch[1]

    masked_inputs, target_labels, loss_mask = mask_input_tokens(
        x, model.mask_vector, mask_prob=0.15
    )
    logits = model(masked_inputs, attention_mask)
    print(logits[loss_mask], target_labels[loss_mask])
    loss = masked_mse_loss(logits, target_labels, loss_mask)
    total_loss += loss
    break

print(f"Test dataset | Loss: {total_loss.item()/ len(loader):.4f}")

Total configs: 32.


tensor([[ 9.7703,  0.1206,  0.2069,  0.0792,  0.9415],
        [ 9.6442,  0.2406,  0.8976,  0.2811,  1.1030],
        [13.4797, -0.0974,  0.9789,  0.7988,  0.4160],
        [-1.0992, -1.0370, -1.0545, -1.0473, -1.0454]],
       grad_fn=<IndexBackward0>) tensor([[ 9.,  0.,  0.,  0.,  1.],
        [ 9.,  0.,  1.,  0.,  1.],
        [12.,  0.,  1.,  1.,  0.],
        [-1., -1., -1., -1., -1.]])
Test dataset | Loss: 0.0000
