In [1]:
import torch
import torch.nn as nn
from transformers import BertConfig, BertForMaskedLM, BertModel

class TokenVectorBERT(nn.Module):
    def __init__(self, input_dim=3, vocab_dim=64, bert_hidden=64, max_seq_len=128):
        super().__init__()

        self.token_proj = nn.Linear(input_dim, vocab_dim)  # turn [0, 0, 2] into an embedding
        self.config = BertConfig(
            vocab_size=1,  # dummy, unused
            hidden_size=bert_hidden,
            num_hidden_layers=2,
            num_attention_heads=2,
            intermediate_size=bert_hidden * 2,
            max_position_embeddings=max_seq_len,
            pad_token_id=0
        )
        self.bert = BertModel(self.config)
        self.mlm_head = nn.Linear(bert_hidden, vocab_dim)

    def forward(self, input_vecs, attention_mask=None):
        # input_vecs: (batch_size, seq_len, input_dim) like (2, 4, 3)
        x = self.token_proj(input_vecs)  # (batch_size, seq_len, vocab_dim)
        outputs = self.bert(inputs_embeds=x, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        return self.mlm_head(sequence_output)


In [None]:
# ----- Dummy Data -----
batch_size = 1
seq_len = 3
token_dim = 4

# Simulate tokens: each token is a 3D vector with values 0–5
torch.manual_seed(42)
inputs = torch.randint(0, 6, (batch_size, seq_len, token_dim)).float()

# Attention mask: 1s for real tokens
attention_mask = torch.ones(batch_size, seq_len).long()

# Learnable MASK token (for masking positions)
mask_vector = nn.Parameter(torch.zeros(token_dim))

# ----- Masking Function -----
def mask_input_tokens(inputs, mask_token, mask_prob=0.3):
    inputs = inputs.clone()
    labels = inputs.clone()
    mask = torch.rand(inputs[:, :, 0].shape) < mask_prob  # shape: (B, T)

    for i in range(inputs.size(0)):
        for j in range(inputs.size(1)):
            if mask[i, j]:
                inputs[i, j] = mask_token

    return inputs, labels, mask  # No NaNs here!

def masked_mse_loss(pred, target, mask):
    # print("pred", pred)
    # print()
    # print("target", target)
    loss = (pred - target) ** 2
    print(loss.shape)
    loss = loss.mean(dim=-1)  # (B, T)
    print(loss.shape)
    loss = loss * mask.float()
    print(loss.shape)
    
    valid_tokens = mask.sum()
    return loss.sum() / (valid_tokens + 1e-8)


In [3]:
model = TokenVectorBERT(input_dim=token_dim, vocab_dim=64)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)

epochs = 1

print("inputs", inputs)
print("mask_vector", mask_vector)
print()

for epoch in range(epochs):
    model.train()

    masked_inputs, target_labels, loss_mask = mask_input_tokens(inputs, mask_vector)

    print("masked", masked_inputs)
    print("target_labels", target_labels)
    print("loss_mask", loss_mask)
    print()

    logits = model(masked_inputs, attention_mask)

    # Compute loss only on masked positions
    loss = masked_mse_loss(logits, model.token_proj(target_labels), loss_mask)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print(f"Epoch {epoch+1}/{epochs} | Loss: {loss.item():.4f}")

inputs tensor([[[0., 5., 4., 4.],
         [0., 5., 4., 2.],
         [4., 5., 4., 4.]]])
mask_vector Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)

masked tensor([[[0., 5., 4., 4.],
         [0., 5., 4., 2.],
         [0., 0., 0., 0.]]], grad_fn=<CopySlices>)
target_labels tensor([[[0., 5., 4., 4.],
         [0., 5., 4., 2.],
         [4., 5., 4., 4.]]])
loss_mask tensor([[False, False,  True]])

torch.Size([1, 3, 64])
torch.Size([1, 3])
torch.Size([1, 3])
Epoch 1/1 | Loss: 6.5812
