In [1]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from transformers import BertConfig, BertModel

In [2]:
from bert_scratch import TokenVectorBERT
from dataset import CVFConfigForBertFTDataset

Total configs: 32.
Dataset: implicit_graph_n5 | Size: 25,935


In [3]:
model = "bert_trained_at_2025_04_16_20_28"

In [4]:
device = "cuda"

In [5]:
# Model class must be defined somewhere
pt_model = torch.load(f"trained_models/{model}.pt", weights_only=False)
pt_model.eval()

TokenVectorBERT(
  (token_proj): Linear(in_features=5, out_features=64, bias=True)
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(1, 64, padding_idx=0)
      (position_embeddings): Embedding(128, 64)
      (token_type_embeddings): Embedding(2, 64)
      (LayerNorm): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=64, out_features=64, bias=True)
              (key): Linear(in_features=64, out_features=64, bias=True)
              (value): Linear(in_features=64, out_features=64, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=64, out_features=64, bias=True)
              (La

In [6]:
dataset = CVFConfigForBertFTDataset(
    device,
    "implicit_graph_n5",
    "implicit_graph_n5_config_rank_dataset.csv",
    D=5,
    program="dijkstra",
)

Total configs: 32.


In [7]:
class TokenVectorBERTFineTune(nn.Module):
    def __init__(self, bert_model, vocab_dim):
        super().__init__()
        self.bert = bert_model
        self.output = nn.Linear(vocab_dim, 1)

    def forward(self, input_vecs, attention_mask=None):
        result = self.bert(input_vecs, attention_mask)
        result = self.output(result)
        return result

In [8]:
model = TokenVectorBERTFineTune(pt_model, 5)
model.to(device)

for param in model.bert.parameters():
  param.requires_grad = False

In [9]:
loader = DataLoader(dataset, batch_size=64, shuffle=False)

In [10]:
epochs = 2000

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
mse_loss = torch.nn.MSELoss()

model.train()
for epoch in range(epochs):
    total_loss = 0.0
    for batch in loader:
        x = batch[0]
        y = batch[1].unsqueeze(-1)

        logits = model(x)

        # Compute loss only on masked positions
        optimizer.zero_grad()
        loss = mse_loss(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss

    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss.item()/ len(loader):.4f}")

Epoch 1/2000 | Loss: 63.2336
Epoch 2/2000 | Loss: 62.7239
Epoch 3/2000 | Loss: 63.1764
Epoch 4/2000 | Loss: 62.8454
Epoch 5/2000 | Loss: 62.9568
Epoch 6/2000 | Loss: 62.0412
Epoch 7/2000 | Loss: 62.2310
Epoch 8/2000 | Loss: 61.9255
Epoch 9/2000 | Loss: 62.0515
Epoch 10/2000 | Loss: 62.2776
Epoch 11/2000 | Loss: 61.6669
Epoch 12/2000 | Loss: 61.8020
Epoch 13/2000 | Loss: 61.2020
Epoch 14/2000 | Loss: 61.1573
Epoch 15/2000 | Loss: 60.8682
Epoch 16/2000 | Loss: 60.7196
Epoch 17/2000 | Loss: 60.8477
Epoch 18/2000 | Loss: 60.1763
Epoch 19/2000 | Loss: 60.1383
Epoch 20/2000 | Loss: 59.6179
Epoch 21/2000 | Loss: 59.7598
Epoch 22/2000 | Loss: 59.8440
Epoch 23/2000 | Loss: 59.9153
Epoch 24/2000 | Loss: 59.3083
Epoch 25/2000 | Loss: 58.7868
Epoch 26/2000 | Loss: 59.0105
Epoch 27/2000 | Loss: 58.8380
Epoch 28/2000 | Loss: 58.5446
Epoch 29/2000 | Loss: 58.7325
Epoch 30/2000 | Loss: 58.0860
Epoch 31/2000 | Loss: 57.6173
Epoch 32/2000 | Loss: 57.9776
Epoch 33/2000 | Loss: 58.4452
Epoch 34/2000 | Los