In [None]:
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
from dataclasses import dataclass
import time
import math
import inspect
import numpy as np
import os

# --------------------------------------------------------
class Config:
    def __init__(self):
        self.embedding_size = 4096
        self.vocab_size     = 128256
        self.block_count    = 32
        self.proj           = 14336
        self.eps            = 1e-5

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(in_features=config.embedding_size, out_features=config.proj, bias=False)
        self.up_proj   = nn.Linear(in_features=config.embedding_size, out_features=config.proj, bias=False)
        self.down_proj = nn.Linear(in_features=config.proj, out_features=config.embedding_size, bias=False)
        self.act_fn    = nn.SiLU()

class LlamaRMSNorm(nn.Module):
    def __init__(self, dim, eps):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        # Implementation details would go here

class LlamaSdpaAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.q_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size, bias=False)
        self.k_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size // 4, bias=False)
        self.v_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size // 4, bias=False)
        self.o_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size, bias=False)
        self.rotary_emb = LlamaRotaryEmbedding()

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = LlamaSdpaAttention(config)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.embedding_size, config.eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.embedding_size, config.eps)

class LlamaModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size)
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.block_count)])
        self.norm = LlamaRMSNorm(config.embedding_size, config.eps)

class LlamaForCausalLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.model = LlamaModel(config)
        self.lm_head = nn.Linear(config.embedding_size, config.vocab_size, bias=False)

config = Config()
model = LlamaForCausalLM(config)

In [5]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "nvidia/Llama3-ChatQA-1.5-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=r'D:\hugging-models')
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", cache_dir=r'D:\hugging-models')
model

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)

In [2]:
sd_hf = model.state_dict()

for k, v in sd_hf.items():
    print(k, v.shape)

model.embed_tokens.weight torch.Size([128256, 4096])
model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.0.self_attn.k_proj.weight torch.Size([1024, 4096])
model.layers.0.self_attn.v_proj.weight torch.Size([1024, 4096])
model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096])
model.layers.0.mlp.gate_proj.weight torch.Size([14336, 4096])
model.layers.0.mlp.up_proj.weight torch.Size([14336, 4096])
model.layers.0.mlp.down_proj.weight torch.Size([4096, 14336])
model.layers.0.input_layernorm.weight torch.Size([4096])
model.layers.0.post_attention_layernorm.weight torch.Size([4096])
model.layers.1.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.1.self_attn.k_proj.weight torch.Size([1024, 4096])
model.layers.1.self_attn.v_proj.weight torch.Size([1024, 4096])
model.layers.1.self_attn.o_proj.weight torch.Size([4096, 4096])
model.layers.1.mlp.gate_proj.weight torch.Size([14336, 4096])
model.layers.1.mlp.up_proj.weight torch.Size([14336, 4096])
model.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
from dataclasses import dataclass
import time
import math
import inspect
import numpy as np
import os

# --------------------------------------------------------
class Config:
    def __init__(self):
        self.embedding_size = 4096
        self.vocab_size = 128256
        self.block_count = 32
        self.proj = 14336
        self.eps = 1e-5

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(in_features=config.embedding_size, out_features=config.proj, bias=False)
        self.up_proj = nn.Linear(in_features=config.embedding_size, out_features=config.proj, bias=False)
        self.down_proj = nn.Linear(in_features=config.proj, out_features=config.embedding_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x) * self.up_proj(x)))

class LlamaRMSNorm(nn.Module):
    def __init__(self, dim, eps):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x, seq_len=None):
        seq_len = seq_len or x.shape[1]
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb

def apply_rotary_emb(q, k, cos, sin):
    q_rot = (q * cos) + (rotate_half(q) * sin)
    k_rot = (k * cos) + (rotate_half(k) * sin)
    return q_rot, k_rot

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

class LlamaSdpaAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.q_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size, bias=False)
        self.k_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size // 4, bias=False)
        self.v_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size // 4, bias=False)
        self.o_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size, bias=False)
        self.rotary_emb = LlamaRotaryEmbedding(config.embedding_size // 4)

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        seq_len = x.shape[1]
        cos, sin = self.rotary_emb(x, seq_len).chunk(2, dim=-1)
        q, k = apply_rotary_emb(q, k, cos, sin)

        attn_weights = torch.einsum('bqd,bkd->bqk', q, k) / math.sqrt(q.size(-1))
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_output = torch.einsum('bqk,bvd->bqd', attn_weights, v)

        return self.o_proj(attn_output)

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = LlamaSdpaAttention(config)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.embedding_size, config.eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.embedding_size, config.eps)

    def forward(self, x):
        residual = x
        x = self.input_layernorm(x)
        x = self.self_attn(x)
        x = x + residual

        residual = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = x + residual

        return x

class LlamaModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size)
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.block_count)])
        self.norm = LlamaRMSNorm(config.embedding_size, config.eps)

    def forward(self, input_ids):
        x = self.embed_tokens(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x

class LlamaForCausalLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.model = LlamaModel(config)
        self.lm_head = nn.Linear(config.embedding_size, config.vocab_size, bias=False)

    def forward(self, input_ids):
        hidden_states = self.model(input_ids)
        logits = self.lm_head(hidden_states)
        return logits

config = Config()
model = LlamaForCausalLM(config)
input_ids = torch.randint(0, config.vocab_size, (1, 10))

# Forward pass
logits = model(input_ids)
print(f"Output logits shape: {logits.shape}")

# Generate text
generated = model.generate(input_ids, max_length=20)
print(f"Generated sequence shape: {generated.shape}")

In [2]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math
import numpy as np
import os

class Config:
    def __init__(self):
        self.embedding_size = 4096
        self.vocab_size = 128256
        self.block_count = 32
        self.proj = 14336
        self.eps = 1e-5

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(in_features=config.embedding_size, out_features=config.proj, bias=False)
        self.up_proj = nn.Linear(in_features=config.embedding_size, out_features=config.proj, bias=False)
        self.down_proj = nn.Linear(in_features=config.proj, out_features=config.embedding_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x) * self.up_proj(x)))

class LlamaRMSNorm(nn.Module):
    def __init__(self, dim, eps):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x, seq_len=None):
        seq_len = seq_len or x.shape[1]
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb

def apply_rotary_emb(q, k, cos, sin):
    q_rot = (q * cos) + (rotate_half(q) * sin)
    k_rot = (k * cos) + (rotate_half(k) * sin)
    return q_rot, k_rot

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

class LlamaSdpaAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.q_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size, bias=False)
        self.k_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size // 4, bias=False)
        self.v_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size // 4, bias=False)
        self.o_proj = nn.Linear(in_features=config.embedding_size, out_features=config.embedding_size, bias=False)
        self.rotary_emb = LlamaRotaryEmbedding(config.embedding_size // 4)

    def forward(self, x):
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        seq_len = x.shape[1]
        cos, sin = self.rotary_emb(x, seq_len).chunk(2, dim=-1)
        q, k = apply_rotary_emb(q, k, cos, sin)

        attn_weights = torch.einsum('bqd,bkd->bqk', q, k) / math.sqrt(q.size(-1))
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_output = torch.einsum('bqk,bvd->bqd', attn_weights, v)

        return self.o_proj(attn_output)

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = LlamaSdpaAttention(config)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.embedding_size, config.eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.embedding_size, config.eps)

    def forward(self, x):
        residual = x
        x = self.input_layernorm(x)
        x = self.self_attn(x)
        x = x + residual

        residual = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = x + residual

        return x

class LlamaModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size)
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.block_count)])
        self.norm = LlamaRMSNorm(config.embedding_size, config.eps)

    def forward(self, input_ids):
        x = self.embed_tokens(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x

class LlamaForCausalLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.model = LlamaModel(config)
        self.lm_head = nn.Linear(config.embedding_size, config.vocab_size, bias=False)

    def forward(self, input_ids):
        hidden_states = self.model(input_ids)
        logits = self.lm_head(hidden_states)
        return logits

config = Config()
model = LlamaForCausalLM(config)
input_ids = torch.randint(0, config.vocab_size, (1, 10))

# Forward pass
logits = model(input_ids)
print(f"Output logits shape: {logits.shape}")


RuntimeError: The size of tensor a (4096) must match the size of tensor b (512) at non-singleton dimension 2

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Config:
    def __init__(self):
        self.embedding_size = 4096
        self.vocab_size     = 128256
        self.block_count    = 32
        self.proj           = 14336
        self.eps            = 1e-5
        self.max_seq_len    = 2048
        self.num_attention_heads = 32
        self.num_key_value_heads = 32
        self.head_dim       = self.embedding_size // self.num_attention_heads

class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.embedding_size, config.proj, bias=False)
        self.up_proj   = nn.Linear(config.embedding_size, config.proj, bias=False)
        self.down_proj = nn.Linear(config.proj, config.embedding_size, bias=False)
        self.act_fn    = nn.SiLU()

    def forward(self, x):
        gate = self.act_fn(self.gate_proj(x))
        up = self.up_proj(x)
        return self.down_proj(gate * up)

class LlamaRMSNorm(nn.Module):
    def __init__(self, dim, eps):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.max_seq_len = max_seq_len
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_len):
        if self.cos_cached is None or self.sin_cached is None or seq_len > self.max_seq_len:
            self.max_seq_len = max(seq_len, self.max_seq_len)
            t = torch.arange(self.max_seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            self.cos_cached = emb.cos()[None, None, :, :]
            self.sin_cached = emb.sin()[None, None, :, :]
        return (
            self.cos_cached[:, :, :seq_len, :],
            self.sin_cached[:, :, :seq_len, :],
        )

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class LlamaSdpaAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.head_dim = config.head_dim
        
        self.q_proj = nn.Linear(config.embedding_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.embedding_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.embedding_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.embedding_size, bias=False)
        
        self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, config.max_seq_len)

    def forward(self, hidden_states, attention_mask=None):
        bsz, q_len, _ = hidden_states.size()

        q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(v, seq_len=q_len)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        if self.num_key_value_heads != self.num_heads:
            k = k.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=1)
            v = v.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=1)

        attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
        
        return self.o_proj(attn_output)

class LlamaDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = LlamaSdpaAttention(config)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.embedding_size, config.eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.embedding_size, config.eps)

    def forward(self, hidden_states, attention_mask=None):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(hidden_states, attention_mask)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

class LlamaModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size)
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.block_count)])
        self.norm = LlamaRMSNorm(config.embedding_size, config.eps)

    def forward(self, input_ids, attention_mask=None):
        hidden_states = self.embed_tokens(input_ids)

        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)

        hidden_states = self.norm(hidden_states)

        return hidden_states

class LlamaForCausalLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.model = LlamaModel(config)
        self.lm_head = nn.Linear(config.embedding_size, config.vocab_size, bias=False)

    def forward(self, input_ids, attention_mask=None):
        hidden_states = self.model(input_ids, attention_mask)
        logits = self.lm_head(hidden_states)
        return logits

    def generate(self, input_ids, max_length, temperature=1.0, top_k=50, top_p=0.95):
        for _ in range(max_length - input_ids.size(1)):
            outputs = self(input_ids)
            next_token_logits = outputs[:, -1, :] / temperature
            next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_token], dim=-1)
        return input_ids

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")):
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

# Example usage
config = Config()
model = LlamaForCausalLM(config)

# Generate some random input
input_ids = torch.randint(0, config.vocab_size, (1, 10))

# Forward pass
logits = model(input_ids)
print(f"Output logits shape: {logits.shape}")

# Generate text
generated = model.generate(input_ids, max_length=20)
print(f"Generated sequence shape: {generated.shape}")

Output logits shape: torch.Size([1, 10, 128256])


IndexError: index 102614 is out of bounds for dimension 0 with size 1

In [2]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)