In [1]:
from dataclasses import dataclass
import torch.nn as nn
import torch
from torch.nn import functional as F

@dataclass
class ChessBertConfig:
    vocab_size: int = 117
    block_size: int = 87
    n_layers: int = 4
    n_heads: int = 4
    n_embd: int = 512
    n_labels: int = 1971
    ffn_size: int = 2048

class ChessBertModel(nn.Module):
    def __init__(self, config: ChessBertConfig):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embeddings = nn.Embedding(config.block_size, config.n_embd)
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(config.n_embd,
                                                                        config.n_heads, 
                                                                        config.ffn_size, 
                                                                        norm_first=True,
                                                                        batch_first=True,
                                                                        activation = F.gelu), config.n_layers) 
        self.layer_norm = nn.LayerNorm(config.n_embd)
        self.classifier = nn.Linear(config.n_embd, config.n_labels)

    def forward(self, x, targets=None, attention_masks = None):
        x = self.token_embeddings(x) + self.position_embeddings(torch.arange(x.size(1), device=x.device))
        x = self.encoder(x, src_key_padding_mask=attention_masks)
        x = self.layer_norm(x)
        logits = self.classifier(x[:, 0, :]) #extracts logits of [CLS] token

        if targets is not None:
            criterion = nn.CrossEntropyLoss()
            loss = criterion(logits.view(-1, self.config.n_labels), targets.view(-1))
            return loss, logits
        else:
            return logits