In [1]:
import torch

In [2]:
torch.set_printoptions(precision=4, sci_mode=False, threshold=2000)

In [3]:
torch.manual_seed(1337)

<torch._C.Generator at 0x1fe3e28c3d0>

In [4]:
from model import Transformer, ModelArgs
from export import version1_export

In [5]:
args = ModelArgs(dim=64, n_layers=2, n_heads=2, vocab_size=128, multiple_of=64, max_seq_len=32)

In [6]:
trns = Transformer(args)
trns.eval()

Transformer(
  (tok_embeddings): Embedding(128, 64)
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-1): 2 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=64, out_features=64, bias=False)
        (wk): Linear(in_features=64, out_features=64, bias=False)
        (wv): Linear(in_features=64, out_features=64, bias=False)
        (wo): Linear(in_features=64, out_features=64, bias=False)
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=64, out_features=192, bias=False)
        (w2): Linear(in_features=192, out_features=64, bias=False)
        (w3): Linear(in_features=64, out_features=192, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=64, out_features=128,

In [16]:
# version1_export(trns, "test_qmodel.bin")

wrote test_qmodel.bin


In [7]:
from model import precompute_freqs_cis

In [8]:
freqs_cos, freqs_sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len, theta=10000)

In [9]:
seq_len = 4

In [10]:
attn_input = torch.randn((2, seq_len, args.dim))

In [11]:
attn_input.numel()

512

In [12]:
attn_input.flatten()

