In [1]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from transformers import BertConfig, BertModel

In [2]:
from bert_scratch import TokenVectorBERT
from dataset import CVFConfigForBertFTDataset

Total configs: 243.
Dataset: implicit_graph_n5 | Size: 25,935


In [3]:
model = "bert_trained_at_2025_04_22_13_53"

In [4]:
device = "cuda"

In [5]:
# Model class must be defined somewhere
pt_model = torch.load(f"trained_models/{model}.pt", weights_only=False)
pt_model.eval()

pt_model.decoder_proj = nn.Identity()

In [6]:
dataset = CVFConfigForBertFTDataset(
    device,
    "implicit_graph_n5",
    "implicit_graph_n5_config_rank_dataset.csv",
    D=5,
    program="dijkstra",
)

Total configs: 243.


In [7]:
class TokenVectorBERTFineTune(nn.Module):
    def __init__(self, bert_model, vocab_dim):
        super().__init__()
        self.bert = bert_model
        self.linear1 = nn.Linear(vocab_dim, vocab_dim)
        self.output = nn.Linear(vocab_dim, 1)

    def forward(self, input_vecs, attention_mask=None):
        result = self.bert(input_vecs, attention_mask)
        result = torch.relu(self.linear1(result))
        result = self.output(result)
        return result

In [8]:
model = TokenVectorBERTFineTune(pt_model, 64)
model.to(device)

for param in model.bert.parameters():
  param.requires_grad = False

In [9]:
loader = DataLoader(dataset, batch_size=64, shuffle=False)

In [10]:
epochs = 2000

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
mse_loss = torch.nn.MSELoss()

model.train()

for epoch in range(epochs):
    total_loss = 0.0
    for batch in loader:
        x = batch[0]
        y = batch[1].unsqueeze(-1)
        # attention_mask = torch.full(x.shape, True)
        out = model(x)
        optimizer.zero_grad()
        loss = mse_loss(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss

    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss.item()/ len(loader):.4f}")

Epoch 1/2000 | Loss: 65.2228
Epoch 2/2000 | Loss: 63.8824
Epoch 3/2000 | Loss: 62.4234
Epoch 4/2000 | Loss: 61.0671
Epoch 5/2000 | Loss: 59.5289
Epoch 6/2000 | Loss: 58.0176
Epoch 7/2000 | Loss: 56.0161
Epoch 8/2000 | Loss: 54.3955
Epoch 9/2000 | Loss: 52.9427
Epoch 10/2000 | Loss: 51.0499
Epoch 11/2000 | Loss: 48.6695
Epoch 12/2000 | Loss: 47.2064
Epoch 13/2000 | Loss: 44.7304
Epoch 14/2000 | Loss: 43.1874
Epoch 15/2000 | Loss: 41.4020
Epoch 16/2000 | Loss: 38.9940
Epoch 17/2000 | Loss: 36.9384
Epoch 18/2000 | Loss: 35.7101
Epoch 19/2000 | Loss: 33.6983
Epoch 20/2000 | Loss: 32.1524
Epoch 21/2000 | Loss: 31.2018
Epoch 22/2000 | Loss: 30.1594
Epoch 23/2000 | Loss: 27.9235
Epoch 24/2000 | Loss: 28.3686
Epoch 25/2000 | Loss: 27.4848
Epoch 26/2000 | Loss: 25.8553
Epoch 27/2000 | Loss: 26.1908
Epoch 28/2000 | Loss: 26.5324
Epoch 29/2000 | Loss: 26.2188
Epoch 30/2000 | Loss: 24.3877
Epoch 31/2000 | Loss: 24.7542
Epoch 32/2000 | Loss: 24.6909
Epoch 33/2000 | Loss: 24.5208
Epoch 34/2000 | Los