In [1]:
import torch
from heliumbert import HeliumbertConfig

# 1. Create a tiny config (super small to test quickly)
config = HeliumbertConfig(
    vocab_size=100,          # toy vocab
    hidden_size=32,          # small hidden size
    num_hidden_layers=2,     # fewer layers
    num_attention_heads=4,
    intermediate_size=64,
    max_position_embeddings=16,
)

# 2. Import your modified model class (replace with yours)
from heliumbert import HeliumbertForTokenClassification  

model = HeliumbertForTokenClassification(config)

# 3. Dummy batch (batch_size=2, seq_len=8)
input_ids = torch.randint(0, config.vocab_size, (2, 8))
attention_mask = torch.ones_like(input_ids)

# 4. Forward pass with labels (token classification expects labels)
labels = torch.randint(0, config.num_labels, (2, 8))  
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

loss, logits = outputs.loss, outputs.logits
print("Loss:", loss.item())
print("Logits shape:", logits.shape)

Loss: 0.7032086849212646
Logits shape: torch.Size([2, 8, 2])
