In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

In [3]:

@dataclass
class ModelArgs:
    dims: int = 4096
    n_layers: int = 12
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1
    multiple_of: int = 256
    ff_dim_multipler: Optional[float] = None
    norm_eps: float = 1e-5

    # Needed for Kv cache
    max_batch_size: int = 32
    max_seq_len: int = 2048

    device: str = None


In [None]:

class Transformer(nn.Module):

    def __init__(self, args:ModelArgs) -> None:
        super().__init__()
        self.args = args

        # Model parameters
        assert args.vocab_size != -1, "Vocab_size must be set"

        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        self.tok_embedddings = nn.Embedding(args.vocab_size, args.dims)

        self.layers = nn.ModuleList()
        for _ in range(args.n_layers):
            self.layers.append(EncoderBlock(args))

        self.norm = RMSNorm(args.dims, eps=args.norm_eps)
        self.output = nn.Linear(args.dims, args.vocab_size, bias=False)

        self.freqs_complex = precompute_theta_pos_frequencies(self.args.dims // self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device)

    def forward(self, tokens: torch.Tensor, start_pos: int, use_kv_cache: bool = False, ):
        batch_size, seq_len = tokens.shape

        if use_kv_cache:
            assert seq_len == 1, "When using kv cache, the sequence length must be 1"

            h = self.tok_embedddings(tokens)

            freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len].to(h.device)


            for layer in self.layers:
                h = layer(h, start_pos, freqs_complex)
            h = self.norm(h)
            logits = self.output(h)
            return logits, None, None             
        



In [7]:
torch.arange(0, 20, 2).float() * -2

tensor([ -0.,  -4.,  -8., -12., -16., -20., -24., -28., -32., -36.])