In [1]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from transformers import BertConfig, BertModel

In [2]:
from bert_scratch import TokenVectorBERT
from dataset import CVFConfigForBertFTDataset

Total configs: 243.
Dataset: implicit_graph_n5 | Size: 25,935


In [3]:
model = "bert_trained_at_2025_04_22_20_56"

In [4]:
device = "cuda"

In [5]:
# Model class must be defined somewhere
pt_model = torch.load(f"trained_models/{model}.pt", weights_only=False)
pt_model.eval()

pt_model.decoder_proj = nn.Identity()

In [6]:
dataset = CVFConfigForBertFTDataset(
    device,
    "implicit_graph_n5",
    "implicit_graph_n5_config_rank_dataset.csv",
    D=5,
    program="dijkstra",
)

Total configs: 243.


In [7]:
class TokenVectorBERTFineTune(nn.Module):
    def __init__(self, bert_model, vocab_dim):
        super().__init__()
        self.bert = bert_model
        self.linear1 = nn.Linear(vocab_dim, vocab_dim)
        self.output = nn.Linear(vocab_dim, 1)

    def forward(self, input_vecs, attention_mask=None):
        result = self.bert(input_vecs, attention_mask)
        result = torch.relu(self.linear1(result))
        result = self.output(result)
        return result

In [8]:
model = TokenVectorBERTFineTune(pt_model, 64)
model.to(device)

for param in model.bert.parameters():
  param.requires_grad = False

In [9]:
loader = DataLoader(dataset, batch_size=64, shuffle=False)

In [10]:
epochs = 1

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
mse_loss = torch.nn.MSELoss()

model.train()

for epoch in range(epochs):
    total_loss = 0.0
    for batch in loader:
        x = batch[0]
        y = batch[1].unsqueeze(-1)
        # attention_mask = torch.full(x.shape, True)
        out = model(x)
        print(out)
        print(y)
        optimizer.zero_grad()
        loss = mse_loss(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss


    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss.item()/ len(loader):.4f}")

tensor([[[ 0.0425]],

        [[ 0.0434]],

        [[-0.0036]],

        [[-0.0068]],

        [[ 0.0286]],

        [[-0.0101]],

        [[ 0.1386]],

        [[ 0.0309]],

        [[ 0.0433]],

        [[ 0.0295]],

        [[-0.0015]],

        [[-0.0232]],

        [[ 0.0586]],

        [[-0.0007]],

        [[ 0.0494]],

        [[ 0.0783]],

        [[ 0.0260]],

        [[-0.0054]],

        [[ 0.0283]],

        [[ 0.0490]],

        [[-0.0058]],

        [[ 0.0452]],

        [[ 0.0304]],

        [[ 0.0225]],

        [[ 0.0326]],

        [[ 0.0255]],

        [[ 0.0097]],

        [[ 0.0553]],

        [[ 0.0505]],

        [[ 0.0455]],

        [[ 0.1465]],

        [[ 0.0951]],

        [[ 0.0477]],

        [[ 0.1458]],

        [[-0.0040]],

        [[ 0.0175]],

        [[ 0.0255]],

        [[ 0.0350]],

        [[ 0.0122]],

        [[-0.0545]],

        [[ 0.1213]],

        [[-0.0153]],

        [[ 0.0739]],

        [[-0.0113]],

        [[ 0.0191]],

        [[