In [None]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from transformers import BertConfig, BertForMaskedLM, BertModel

In [2]:
from dataset import CVFConfigForBertDataset

In [3]:
device = "cuda"

In [4]:
batch_size = 4
seq_len = 4
token_dim = 4

In [None]:
class TokenVectorBERT(nn.Module):
    
    def __init__(self, input_dim, vocab_dim=64, bert_hidden=64, max_seq_len=128):
        super().__init__()
        # Learnable MASK token (for masking positions)
        self.mask_vector = nn.Parameter(torch.zeros(token_dim), requires_grad=True)
        self.token_proj = nn.Linear(
            input_dim, vocab_dim
        )  # turn [0, 0, 2] into an embedding
        self.config = BertConfig(
            vocab_size=1,  # dummy, unused
            hidden_size=bert_hidden,
            num_hidden_layers=2,
            num_attention_heads=2,
            intermediate_size=bert_hidden * 2,
            max_position_embeddings=max_seq_len,
            pad_token_id=0,
        )
        self.bert = BertModel(self.config)
        self.mlm_head = nn.Linear(bert_hidden, vocab_dim)
        self.decoder_proj = nn.Linear(
            vocab_dim, input_dim
        )

    def forward(self, input_vecs, attention_mask=None):
        # input_vecs: (batch_size, seq_len, input_dim) like (2, 4, 3)
        x = self.token_proj(input_vecs)  # (batch_size, seq_len, vocab_dim)
        outputs = self.bert(inputs_embeds=x, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        logits = self.mlm_head(sequence_output)
        pred_token = self.decoder_proj(logits)
        return pred_token

In [6]:
# Learnable MASK token (for masking positions)
# mask_vector = nn.Parameter(torch.zeros(token_dim), requires_grad=True)

In [7]:
# ----- Masking Function -----
def mask_input_tokens(inputs, mask_vector, mask_prob=0.3):
    inputs = inputs.clone()
    labels = inputs.clone()
    mask = torch.rand(inputs[:, :, 0].shape) < mask_prob  # shape: (B, T)

    for i in range(inputs.size(0)):
        for j in range(inputs.size(1)):
            if mask[i, j]:
                inputs[i, j] = mask_vector

    return inputs, labels, mask


def masked_mse_loss(pred, target, mask):
    valid_tokens = mask.sum()
    if valid_tokens == 0:
        valid_tokens = 1e-8
    loss = (pred - target) ** 2
    loss = loss.mean(dim=-1)  # (B, T)
    loss = loss * mask.float()

    return loss.sum() / valid_tokens

In [8]:
dataset = CVFConfigForBertDataset(
    device,
    "graph_random_regular_graph_n4_d3",
    "graph_random_regular_graph_n4_d3_pt_adj_list.txt",
)

loader = DataLoader(dataset, batch_size=2, shuffle=False)

Total configs: 256.


In [10]:
model = TokenVectorBERT(input_dim=token_dim, vocab_dim=64, bert_hidden=64)
print()
print("Total parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
print()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)

epochs = 100

for epoch in range(epochs):
    model.train()

    total_loss = 0.0

    for batch in loader:
        x = batch[0]
        attention_mask = batch[1]

        masked_inputs, target_labels, loss_mask = mask_input_tokens(
            x, model.mask_vector, mask_prob=0.3
        )

        logits = model(masked_inputs, attention_mask)

        # Compute loss only on masked positions
        optimizer.zero_grad()
        loss = masked_mse_loss(logits, target_labels, loss_mask)
        loss.backward()
        optimizer.step()
        total_loss += loss

    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss.item()/ len(loader):.4f}")


Total parameters: 84,360

Epoch 1/100 | Loss: 1.4167
Epoch 2/100 | Loss: 1.0488
Epoch 3/100 | Loss: 0.8686
Epoch 4/100 | Loss: 0.8150
Epoch 5/100 | Loss: 0.7650
Epoch 6/100 | Loss: 0.8043
Epoch 7/100 | Loss: 0.7410
Epoch 8/100 | Loss: 0.7127
Epoch 9/100 | Loss: 0.7090
Epoch 10/100 | Loss: 0.6973
Epoch 11/100 | Loss: 0.6186
Epoch 12/100 | Loss: 0.6639
Epoch 13/100 | Loss: 0.7177
Epoch 14/100 | Loss: 0.6673
Epoch 15/100 | Loss: 0.6759
Epoch 16/100 | Loss: 0.7587
Epoch 17/100 | Loss: 0.6360
Epoch 18/100 | Loss: 0.6792
Epoch 19/100 | Loss: 0.6651
Epoch 20/100 | Loss: 0.6384
Epoch 21/100 | Loss: 0.7616
Epoch 22/100 | Loss: 0.6630
Epoch 23/100 | Loss: 0.8152
Epoch 24/100 | Loss: 0.8196
Epoch 25/100 | Loss: 0.7253
Epoch 26/100 | Loss: 0.6590
Epoch 27/100 | Loss: 0.6849
Epoch 28/100 | Loss: 0.7078
Epoch 29/100 | Loss: 0.6560
Epoch 30/100 | Loss: 0.6556
Epoch 31/100 | Loss: 0.6483
Epoch 32/100 | Loss: 0.6421
Epoch 33/100 | Loss: 0.6460
Epoch 34/100 | Loss: 0.6677
Epoch 35/100 | Loss: 0.6072
Ep

In [13]:
model.eval()

loader = DataLoader(dataset, batch_size=1, shuffle=True)

for batch in loader:
    x = batch[0]
    attention_mask = batch[1]
    masked_inputs, target_labels, loss_mask = mask_input_tokens(
        x, model.mask_vector, mask_prob=0.4
    )
    logits = model(masked_inputs, attention_mask)
    if loss_mask.any():
        print("target", target_labels[loss_mask])
        print("predicted", logits[loss_mask])
        break

target tensor([[1., 0., 0., 0.],
        [1., 0., 2., 0.]])
predicted tensor([[1.8169, 2.2958, 2.0376, 0.9434],
        [1.7154, 2.8398, 2.2903, 0.5694]], grad_fn=<IndexBackward0>)
