In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from transformers import BertConfig, BertForMaskedLM, BertModel

In [2]:
from dataset import CVFConfigForBertDataset

In [3]:
class TokenVectorBERT(nn.Module):
    def __init__(self, input_dim=3, vocab_dim=64, bert_hidden=64, max_seq_len=128):
        super().__init__()

        self.token_proj = nn.Linear(
            input_dim, vocab_dim
        )  # turn [0, 0, 2] into an embedding
        self.config = BertConfig(
            vocab_size=1,  # dummy, unused
            hidden_size=bert_hidden,
            num_hidden_layers=2,
            num_attention_heads=2,
            intermediate_size=bert_hidden * 2,
            max_position_embeddings=max_seq_len,
            pad_token_id=0,
        )
        self.bert = BertModel(self.config)
        self.mlm_head = nn.Linear(bert_hidden, vocab_dim)
        self.decoder_proj = nn.Linear(
            vocab_dim, input_dim
        )  # to project back to [0, 0, 2] space

    def forward(self, input_vecs, attention_mask=None):
        # input_vecs: (batch_size, seq_len, input_dim) like (2, 4, 3)
        x = self.token_proj(input_vecs)  # (batch_size, seq_len, vocab_dim)
        outputs = self.bert(inputs_embeds=x, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        logits = self.mlm_head(sequence_output)
        pred_token = self.decoder_proj(logits)
        return pred_token

In [4]:
device = "cuda"

In [5]:
# ----- Dummy Data -----
batch_size = 2
seq_len = 4
token_dim = 4

# Learnable MASK token (for masking positions)
mask_vector = nn.Parameter(torch.zeros(token_dim))


# ----- Masking Function -----
def mask_input_tokens(inputs, mask_token, mask_prob=0.3):
    inputs = inputs.clone()
    labels = inputs.clone()
    mask = torch.rand(inputs[:, :, 0].shape) < mask_prob  # shape: (B, T)

    for i in range(inputs.size(0)):
        for j in range(inputs.size(1)):
            if mask[i, j]:
                inputs[i, j] = mask_token

    return inputs, labels, mask  # No NaNs here!


def masked_mse_loss(pred, target, mask):
    valid_tokens = mask.sum()
    if valid_tokens == 0:
        valid_tokens = 1e-8
    loss = (pred - target) ** 2
    loss = loss.mean(dim=-1)  # (B, T)
    loss = loss * mask.float()

    return loss.sum() / valid_tokens

In [6]:
dataset = CVFConfigForBertDataset(
    device,
    "graph_random_regular_graph_n4_d3",
    "graph_random_regular_graph_n4_d3_pt_adj_list.txt",
)

loader = DataLoader(dataset, batch_size=2, shuffle=False)

Total configs: 256.


In [7]:
len(loader)

336

In [8]:
model = TokenVectorBERT(input_dim=token_dim, vocab_dim=64, bert_hidden=64)
print()
print("Total parameters:{:,}".format(sum(p.numel() for p in model.parameters())))
print()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

epochs = 100

# print("inputs", inputs)
# print("mask_vector", mask_vector)
# print()

for epoch in range(epochs):
    model.train()

    total_loss = 0.0

    for batch in loader:
        x = batch[0]
        attention_mask = batch[1]

        masked_inputs, target_labels, loss_mask = mask_input_tokens(
            x, mask_vector, mask_prob=0.3
        )

        logits = model(masked_inputs, attention_mask)

        # Compute loss only on masked positions
        loss = masked_mse_loss(logits, target_labels, loss_mask)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss
        break

    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss.item()/ len(loader):.4f}")


Total parameters:84,356

Epoch 1/100 | Loss: 0.0055
Epoch 2/100 | Loss: 0.0026
Epoch 3/100 | Loss: 0.0016
Epoch 4/100 | Loss: 0.0007
Epoch 5/100 | Loss: 0.0004
Epoch 6/100 | Loss: 0.0000
Epoch 7/100 | Loss: 0.0066
Epoch 8/100 | Loss: 0.0029
Epoch 9/100 | Loss: 0.0015
Epoch 10/100 | Loss: 0.0014
Epoch 11/100 | Loss: 0.0077
Epoch 12/100 | Loss: 0.0000
Epoch 13/100 | Loss: 0.0021
Epoch 14/100 | Loss: 0.0039
Epoch 15/100 | Loss: 0.0012
Epoch 16/100 | Loss: 0.0013
Epoch 17/100 | Loss: 0.0035
Epoch 18/100 | Loss: 0.0029
Epoch 19/100 | Loss: 0.0015
Epoch 20/100 | Loss: 0.0032
Epoch 21/100 | Loss: 0.0035
Epoch 22/100 | Loss: 0.0022
Epoch 23/100 | Loss: 0.0023
Epoch 24/100 | Loss: 0.0029
Epoch 25/100 | Loss: 0.0041
Epoch 26/100 | Loss: 0.0024
Epoch 27/100 | Loss: 0.0024
Epoch 28/100 | Loss: 0.0032
Epoch 29/100 | Loss: 0.0016
Epoch 30/100 | Loss: 0.0022
Epoch 31/100 | Loss: 0.0021
Epoch 32/100 | Loss: 0.0024
Epoch 33/100 | Loss: 0.0030
Epoch 34/100 | Loss: 0.0023
Epoch 35/100 | Loss: 0.0024
Epo

In [18]:
model.eval()

loader = DataLoader(dataset, batch_size=1, shuffle=True)

for batch in loader:
    x = batch[0]
    attention_mask = batch[1]
    masked_inputs, target_labels, loss_mask = mask_input_tokens(
        x, mask_vector, mask_prob=0.4
    )
    logits = model(masked_inputs, attention_mask)
    if loss_mask.any():
        print("target", target_labels[loss_mask])
        print("predicted", logits[loss_mask])
        break

target tensor([[ 1.,  3.,  3.,  1.],
        [-1., -1., -1., -1.]])
predicted tensor([[0.5192, 0.4644, 0.4470, 0.3425],
        [0.6723, 1.1507, 0.8315, 0.4656]], grad_fn=<IndexBackward0>)
