In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class TinyCausalLM(nn.Module):
  def __init__(self, vocab_size, d_model, n_heads, n_layers):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_embedding = nn.Embedding(512, d_model)
    decoder_layer = nn.TransformerDecoderLayer(d_model, n_heads)
    self.decoder = nn.TransformerDecoder(decoder_layer, n_layers)
    # خروجی نهایی : برداری به اندازه تعداد واژگان
    # همون لایه خطی که خروجی را به شکل یک توزیع احتمال روی واژگان میده :
    self.fc_out = nn.Linear(d_model, vocab_size)

  def forward(self, x):
    seq_len, batch_size = x.shape
    positions = torch.arange(0, seq_len, device=x.device).unsqueeze(1)
    x = self.embedding(x) + self.pos_embedding(positions)

    # causal mask
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)

    x = self.decoder(x, x, tgt_mask=mask)
    return self.fc_out(x)

In [5]:
vocab_size = 100
model = TinyCausalLM(vocab_size, d_model=64, n_heads=4, n_layers=2)

In [9]:
print(model)

TinyCausalLM(
  (embedding): Embedding(100, 64)
  (pos_embedding): Embedding(512, 64)
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
       