In [1]:
from importlib.metadata import version

pkgs = [
    "blobfile",         # to download pretrained weights
    "huggingface_hub",  # to download pretrained weights
    "tiktoken",         # to implement the tokenizer
    "torch",            # to implement the model
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

PackageNotFoundError: blobfile

In [2]:
! pip3 install blobfile

Collecting blobfile
  Obtaining dependency information for blobfile from https://files.pythonhosted.org/packages/ed/4d/1392562369b1139e741b30d624f09fe7091d17dd5579fae5732f044b12bb/blobfile-3.0.0-py3-none-any.whl.metadata
  Downloading blobfile-3.0.0-py3-none-any.whl.metadata (15 kB)
Collecting pycryptodomex>=3.8 (from blobfile)
  Obtaining dependency information for pycryptodomex>=3.8 from https://files.pythonhosted.org/packages/62/c2/8c97e649ccd3886eaf4918bd87791d3b52e80ba5b9c4678e2b631f2f8340/pycryptodomex-3.22.0-cp37-abi3-macosx_10_9_universal2.whl.metadata
  Downloading pycryptodomex-3.22.0-cp37-abi3-macosx_10_9_universal2.whl.metadata (3.4 kB)
Collecting lxml>=4.9 (from blobfile)
  Obtaining dependency information for lxml>=4.9 from https://files.pythonhosted.org/packages/1e/04/acd238222ea25683e43ac7113facc380b3aaf77c53e7d88c4f544cef02ca/lxml-5.4.0-cp39-cp39-macosx_10_9_universal2.whl.metadata
  Downloading lxml-5.4.0-cp39-cp39-macosx_10_9_universal2.whl.metadata (3.5 kB)
Download

# 1. 逐步转换LLaMA模型实现

## 1.1 复用LLaMA 2组件

In [11]:
# 从notebook中导入定义的模块
import os
import sys
import io
import nbformat
import types

def import_from_notebook():
    def import_definitions_from_notebook(fullname, names):
        # current_dir = os.getcwd()
        # path = os.path.join(current_dir, fullname + ".ipynb")
        # path = os.path.normpath(path)
        path = fullname + ".ipynb"

        if not os.path.exists(path):
            raise FileNotFoundError(f"Notebook file not found at: {path}")
        
        with io.open(path, "r", encoding="utf-8") as f:
            nb = nbformat.read(f, as_version=4)

        mod = types.ModuleType(fullname)
        sys.modules[fullname] = mod

        for cell in nb.cells:
            if cell.cell_type == "code":
                cell_code = cell.source
                for name in names:
                    if f"def {name}" in cell_code or f"class {name}" in cell_code:
                        exec(cell_code, mod.__dict__)
        return mod

    fullname = "../../ch05/07_gpt_to_llama/converting-gpt-to-llama2"
    names = ["precompute_rope_params", "compute_rope", "SiLU", "FeedForward", "RMSNorm", "MultiHeadAttention"]

    return import_definitions_from_notebook(fullname, names)

In [12]:
imported_module = import_from_notebook()

compute_rope = getattr(imported_module, "compute_rope", None)
SiLU = getattr(imported_module, "SiLU", None)
FeedForward = getattr(imported_module, "FeedForward", None)
RMSNorm = getattr(imported_module, "RMSNorm", None)

MultiHeadAttention = getattr(imported_module, "MultiHeadAttention", None)

## 1.2 优化RoPE

In [None]:
import torch

def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    inv_freq = 2.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))

    if freq_config is not None:
        low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
        high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]

        wavelen = 2 * torch.pi / inv_freq

        inv_freq_llama = torch.where(
            wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
        )

        smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
            freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
        )

        smoothed_inv_freq = (
            (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
        )

        is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
        inv_freq = inv_freq_llama
    
    positions = torch.arange(context_length)

    angles = positions[:, None] * inv_freq[None, :]

    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin