In [1]:
import torch

In [20]:
torch.set_printoptions(precision=4, sci_mode=False, threshold=2000)

In [3]:
torch.manual_seed(1337)

<torch._C.Generator at 0x1b1eb4e03d0>

In [4]:
from model import Transformer, ModelArgs
from export import version1_export

In [5]:
args = ModelArgs(dim=64, n_layers=2, n_heads=2, vocab_size=128, multiple_of=64, max_seq_len=32)

In [6]:
trns = Transformer(args)
trns.eval()

Transformer(
  (tok_embeddings): Embedding(128, 64)
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-1): 2 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=64, out_features=64, bias=False)
        (wk): Linear(in_features=64, out_features=64, bias=False)
        (wv): Linear(in_features=64, out_features=64, bias=False)
        (wo): Linear(in_features=64, out_features=64, bias=False)
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=64, out_features=192, bias=False)
        (w2): Linear(in_features=192, out_features=64, bias=False)
        (w3): Linear(in_features=64, out_features=192, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=64, out_features=128,

In [16]:
# version1_export(trns, "test_qmodel.bin")

wrote test_qmodel.bin


In [7]:
from model import precompute_freqs_cis

In [8]:
freqs_cos, freqs_sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len, theta=10000)

In [9]:
seq_len = 4

In [10]:
attn_input = torch.randn((2, seq_len, args.dim))

In [16]:
attn_input.numel()

512

In [21]:
attn_input.flatten()

tensor([ 0.1953,  1.3208,  0.2495, -0.1673, -0.0591,  0.3616,  1.3410, -0.8957,
         0.6161,  1.2228, -0.1955, -0.7414, -0.3080, -2.1912,  0.2576,  0.0792,
        -1.3583, -0.1788,  1.2726,  0.6850, -0.6112,  1.0188, -0.5511,  1.8887,
         1.3301,  1.0066,  0.4750, -1.4066, -0.8390, -0.3888, -0.2326,  1.0632,
         0.0810, -2.1969,  0.4605, -1.4893, -2.9533,  1.4599,  2.3899, -0.3932,
        -0.9568, -1.1449,  1.4483,  1.2406,  0.1914,  1.4117, -0.4767, -1.0781,
        -0.9110,  1.2804,  0.3797,  0.9956, -1.5105, -0.7912,  1.3679, -0.5590,
        -2.1662, -0.9770,  0.4323,  0.1489, -0.7385,  0.7942, -0.1410,  0.1094,
        -0.0592, -1.2568, -0.1238, -2.5755, -0.8053, -1.3994,  0.4603,  0.8535,
        -0.4498, -0.1781,  2.5215,  1.4743,  1.1519, -0.5641,  1.2015, -0.1887,
         0.1653,  0.5446,  0.3146, -0.8870,  0.1913,  0.3718,  2.1777, -0.2279,
        -0.9992, -1.7526,  0.4812,  0.0654,  0.0687,  1.2828, -0.5617,  0.0165,
        -0.3389,  1.7562, -0.1553, -0.65

In [12]:
attn_out = trns.layers[0].attention.forward(attn_input, freqs_cos[:seq_len], freqs_sin[:seq_len])

In [14]:
attn_out.shape

torch.Size([2, 4, 64])

In [23]:
attn_out.flatten()

tensor([     0.0075,     -0.0026,      0.0036,      0.0030,     -0.0054,
             0.0292,     -0.0237,      0.0184,     -0.0082,     -0.0130,
            -0.0035,      0.0040,      0.0218,      0.0066,      0.0012,
            -0.0013,      0.0048,      0.0216,      0.0087,      0.0048,
             0.0000,      0.0175,      0.0174,     -0.0108,      0.0115,
             0.0020,      0.0106,     -0.0113,     -0.0047,      0.0158,
             0.0040,      0.0123,     -0.0040,     -0.0076,      0.0059,
            -0.0118,     -0.0213,     -0.0106,      0.0068,     -0.0006,
            -0.0020,     -0.0106,      0.0091,     -0.0211,      0.0162,
            -0.0130,     -0.0020,     -0.0058,     -0.0189,      0.0125,
             0.0026,      0.0141,     -0.0200,     -0.0145,     -0.0071,
             0.0127,     -0.0006,     -0.0048,      0.0071,     -0.0078,
             0.0007,     -0.0046,     -0.0124,     -0.0114,     -0.0116,
            -0.0066,      0.0020,     -0.0021,     