# Install packages

In [1]:
!pip install accelerate bitsandbytes transformers

Collecting accelerate
  Using cached accelerate-0.33.0-py3-none-any.whl.metadata (18 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Collecting transformers
  Using cached transformers-4.44.0-py3-none-any.whl.metadata (43 kB)
Collecting numpy<2.0.0,>=1.17 (from accelerate)
  Downloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Collecting huggingface-hub>=0.21.0 (from accelerate)
  Downloading huggingface_hub-0.24.5-py3-none-any.whl.metadata (13 kB)
Collecting safetensors>=0.3.1 (from accelerate)
  Downloading safetensors-0.4.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting regex!=2019.12.17 (from transformers)
  Using cached regex-2024.7.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers)
  Using cached tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_6

# Importing

In [1]:
import torch
import warnings
from torch import nn
from torch.nn import functional as F
from typing import Optional, Tuple
import math
from transformers.utils import logging
from transformers.cache_utils import Cache
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
    LlamaMLP,
    LlamaRMSNorm,
    apply_rotary_pos_emb,
)
from transformers import (
    LlamaForCausalLM,
    AutoTokenizer,
)
warnings.filterwarnings("ignore")

# Llama attention

In [2]:
logger = logging.get_logger(__name__)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
        super().__init__(config, layer_idx)
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        ##########################################################
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads # k,v head 개수
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads  # 한 그룹 내의 head 개수??
        ##########################################################
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        ##########################################################
        # q,k,v projection layer와 output projection layer
        # query projection layer를 선언하세요.
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
        ##########################################################
        self._init_rope()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1: # tensor 병렬화
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            ##########################################################
            # q,k,v projection layer를 통과시킵니다.
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)
            ##########################################################

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        ##########################################################
        # GQA 연산을 위해 k,v 행렬을 q 행렬과 같은 크기로 반복하여 만들어줍니다. (repeat_kv 함수 사용)
        # 참고: 현재 key_states 와 value_states의 shape은 (bsz, self.num_key_value_heads, q_len, self.head_dim)
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        ##########################################################

        # q, k -> weight
        # q: (bsz, self.num_heads, q_len, self.head_dim)
        # k: (bsz, self.num_heads, q_len, self.head_dim)
        ##########################################################
        # 어텐션 웨이트를 계산합니다.
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
                                                                                #/ torch.sqrt(torch.FloatTensor([self.head_dim]))
        ##########################################################

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        ##########################################################
        # 위에서 구한 어텐션 웨이트에 softmax를 적용하세요.
        # 이때 데이터 타입은 torch.float32로 지정하여 계산하세요.
        # 구해진 softmax을 다시 query_states와 같은 데이터 타입으로 변경하세요.
        attn_weights = torch.softmax(attn_weights.type(torch.float32), dim=-1).to(query_states.dtypes)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        # 어텐션 웨이트와 v 행렬을 이용해 어텐션 output을 구하세요
        attn_output = torch.matmul(attn_weights, value_states)
        # shape : attn_weight(bsz, self.num_heads, q_len, q_len), v(bsz, self.num_heads, q_len, self.head_dim)
        ##########################################################

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            ##########################################################
            # 마지막으로 output projection layer를 통과시킵니다.
            attn_output = self.o_proj(attn_output)
            ##########################################################

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

# Llama decoder layer

In [3]:
class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__(config, layer_idx)
        self.hidden_size = config.hidden_size
        ##########################################################
        self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 가장 처음 적용하는 norm
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 어텐션 후에 적용하는 norm
        ##########################################################

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """
        ##########################################################
        residual = hidden_states

        # norm을 적용하세요
        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
        )

        # residual을 적용하세요.
        hidden_states = residual + hidden_states

        # norm을 적용하세요.
        hidden_states = self.post_attention_layernorm(hidden_states)

        # Fully Connected
        # norm, ffn, residual을 적용하세요.
        residual = hidden_states
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        ##########################################################

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs

# Load model

In [None]:
#!pip install sentencepiece
#!pip install protobuf

In [5]:
model_id = "lmsys/vicuna-7b-v1.5"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = LlamaForCausalLM.from_pretrained(
    model_id,
    load_in_4bit=True,
    device_map="auto",
)

config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


pytorch_model.bin.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/162 [00:00<?, ?B/s]

# Text generation

In [6]:
# prompt: generate any text from above code
input_text = "USER: Who is the best soccer player in the world?\nASSISTANT:" # Templetes 사용
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.cuda()

print(input_ids)

# Generate text
model.eval()
with torch.no_grad():
    output = model.generate(input_ids, max_new_tokens=100)

# Decode and print the generated text
generated_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
print(generated_text)

tensor([[    1,  3148,  1001, 29901, 11644,   338,   278,  1900,   269, 11953,
          4847,   297,   278,  3186, 29973,    13, 22933,  9047, 13566, 29901]],
       device='cuda:0')
USER: Who is the best soccer player in the world?
ASSISTANT: It is subjective to determine who the best soccer player in the world is, as people have different opinions on the matter. Some may argue that a player's skill, talent, and achievements should be considered, while others may argue that a player's impact on the game and their ability to inspire others should be taken into account. Ultimately, the best soccer player in the world is a matter of personal opinion.
