In [1]:
import json


params = json.loads(open('8B-Instruct.local/params.json', 'r').read())
params

{'dim': 4096,
 'n_layers': 32,
 'n_heads': 32,
 'n_kv_heads': 8,
 'vocab_size': 128256,
 'multiple_of': 1024,
 'ffn_dim_multiplier': 1.3,
 'norm_eps': 1e-05,
 'rope_theta': 500000.0}

In [2]:
from llama.model import ModelArgs, Transformer


model_args = ModelArgs(max_seq_len=512, max_batch_size=1, **params)
model_args

ModelArgs(dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000.0, max_batch_size=1, max_seq_len=512)

In [3]:
import os
import torch
from fairscale.nn.model_parallel.initialize import (
    get_model_parallel_rank,
    initialize_model_parallel,
    model_parallel_is_initialized,
)


os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'


if not torch.distributed.is_initialized():
	torch.distributed.init_process_group("nccl")
if not model_parallel_is_initialized():
	initialize_model_parallel(1)


model = Transformer(model_args)
model

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


Transformer(
  (tok_embeddings): VocabParallelEmbedding()
  (layers): ModuleList(
    (0-31): 32 x TransformerBlock(
      (attention): Attention(
        (wq): ColumnParallelLinear()
        (wk): ColumnParallelLinear()
        (wv): ColumnParallelLinear()
        (wo): RowParallelLinear()
      )
      (feed_forward): FeedForward(
        (w1): ColumnParallelLinear()
        (w2): RowParallelLinear()
        (w3): ColumnParallelLinear()
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): ColumnParallelLinear()
)

In [4]:
tokens = torch.arange(100, 200).long()[None]
tokens

tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113,
         114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
         128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
         142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
         156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
         170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
         184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
         198, 199]])

In [5]:
o = model(tokens, 0)
o.shape, o

(torch.Size([1, 100, 128256]),
 tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]))

In [6]:
h = model.tok_embeddings(tokens)
h.shape

torch.Size([1, 100, 4096])

In [7]:
model.freqs_cis.shape, model.freqs_cis

(torch.Size([1024, 64]),
 tensor([[ 1.0000+0.0000e+00j,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
           ...,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
           1.0000+0.0000e+00j],
         [ 0.5403+8.4147e-01j,  0.6861+7.2746e-01j,  0.7878+6.1596e-01j,
           ...,  1.0000+3.6997e-06j,  1.0000+3.0139e-06j,
           1.0000+2.4551e-06j],
         [-0.4161+9.0930e-01j, -0.0584+9.9829e-01j,  0.2412+9.7048e-01j,
           ...,  1.0000+7.3994e-06j,  1.0000+6.0277e-06j,
           1.0000+4.9103e-06j],
         ...,
         [-0.9998+1.7612e-02j, -0.6982+7.1587e-01j,  0.5001-8.6597e-01j,
           ...,  1.0000+3.7774e-03j,  1.0000+3.0771e-03j,
           1.0000+2.5067e-03j],
         [-0.5550-8.3182e-01j, -0.9999-1.6764e-02j,  0.9274-3.7418e-01j,
           ...,  1.0000+3.7811e-03j,  1.0000+3.0802e-03j,
           1.0000+2.5092e-03j],
         [ 0.4001-9.1649e-01j, -0.6739-7.3884e-01j,  0.9610+2.7647e-01j,
           ...,  1.0000+3.7848e-03j,  1.0000+3.0832e-03j,
           1.00

In [8]:
mask = torch.full((100, 100), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
mask = torch.hstack([torch.zeros((100, 0), device=tokens.device), mask]).type_as(h)
mask.shape, mask

(torch.Size([100, 100]),
 tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
         [0., 0., -inf,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [0., 0., 0.,  ..., 0., -inf, -inf],
         [0., 0., 0.,  ..., 0., 0., -inf],
         [0., 0., 0.,  ..., 0., 0., 0.]]))

In [10]:
freqs_cis = model.freqs_cis[0 : 100]

layer = model.layers[0]

h1 = layer(h, 0, freqs_cis, mask)
h1.shape, h1

(torch.Size([1, 100, 4096]),
 tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]], grad_fn=<AddBackward0>))