In [None]:
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
#                                             MiniMind Config
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘

from transformers import PretrainedConfig

class MiniMindConfig(PretrainedConfig):
    model_type = "minimind"

    def __init__(
        self,
        dropout: float = 0.0,
        bos_token_id: int = 1,
        eos_token_idx: int = 2,
        hidden_act: str = 'silu',
        hidden_size: int = 512,
        intermediate_size: int = None,
        max_position_embeddings: int = 32768,
        num_attention_heads: int = 9,
        num_hidden_layers: int = 8,
        num_key_value_heads: int = 2,
        vocab_size: int = 6400,
        rms_norm_eps: float = 1e-5,
        rope_theta: int = 1000000,
        flash_attn: bool = True,
        **kwargs
        # 参数灵活性：允许传入未在函数签名中明确定义的额外的关键字参数
        # 参数传递：可以将这些额外参数传递给父类（这里是 PretrainedConfig）或其他被调用的函数
        # 向后兼容性：可以在不破坏现有代码的情况下，向配置类添加新参数
        # 配置扩展：允许用户为模型提供自定义配置选项，而无需修改原始代码
    ):
        super().__init__(**kwargs)
        self.dropout = dropout
        self.bos_token_id = bos_token_id
        self.eos_token_idx = eos
        self.hidden_act = hidden_act
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.num_key_value_heads = num_key_value_heads
        self.vocab_size = vocab_size
        self.rms_norm_eps = rms_norm_eps
        self.rope_theta = rope_theta
        self.flash_attn = flash_attn

In [None]:
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
#                                             MiniMind Model
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘

import math
import torch
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, Tuple, List, Union
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

In [None]:
class RMSNorm(torch.nn.Module):
  def __init__(self, dim: int, eps: float = 1e-5):
    super().__init__()
    self.eps = eps
    self.weight = nn.Parameter(torch.ones(dim))

  def _norm(self, x):
    return x* torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps)

  def forward(self, x):
    return self.weight * self._norm(x.float()).type_as(x)

In [None]:
# **预计算旋转位置编码的频率**
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theata: float = 1e6):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.deivce)
    freqs = torch.outer(t, freqs).float()
    freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
    freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
    return freqs_cos, freqs_sin


# 计算RoPE，输出编码后的q, v
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    def rotate_half(x):
        return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)

    q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
    k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
    return q_embed, k_embed


# 重复K/V矩阵，以适配多注意力头
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    bs, slen, num_key_value_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :].
        expand(bs, slen, num_key_value_heads, n_rep, head_dim)
        .reshape(bs, slen, num_key_value, heads * n_rep, head_dim)
    )

In [None]:
theta = 5e-5
dim = 1024
freqs = 1 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
print(freqs.shape)

torch.Size([512])


In [None]:
class Attention(nn.Module):
  def __init__(self, args: MiniMindConfig):
    super().__init__()
    self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
    assert args.num_attention_heads % self.num_key_value_heads == 0
    self.n_local_heads = args.num_attention_heads
    self.n_local_kv_heads = self.num_key_value_heads
    self.rep = self.n_local_heads // self.n_local_kv_heads
    self.head_dim = args.hidden_size // args.num_attention_heads
    self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias = False)
    self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias = False)
    self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias = False)
    self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias = False)
    self.attn_dropout = nn.Dropout(args.dropout)
    self.resd_dropout = nn.Dropout(args.dropout)
    self.dropout = args.dropout
    self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn


  def forward(self,
              x: torch.Tensor,
              position_embeddings: Tuple[torch.Tensor, torch.Tensor],
              past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]],
              use_cache = False,
              attention_mask: Optional[torch.Tensor] = None):
    batch_size, seq_len, _ = x.shape
    xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
    xq, xk = xq.view(batch_size, seq_len, self.n_local_heads, self.head_dim), xk.view(batch_size, seq_len, self.n_local_kv_heads, self.head_dim)
    xv = xv.view(batch_size, seq_len, self.n_local_kv_heads, self.head_dim)

    cos, sin = position