tensor([ 0.1953,  1.3208,  0.2495, -0.1673, -0.0591,  0.3616,  1.3410, -0.8957,
         0.6161,  1.2228, -0.1955, -0.7414, -0.3080, -2.1912,  0.2576,  0.0792,
        -1.3583, -0.1788,  1.2726,  0.6850, -0.6112,  1.0188, -0.5511,  1.8887,
         1.3301,  1.0066,  0.4750, -1.4066, -0.8390, -0.3888, -0.2326,  1.0632,
         0.0810, -2.1969,  0.4605, -1.4893, -2.9533,  1.4599,  2.3899, -0.3932,
        -0.9568, -1.1449,  1.4483,  1.2406,  0.1914,  1.4117, -0.4767, -1.0781,
        -0.9110,  1.2804,  0.3797,  0.9956, -1.5105, -0.7912,  1.3679, -0.5590,
        -2.1662, -0.9770,  0.4323,  0.1489, -0.7385,  0.7942, -0.1410,  0.1094,
        -0.0592, -1.2568, -0.1238, -2.5755, -0.8053, -1.3994,  0.4603,  0.8535,
        -0.4498, -0.1781,  2.5215,  1.4743,  1.1519, -0.5641,  1.2015, -0.1887,
         0.1653,  0.5446,  0.3146, -0.8870,  0.1913,  0.3718,  2.1777, -0.2279,
        -0.9992, -1.7526,  0.4812,  0.0654,  0.0687,  1.2828, -0.5617,  0.0165,
        -0.3389,  1.7562, -0.1553, -0.65

In [13]:
attn_out = trns.layers[0].attention.forward(attn_input, freqs_cos[:seq_len], freqs_sin[:seq_len])

In [14]:
attn_out.shape

torch.Size([2, 4, 64])

In [15]:
attn_out.flatten()

tensor([     0.0075,     -0.0026,      0.0036,      0.0030,     -0.0054,
             0.0292,     -0.0237,      0.0184,     -0.0082,     -0.0130,
            -0.0035,      0.0040,      0.0218,      0.0066,      0.0012,
            -0.0013,      0.0048,      0.0216,      0.0087,      0.0048,
             0.0000,      0.0175,      0.0174,     -0.0108,      0.0115,
             0.0020,      0.0106,     -0.0113,     -0.0047,      0.0158,
             0.0040,      0.0123,     -0.0040,     -0.0076,      0.0059,
            -0.0118,     -0.0213,     -0.0106,      0.0068,     -0.0006,
            -0.0020,     -0.0106,      0.0091,     -0.0211,      0.0162,
            -0.0130,     -0.0020,     -0.0058,     -0.0189,      0.0125,
             0.0026,      0.0141,     -0.0200,     -0.0145,     -0.0071,
             0.0127,     -0.0006,     -0.0048,      0.0071,     -0.0078,
             0.0007,     -0.0046,     -0.0124,     -0.0114,     -0.0116,
            -0.0066,      0.0020,     -0.0021,     

In [17]:
ffd_input = torch.randn(2, seq_len, args.dim)

In [20]:
ffd_input.flatten()

tensor([    -0.3538,     -1.9140,     -0.2311,      1.0905,      0.3342,
             1.1195,      0.7562,      0.4004,      0.6715,     -0.5403,
             0.3872,     -1.2024,     -0.4792,     -0.0706,      1.6256,
             1.8494,     -0.1177,     -0.1636,      0.5777,      0.6076,
            -1.0897,     -1.4101,      0.5493,     -0.0152,     -1.7197,
             0.6678,     -0.0942,     -0.2970,     -0.8426,     -1.5830,
             0.4757,     -1.5387,     -0.6046,      0.0255,     -0.4568,
             0.0122,     -0.1214,      0.4754,     -0.2409,      1.0235,
             0.1045,     -0.5077,     -0.1771,     -0.3445,     -0.2736,
            -1.8487,     -2.0654,     -0.2309,      0.3603,     -0.6218,
             0.2895,     -0.1896,      1.9116,      0.4093,      1.2402,
            -0.4324,      0.2388,      0.5648,      3.4474,     -0.0728,
            -1.1945,     -0.1308,      0.5118,     -1.0987,      0.6883,
             3.0039,     -0.1289,     -1.0772,     

In [18]:
ffd_out = trns.layers[0].feed_forward.forward(ffd_input)

In [22]:
ffd_out.shape

torch.Size([2, 4, 64])

In [21]:
ffd_out.flatten()

tensor([     0.0004,     -0.0021,      0.0001,     -0.0007,      0.0013,
            -0.0015,      0.0020,     -0.0011,      0.0024,     -0.0022,
             0.0008,     -0.0026,      0.0005,      0.0005,     -0.0010,
             0.0010,     -0.0011,     -0.0010,     -0.0004,      0.0009,
             0.0005,      0.0022,     -0.0004,      0.0001,      0.0020,
            -0.0013,      0.0003,      0.0001,     -0.0027,      0.0035,
             0.0046,      0.0000,      0.0016,     -0.0002,     -0.0005,
            -0.0015,      0.0018,     -0.0037,     -0.0016,      0.0026,
             0.0010,      0.0012,      0.0034,      0.0001,      0.0024,
             0.0022,      0.0026,     -0.0003,      0.0016,      0.0021,
            -0.0018,     -0.0014,     -0.0012,      0.0002,     -0.0029,
            -0.0014,     -0.0013,      0.0030,     -0.0005,     -0.0034,
            -0.0009,     -0.0025,     -0.0006,     -0.0012,     -0.0014,
            -0.0018,      0.0001,      0.0025,     

In [26]:
block_out = trns.layers[0].forward(attn_input, freqs_cos[:seq_len], freqs_sin[:seq_len])

In [27]:
block_out.shape

torch.Size([2, 4, 64])

In [28]:
block_out.flatten()

tensor([ 0.1977,  1.3144,  0.2541, -0.1610, -0.0646,  0.3886,  1.3182, -0.8762,
         0.6084,  1.2136, -0.2006, -0.7385, -0.2886, -2.1820,  0.2588,  0.0814,
        -1.3533, -0.1580,  1.2797,  0.6903, -0.6106,  1.0316, -0.5343,  1.8786,
         1.3433,  1.0105,  0.4854, -1.4160, -0.8416, -0.3746, -0.2294,  1.0731,
         0.0806, -2.2019,  0.4653, -1.4978, -2.9769,  1.4485,  2.3975, -0.3910,
        -0.9587, -1.1575,  1.4564,  1.2186,  0.2059,  1.3992, -0.4760, -1.0831,
        -0.9288,  1.2881,  0.3824,  1.0100, -1.5302, -0.8057,  1.3591, -0.5457,
        -2.1616, -0.9833,  0.4398,  0.1428, -0.7393,  0.7900, -0.1516,  0.1001,
        -0.0724, -1.2629, -0.1185, -2.5759, -0.8215, -1.3792,  0.4514,  0.8565,
        -0.4462, -0.1880,  2.5134,  1.4770,  1.1634, -0.5623,  1.1907, -0.1931,
         0.1707,  0.5488,  0.3090, -0.8840,  0.1885,  0.3702,  2.1833, -0.2303,
        -1.0051, -1.7367,  0.4868,  0.0513,  0.0590,  1.2906, -0.5567,  0.0264,
        -0.3397,  1.7554, -0.1631, -0.66

In [29]:
args

ModelArgs(dim=64, n_layers=2, n_heads=2, n_kv_heads=None, vocab_size=128, hidden_dim=None, multiple_of=64, norm_eps=1e-05, max_seq_len=32, dropout=0.0)

In [30]:
model_in = torch.randint(0, args.vocab_size, (2, seq_len))

In [32]:
model_in.shape

torch.Size([2, 4])

In [36]:
model_in.flatten()

tensor([ 13,  24,  54,  63,  77, 104,  42,  26])

In [31]:
model_out = trns.forward(model_in)

In [33]:
model_out.shape

torch.Size([2, 1, 128])

In [38]:
model_out.flatten()

tensor([    -0.0287,     -0.0941,     -0.0485,     -0.0072,      0.1997,
            -0.1584,     -0.0730,      0.1262,      0.0198,     -0.0634,
             0.0235,      0.2640,     -0.0725,     -0.0691,     -0.2992,
             0.0761,     -0.3731,      0.1601,      0.0051,      0.2053,
             0.2864,     -0.1265,     -0.0002,      0.2552,      0.1563,
            -0.2802,      0.0716,      0.1477,      0.1392,      0.2549,
            -0.2621,     -0.1014,     -0.1827,     -0.2264,      0.3995,
            -0.0646,      0.0414,      0.2854,     -0.1382,     -0.1682,
             0.3519,     -0.0965,      0.1739,      0.1457,      0.0252,
             0.1421,     -0.0957,      0.1065,     -0.0390,      0.0667,
             0.2187,     -0.1016,      0.1857,     -0.3105,     -0.0179,
             0.1519,      0.0819,     -0.3554,      0.0687,      0.0979,
             0.1896,     -0.0277,      0.0122,      1.5325,     -0.1347,
             0.1201,      0.1374,     -0.2004,     