In [1]:
import math
import json
import torch
from torch import nn
from torch.nn import functional as F

class GptOssTopKRouter(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(32, 2880))
        self.bias = nn.Parameter(torch.empty(32))

    def forward(self, hidden_states):
        hidden_states = hidden_states.reshape(-1, 2880)
        router_logits = F.linear(hidden_states, self.weight, self.bias)  # (seq_len, num_experts)
        router_top_value, router_indices = torch.topk(router_logits, 4, dim=-1)  # (seq_len, top_k)
        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
        return router_scores, router_indices

In [2]:
router = GptOssTopKRouter()

In [3]:
router.load_state_dict(torch.load('model.layers.23.mlp.router.pt'))

<All keys matched successfully>

In [4]:
with open('model.layers.23.mlp.router.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [5]:
parameters.keys()

dict_keys(['hidden_states', 'return'])

In [6]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])

In [7]:
router_scores, router_indices = torch.Tensor(parameters['return']['items'][0]['data']), torch.LongTensor(parameters['return']['items'][1]['data'])

In [8]:
actual_router_scores, actual_router_indices = router(hidden_states)

In [9]:
router_scores.allclose(actual_router_scores)

True

In [10]:
router_indices.allclose(actual_router_indices)

True

------

In [11]:
class GptOssRMSNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(2880))

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + 1e-05)
        return (self.weight * hidden_states).to(input_dtype)

In [12]:
post_attention_layernorm = GptOssRMSNorm()

In [13]:
post_attention_layernorm.load_state_dict(torch.load('model.layers.7.post_attention_layernorm.pt'))

<All keys matched successfully>

In [14]:
with open('model.layers.7.post_attention_layernorm.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [15]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])

In [16]:
ret = torch.Tensor(parameters['return']['data'])

In [17]:
actual_ret = post_attention_layernorm(hidden_states)

In [18]:
ret.allclose(actual_ret)

True

------

In [19]:
def _compute_yarn_parameters():
    base = 150000
    partial_rotary_factor = 1.0
    head_dim = 64
    dim = int(head_dim * partial_rotary_factor)
    factor = 32.0
    attention_factor = None
    mscale = None
    mscale_all_dim = None
    original_max_position_embeddings = 4096

    def get_mscale(scale, mscale=1):
        return 0.1 * mscale * math.log(scale) + 1.0

    if attention_factor is None:
        attention_factor = get_mscale(factor)
            
    beta_fast = 32.0
    beta_slow = 1.0

    def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
        return dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) / (2 * math.log(base))

    def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
        low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
        high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
        return (max(low, 0), min(high, dim - 1))

    def linear_ramp_factor(min, max, dim):
        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func
    
    pos_freqs = base ** (torch.arange(0, dim, 2).to(dtype=torch.float) / dim)
    inv_freq_extrapolation = 1.0 / pos_freqs
    inv_freq_interpolation = 1.0 / (factor * pos_freqs)
    truncate = False
    low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
    inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(dtype=torch.float)
    inv_freq = inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + inv_freq_extrapolation * inv_freq_extrapolation_factor
    return (inv_freq, attention_factor)


class GptOssRotaryEmbedding(nn.Module):
    def __init__(self):
        super().__init__()

        inv_freq, self.attention_scaling = _compute_yarn_parameters()
        self.register_buffer('inv_freq', inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()
        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != 'mps' else 'cpu'
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = freqs
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling
        return (cos.to(x.dtype), sin.to(x.dtype))

In [20]:
rotary_embedding = GptOssRotaryEmbedding()

In [21]:
rotary_embedding.load_state_dict(torch.load('model.rotary_emb.pt'))

<All keys matched successfully>

In [22]:
with open('model.rotary_emb.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [23]:
parameters.keys()

dict_keys(['x', 'position_ids', 'return'])

In [24]:
x = torch.Tensor(parameters['x']['data'])

In [25]:
position_ids = torch.Tensor(parameters['position_ids']['data'])

In [26]:
return_0, return_1 = torch.Tensor(parameters['return']['items'][0]['data']), torch.Tensor(parameters['return']['items'][1]['data'])

In [27]:
actual_return_0, actual_return_1 = rotary_embedding(x, position_ids)

In [28]:
return_0.allclose(actual_return_0)

True

In [29]:
return_1.allclose(actual_return_1)

True

------

In [30]:
with open('model.layers.8.mlp.experts.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [31]:
parameters.keys()

dict_keys(['hidden_states', 'router_indices', 'routing_weights', 'return'])

In [32]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])

In [33]:
router_indices = torch.LongTensor(parameters['router_indices']['data'])

In [34]:
routing_weights = torch.Tensor(parameters['routing_weights']['data'])

In [35]:
ret = torch.Tensor(parameters['return']['data'])

In [36]:
class GptOssExperts(nn.Module):
    def __init__(self):
        super().__init__()
        self.gate_up_proj = nn.Parameter(torch.empty(32, 2880, 2 * 2880))
        self.gate_up_proj_bias = nn.Parameter(torch.empty(32, 2 * 2880))
        self.down_proj = nn.Parameter(torch.empty((32, 2880, 2880)))
        self.down_proj_bias = nn.Parameter(torch.empty(32, 2880))
        self.alpha = 1.702
        self.limit = 7.0

    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
        batch_size = hidden_states.shape[0]
        hidden_states = hidden_states.reshape(-1, 2880)
        num_experts = routing_weights.shape[1]

        next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
        with torch.no_grad():
            expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts + 1)
            expert_mask = expert_mask.permute(2, 1, 0)
            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
        for expert_idx in expert_hit[:]:
            expert_idx = expert_idx[0]
            with torch.no_grad():
                _, token_idx = torch.where(expert_mask[expert_idx])
            current_state = hidden_states[token_idx]
            gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
            gate, up = (gate_up[..., ::2], gate_up[..., 1::2])
            gate = gate.clamp(min=None, max=self.limit)
            up = up.clamp(min=-self.limit, max=self.limit)
            glu = gate * torch.sigmoid(gate * self.alpha)
            gated_output = (up + 1) * glu
            out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
            weighted_output = out * routing_weights[token_idx, expert_idx, None]
            next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
        next_states = next_states.view(batch_size, -1, 2880)

        return next_states

In [37]:
experts = GptOssExperts()

In [38]:
experts.load_state_dict(torch.load('model.layers.8.mlp.experts.pt'))

<All keys matched successfully>

In [39]:
actual_ret = experts(hidden_states, router_indices, routing_weights)

In [40]:
ret.allclose(actual_ret)

True

---------

In [41]:
with open('model.layers.3.self_attn.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [42]:
parameters.keys()

dict_keys(['hidden_states', 'attention_mask', 'position_ids', 'use_cache', 'cache_position', 'position_embeddings', 'output_router_logits', 'return'])

In [43]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])
hidden_states

tensor([[[-0.3566, -0.1026,  1.1694,  ..., -0.4005,  0.1886, -0.1290],
         [ 0.4086,  0.2832,  0.8145,  ..., -0.1869,  0.1457,  0.2259],
         [ 0.6011, -0.5394,  0.0729,  ...,  0.2194, -0.0664, -0.1242],
         ...,
         [-0.1970,  0.3161, -0.1640,  ...,  0.0289, -0.1230, -0.3016],
         [ 0.1955, -0.8916,  0.2378,  ...,  0.0600,  0.0521, -0.4781],
         [-0.2455,  0.3101,  0.1855,  ...,  0.1516, -0.2417, -0.1635]]])

In [44]:
attention_mask = torch.Tensor(parameters['attention_mask']['data'])
attention_mask

tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]]])

In [45]:
position_ids = torch.LongTensor(parameters['position_ids']['data'])
position_ids

tensor([[0, 1, 2, 3, 4, 5, 6]])

In [46]:
parameters['use_cache']

True

In [47]:
parameters['cache_position']

{'type': 'Tensor', 'data': [0, 1, 2, 3, 4, 5, 6]}

In [48]:
position_embeddings_0 = torch.Tensor(parameters['position_embeddings']['items'][0]['data'])

In [49]:
position_embeddings_1 = torch.Tensor(parameters['position_embeddings']['items'][1]['data'])

In [50]:
parameters['output_router_logits']

False

In [51]:
ret_0 = torch.Tensor(parameters['return']['items'][0]['data'])

In [52]:
ret_1 = torch.Tensor(parameters['return']['items'][1]['data'])

In [53]:
from typing import Optional, Tuple
from cowlist import COWList

CONFIG_LAYER_TYPES = COWList([
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention'
])


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float=0.0,
    # **kwargs
):
    key_states = repeat_kv(key, 8)
    value_states = repeat_kv(value, 8)
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask
    sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
    combined_logits = torch.cat([attn_weights, sinks], dim=-1)
    combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
    probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
    scores = probs[..., :-1]
    attn_weights = nn.functional.dropout(scores, p=dropout, training=False)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return (attn_output, attn_weights)


def _apply_rotary_emb(
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> torch.Tensor:
    first_half, second_half = torch.chunk(x, 2, dim=-1)
    first_ = first_half * cos - second_half * sin
    second_ = second_half * cos + first_half * sin
    return torch.cat((first_, second_), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = _apply_rotary_emb(q, cos, sin)
    k_embed = _apply_rotary_emb(k, cos, sin)
    return q_embed, k_embed


class GptOssAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.q_proj = nn.Linear(2880, 64 * 64, bias=True)
        self.k_proj = nn.Linear(2880, 8 * 64, bias=True)
        self.v_proj = nn.Linear(2880, 8 * 64, bias=True)
        self.o_proj = nn.Linear(64 * 64, 2880, bias=True)
        self.sinks = nn.Parameter(torch.empty(64))

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        # past_key_values: Optional[Cache]=None,
        # cache_position: Optional[torch.LongTensor]=None,
        # **kwargs
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, 64)
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        attn_output, attn_weights = eager_attention_forward(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0,
            scaling=0.125
        )
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return (attn_output, attn_weights)

In [54]:
self_attn = GptOssAttention()

In [55]:
self_attn.load_state_dict(torch.load('model.layers.3.self_attn.pt'))

<All keys matched successfully>

In [56]:
actual_ret_0, actual_ret_1 = self_attn(hidden_states, (position_embeddings_0, position_embeddings_1), attention_mask)

In [57]:
ret_0.allclose(actual_ret_0)

True

In [58]:
ret_1.allclose(actual_ret_1)

True

---------

In [59]:
class GptOssMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.router = GptOssTopKRouter()
        self.experts = GptOssExperts()

    def forward(self, hidden_states):
        router_scores, router_indices = self.router(hidden_states)
        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
        return (routed_out, router_scores)

In [60]:
mlp = GptOssMLP()

In [61]:
mlp.load_state_dict(torch.load('model.layers.0.mlp.pt'))

<All keys matched successfully>

In [62]:
with open('model.layers.0.mlp.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [63]:
parameters.keys()

dict_keys(['hidden_states', 'return'])

In [64]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])
hidden_states

tensor([[[-0.2309,  0.2867,  1.0308,  ..., -0.0260, -0.0321,  0.0591],
         [ 0.3464,  0.4684,  0.6387,  ..., -0.4321, -0.1286,  0.0166],
         [ 0.3301,  0.2904,  0.1818,  ..., -0.1855, -0.3324, -0.1021],
         ...,
         [-0.2346,  0.1293,  0.2423,  ..., -0.5792,  0.0849, -0.1463],
         [ 0.0930, -0.3895,  0.5061,  ...,  0.0616, -0.2823, -0.1552],
         [ 0.2867,  0.2045,  0.6961,  ..., -0.3314, -0.2441, -0.1707]]])

In [65]:
return_0 = torch.Tensor(parameters['return']['items'][0]['data'])
return_0

tensor([[[ 0.0024, -0.3997, -0.3772,  ...,  0.0044,  0.0997,  0.0103],
         [ 0.1005, -0.2501,  0.1974,  ...,  0.0395,  0.0471,  0.3597],
         [-0.0042,  0.0123,  0.2023,  ...,  0.1977,  0.0931,  0.0578],
         ...,
         [ 0.0902,  0.0755, -0.2559,  ...,  0.0017, -0.0158,  0.0860],
         [ 0.1149,  0.1651,  0.4096,  ...,  0.0608,  0.0294, -0.2438],
         [-0.2755,  0.0243, -0.3894,  ..., -0.1046,  0.0164,  0.1729]]])

In [66]:
return_1 = torch.Tensor(parameters['return']['items'][1]['data'])
return_1

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3030, 0.0000, 0.0000, 0.0000,
         0.0823, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3616,
         0.0000, 0.0000, 0.0000, 0.0000, 0.2531],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2018, 0.0000, 0.0000,
         0.0000, 0.0000, 0.1148, 0.0000, 0.0000, 0.0000, 0.0000, 0.4577, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2258,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0798, 0.0000, 0.7746, 0.0000, 0.0000, 0.0000, 0.0756,
         0.0000, 0.0000, 0.0000, 0.0700, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1269, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000,

In [67]:
actual_ret_0, actual_ret_1 = mlp(hidden_states)

In [68]:
return_0.allclose(actual_ret_0)

True

In [69]:
return_1.allclose(actual_ret_1)

True

------

In [70]:
class GptOssDecoderLayer(nn.Module):

    def __init__(self):
        super().__init__()
        self.self_attn = GptOssAttention()
        self.mlp = GptOssMLP()
        self.input_layernorm = GptOssRMSNorm()
        self.post_attention_layernorm = GptOssRMSNorm()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor]=None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]]=None,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_embeddings=position_embeddings,
        )
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states, _ = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states

In [71]:
layers_20 = GptOssDecoderLayer()

In [72]:
layers_20.load_state_dict(torch.load('model.layers.20.pt'))

<All keys matched successfully>

In [73]:
with open('model.layers.20.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [74]:
parameters.keys()

dict_keys(['hidden_states', 'attention_mask', 'position_ids', 'use_cache', 'cache_position', 'position_embeddings', 'output_router_logits', 'return'])

In [75]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])
hidden_states

tensor([[[  -5.8659,   33.0872,   91.4604,  ...,   20.2924,   73.2631,
          -235.1105],
         [-121.1287,  -56.9332,   17.2398,  ...,  -56.2602,    2.1464,
          -236.1957],
         [ 148.5167,  106.2464,  100.2911,  ...,    6.6027,   85.7875,
          -127.2576],
         ...,
         [ 241.3062,  -62.1225,  -36.7799,  ...,  -26.3095,  -28.3285,
             5.2132],
         [ -86.6025, -114.4028,  118.7643,  ...,  -17.0406,   84.3951,
          -219.4708],
         [-318.7719,   50.1896,   68.3206,  ...,   56.5228,   97.6040,
            39.3755]]])

In [76]:
attention_mask = torch.Tensor(parameters['attention_mask']['data'])
attention_mask

tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]]])

In [77]:
torch.Tensor(parameters['position_ids']['data'])

tensor([[0., 1., 2., 3., 4., 5., 6.]])

In [78]:
parameters['use_cache']

True

In [79]:
parameters['cache_position']

{'type': 'Tensor', 'data': [0, 1, 2, 3, 4, 5, 6]}

In [80]:
position_embeddings_0 = torch.Tensor(parameters['position_embeddings']['items'][0]['data'])
position_embeddings_0

tensor([[[ 1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466],
         [ 0.7276,  1.0394,  1.1976,  1.2752,  1.3125,  1.3304,  1.3389,
           1.3429,  1.3448,  1.3459,  1.3463,  1.3465,  1.3465,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466],
         [-0.5604,  0.2579,  0.7838,  1.0685,  1.2120,  1.2821,  1.3158,
           1.3320,  1.3396,  1.3439,  1.3456,  1.3462,  1.3464,  1.3465,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1

In [81]:
position_embeddings_1 = torch.Tensor(parameters['position_embeddings']['items'][1]['data'])
position_embeddings_1

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.1331e+00,  8.5615e-01,  6.1558e-01,  4.3271e-01,  3.0098e-01,
           2.0831e-01,  1.4384e-01,  9.9212e-02,  6.8394e-02,  4.2687e-02,
           2.6034e-02,  1.5609e-02,  9.1498e-03,  5.1982e-03,  2.8194e-03,
           1.4174e-03,  6.1469e-04,  1.7414e-04,  5.1586e-05,  3.5545e-05,
           2.4492e-05,  1.6876e-05,  1.1628e-05,  8.0124e-06,  5.5209e-06,
           3.8042e-06,  2.6212e-06,  1.8061e-06,  1.2445e-06,  8.5753e-07,
           5.9087e-07,  4.0714e-07],
         [ 1.2244e+00,  1.

In [82]:
ret = torch.Tensor(parameters['return']['data'])
ret

tensor([[[ -51.7529,  -79.2223,  166.9466,  ...,   45.3354,   29.9371,
          -294.3643],
         [-106.0870,  -51.1902,  113.9774,  ..., -122.4729,  -74.4957,
          -298.7038],
         [ 162.0087,   43.3856,  182.8648,  ..., -100.3009,   22.0254,
          -199.8858],
         ...,
         [ 225.2012,  -14.6964,  -12.7762,  ...,  -66.5061, -120.1939,
            30.1922],
         [-202.3708,  -67.3173,  108.8600,  ...,  -21.3696,  119.0166,
          -300.0491],
         [-507.0756,  107.2169,   57.5793,  ...,   -6.3350,   65.9705,
           126.3017]]])

In [83]:
actual_ret = layers_20(hidden_states, attention_mask, (position_embeddings_0, position_embeddings_1))

In [84]:
ret.allclose(actual_ret)

True

------

# Work in Progress

The whole model:

```python
class GptOssModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_tokens = nn.Embedding(201088, 2880, 199999)
        self.layers = nn.ModuleList([GptOssDecoderLayer() for _ in range(24)])
        self.norm = GptOssRMSNorm(2880, eps=1e-05)
        self.rotary_emb = GptOssRotaryEmbedding()

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.Tensor,
        position_ids: torch.LongTensor,
        cache_position: torch.LongTensor,
    ):
        inputs_embeds = self.embed_tokens(input_ids)
        mask_kwargs = {'config': self.config, 'input_embeds': inputs_embeds, 'attention_mask': attention_mask, 'cache_position': cache_position, 'past_key_values': past_key_values}
        causal_mask_mapping = {'full_attention': create_causal_mask(**mask_kwargs), 'sliding_attention': create_sliding_window_causal_mask(**mask_kwargs)}
        hidden_states = inputs_embeds
        position_embeddings = self.rotary_emb(hidden_states, position_ids)
        for decoder_layer in self.layers:
            hidden_states = decoder_layer(
                hidden_states,
                attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                position_embeddings=position_embeddings,
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states


class GptOssForCausalLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = GptOssModel()
        self.lm_head = nn.Linear(2880, 201088, bias=False)

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.Tensor,
        position_ids: torch.LongTensor,
        cache_position: torch.LongTensor,
    ):
        hidden_states = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, cache_position=cache_position)
        logits = self.lm_head(hidden_states[:, -1:, :])
        return logits
```

Mask creation logic:

```python
def create_causal_mask(
    attention_mask,
    kv_length,
    batch_size,
    dtype,
    cache_position
):
    causal_mask = eager_mask(
        batch_size=batch_size,
        cache_position=cache_position,
        kv_length=kv_length,
        kv_offset=0,
        mask_function=causal_mask_function,
        attention_mask=attention_mask,
        allow_is_causal_skip=True,
        dtype=dtype,
    )
    return causal_mask


def create_sliding_window_causal_mask(
    attention_mask,
    kv_length,
    batch_size,
    dtype,
    cache_position
):
    mask_factory_function = and_masks(sliding_window_overlay, causal_mask_function)
    causal_mask = eager_mask(
        batch_size=batch_size,
        cache_position=cache_position,
        kv_length=kv_length,
        kv_offset=0,
        mask_function=mask_factory_function,
        attention_mask=attention_mask,
        allow_is_causal_skip=True,
        dtype=dtype,
        local_size=128,
    )
    return causal_mask


def eager_mask(
    batch_size: int,
    cache_position: torch.Tensor,
    kv_length: int,
    kv_offset: int=0,
    mask_function: Callable=causal_mask_function,
    attention_mask: Optional[torch.Tensor]=None,
    dtype: torch.dtype=torch.float32,
    **kwargs
):
    mask = sdpa_mask_recent_torch(
        batch_size=batch_size,
        cache_position=cache_position,
        kv_length=kv_length,
        kv_offset=kv_offset,
        mask_function=mask_function,
        attention_mask=attention_mask,
        allow_is_causal_skip=False,
        allow_torch_fix=False,
        **kwargs
    )
    min_dtype = torch.finfo(dtype).min
    mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
    return mask


def sdpa_mask_recent_torch(
    batch_size: int,
    cache_position: torch.Tensor,
    kv_length: int,
    kv_offset: int=0,
    mask_function: Callable=causal_mask_function,
    attention_mask: Optional[torch.Tensor]=None,
    local_size: Optional[int]=None,
    allow_is_causal_skip: bool=True
) -> Optional[torch.Tensor]:
    q_length = cache_position.shape[0]
    kv_arange = torch.arange(kv_length, device=cache_position.device)
    mask_function = and_masks(mask_function, padding_mask_function(attention_mask))
    batch_arange = torch.arange(batch_size, device=cache_position.device)
    head_arange = torch.arange(1, device=cache_position.device)
    with TransformGetItemToIndex():
        causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
    return causal_mask


def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int):
    return kv_idx <= q_idx

def sliding_window_overlay(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int):
    return kv_idx > q_idx - 128


def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
    def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
        return padding_mask[batch_idx, kv_idx]
    return inner_mask


def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool=True) -> Callable:
    dimensions = [(None, None, None, 0), (None, None, 0, None)]
    if bh_indices:
        dimensions.extend([(None, 0, None, None), (0, None, None, None)])
    for dims in dimensions:
        mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
    return mask_function


def and_masks(*mask_functions: Callable) -> Callable:
    def and_mask(batch_idx, head_idx, q_idx, kv_idx):
        result = q_idx.new_ones((), dtype=torch.bool)
        for mask in mask_functions:
            result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
        return result
    return and_mask
```

Values used to create masks:

```
tensor([[True, True, True, True, True, True, True]]) 7 1 torch.float32 tensor([0, 1, 2, 3, 4, 5, 6])
tensor([[True, True, True, True, True, True, True]]) 7 1 torch.float32 tensor([0, 1, 2, 3, 4, 5, 6])

tensor([[True, True, True, True, True, True, True, True]]) 8 1 torch.float32 tensor([7])
tensor([[True, True, True, True, True, True, True, True]]) 8 1 torch.float32 tensor([7])

tensor([[True, True, True, True, True, True, True, True, True]]) 9 1 torch.float32 tensor([8])
tensor([[True, True, True, True, True, True, True, True, True]]) 9 1 torch.float32 tensor([8])

...

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True]]) 26 1 torch.float32 tensor([25])
tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True]]) 26 1 torch.float32 tensor([25])
```

Created masks:

```
{'full_attention': tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]]]), 'sliding_attention': tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]]])}

{'full_attention': tensor([[[[0., 0., 0., 0., 0., 0., 0., 0.]]]]), 'sliding_attention': tensor([[[[0., 0., 0., 0., 0., 0., 0., 0.]]]])}

{'full_attention': tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]), 'sliding_attention': tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])}

...

{'full_attention': tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.]]]]), 'sliding_attention': tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.]]]])}
```

Generation logic:

```python
class GenerationMixin(ContinuousMixin):
    @torch.no_grad()
    def generate(
        self,
        # inputs: Optional[torch.Tensor]=None,
        # generation_config: Optional[GenerationConfig]=None,
        # logits_processor: Optional[LogitsProcessorList]=None,
        # stopping_criteria: Optional[StoppingCriteriaList]=None,
        # prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]]=None,
        # synced_gpus: Optional[bool]=None,
        # assistant_model: Optional['PreTrainedModel']=None,
        # streamer: Optional['BaseStreamer']=None,
        # negative_prompt_ids: Optional[torch.Tensor]=None,
        # negative_prompt_attention_mask: Optional[torch.Tensor]=None,
        # use_model_defaults: Optional[bool]=None,
        # custom_generate: Optional[Union[str, Callable]]=None,
        **kwargs
    ) -> Union[GenerateOutput, torch.LongTensor]:...
        # {'synced_gpus': False}
        generation_mode_kwargs = self._extract_generation_mode_kwargs(custom_generate, kwargs, synced_gpus, assistant_model, streamer)
        # GenerationConfig {
        #   "bos_token_id": 199998,
        #   "do_sample": true,
        #   "eos_token_id": [
        #     200002,
        #     199999
        #   ],
        #   "pad_token_id": 199999
        # }
        # {'input_ids': tensor([[   40,  6423,   290, 10915,   328,  2615,   382]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}
        generation_config, model_kwargs = self._prepare_generation_config(generation_config, use_model_defaults, **kwargs)
        # <GenerationMode.SAMPLE: 'sample'>
        generation_mode = generation_config.get_generation_mode(assistant_model)

        logits_processor = LogitsProcessorList()
        stopping_criteria = StoppingCriteriaList()
        # tensor([[   40,  6423,   290, 10915,   328,  2615,   382]])
        # 'input_ids'
        # {'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
        # 1
        batch_size = inputs_tensor.shape[0]
        device = inputs_tensor.device
        self._prepare_special_tokens(generation_config, True, device=device)
        # tensor([[   40,  6423,   290, 10915,   328,  2615,   382]])
        input_ids = inputs_tensor if model_input_name == 'input_ids' else model_kwargs.pop('input_ids')
        # 7
        input_ids_length = input_ids.shape[1]
        # True
        # 20
        has_default_max_length = kwargs.get('max_length') is None and generation_config.max_length is not None
        # True
        # 0
        has_default_min_length = kwargs.get('min_length') is None and generation_config.min_length is not None
        generation_config = self._prepare_generated_length(generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=model_input_name, inputs_tensor=inputs_tensor, input_ids_length=input_ids_length)
        model_kwargs['logits_to_keep'] = 1
        # 27
        # 26
        max_cache_length = generation_config.max_length - 1

        dynamic_cache_kwargs = {'config': self.config.get_text_config(decoder=True)}
        model_kwargs['past_key_values'] = DynamicCache(**dynamic_cache_kwargs)
        self._prepare_cache_for_generation(generation_config, model_kwargs, generation_mode, batch_size, max_cache_length)

        prepared_logits_processor = self._get_logits_processor(generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, device=inputs_tensor.device, model_kwargs=model_kwargs, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask)

        prepared_stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=generation_mode_kwargs.get('tokenizer'))

        # True
        model_kwargs['use_cache'] = generation_config.use_cache
        
        # <function GenerationMixin._sample at 0x7fef8e6998a0>
        result = self._sample(self, input_ids, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, **generation_mode_kwargs, **model_kwargs)

        return result

    def _sample(self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool=False, streamer: Optional['BaseStreamer']=None, **model_kwargs) -> Union[GenerateNonBeamOutput, torch.LongTensor]:...
        # tensor(199999)
        pad_token_id = generation_config._pad_token_tensor
        scores = None
        # (1, 7)
        batch_size, cur_len = input_ids.shape[:2]
        this_peer_finished = False
        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
        model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
        model_forward = self.__call__
        is_prefill = True
        while not this_peer_finished:
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            if is_prefill:
                # <function GptOssForCausalLM.forward at 0x7fef8d8cbba0>
                outputs = self(**model_inputs, return_dict=True)
                is_prefill = False
            else:
                # <function GptOssForCausalLM.forward at 0x7fef8d8cbba0>
                outputs = model_forward(**model_inputs, return_dict=True)
            model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder)
            next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
            next_token_scores = logits_processor(input_ids, next_token_logits)
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
            this_peer_finished = unfinished_sequences.max() == 0
            cur_len += 1
            del outputs
        return input_ids

    def _cache_dependant_input_preparation(self, input_ids: torch.LongTensor, inputs_embeds: Optional[torch.FloatTensor], cache_position: Optional[torch.LongTensor]) -> tuple[torch.FloatTensor, torch.LongTensor]:...
        if input_ids.shape[1] != cache_position.shape[0]:
            input_ids = input_ids[:, cache_position]
        return (inputs_embeds, input_ids)

    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, past_key_values: Optional[Cache]=None, attention_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, cache_position: Optional[torch.LongTensor]=None, **kwargs):...
        model_inputs = {}
        model_inputs['cache_position'] = cache_position
        if past_key_values is not None:
            model_inputs['past_key_values'] = past_key_values
            inputs_embeds, input_ids = self._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position)
        model_inputs['input_ids'] = input_ids.clone(memory_format=torch.contiguous_format)
        model_inputs['inputs_embeds'] = None
        if attention_mask is not None and kwargs.get('position_ids') is None and ('position_ids' in set(inspect.signature(self.forward).parameters.keys())):
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            kwargs['position_ids'] = position_ids
        for model_input_name in ['position_ids', 'token_type_ids', 'decoder_position_ids']:
            model_input = kwargs.get(model_input_name)
            if model_input is not None:
                if past_key_values is not None:
                    current_input_length = ... if model_inputs.get('inputs_embeds') is not None else model_inputs['input_ids'].shape[1]
                    model_input = model_input[:, -current_input_length:]
                    model_input = model_input.clone(memory_format=torch.contiguous_format)
                model_inputs[model_input_name] = model_input
        if attention_mask is not None:
            model_inputs['attention_mask'] = attention_mask
        for key, value in kwargs.items():
            if key not in model_inputs:
                model_inputs[key] = value
        model_inputs.pop('labels', None)
        return model_inputs

    def _prepare_model_inputs(self, inputs: Optional[torch.Tensor]=None, bos_token_id: Optional[torch.Tensor]=None, model_kwargs: Optional[dict[str, torch.Tensor]]=None) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:...
        # input_ids
        input_name = self.main_input_name
        model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
        inputs_kwarg = model_kwargs.pop(input_name, None)
        inputs = inputs_kwarg
        return (inputs, input_name, model_kwargs)

    def _update_model_kwargs_for_generation(self, outputs: ModelOutput, model_kwargs: dict[str, Any], is_encoder_decoder: bool=False, num_new_tokens: int=1) -> dict[str, Any]:
        model_kwargs['past_key_values'] = outputs.possible_cache_name
        # tensor([[1, 1, 1, 1, 1, 1, 1]])
        attention_mask = model_kwargs['attention_mask']
        # tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
        model_kwargs['attention_mask'] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
        model_kwargs['cache_position'] = model_kwargs['cache_position'][-1:] + num_new_tokens
        return model_kwargs

    def _get_logits_processor(self, generation_config: GenerationConfig, input_ids_seq_length: Optional[int]=None, encoder_input_ids: Optional[torch.LongTensor]=None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]]=None, logits_processor: Optional[LogitsProcessorList]=None, device: Optional[str]=None, model_kwargs: Optional[dict[str, Any]]=None, negative_prompt_ids: Optional[torch.Tensor]=None, negative_prompt_attention_mask: Optional[torch.Tensor]=None) -> LogitsProcessorList:...
        processors = LogitsProcessorList()
        min_tokens_to_keep = 1
        # 50
        # 1
        processors.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
        return processors

    def _get_stopping_criteria(self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList], tokenizer: Optional['PreTrainedTokenizerBase']=None) -> StoppingCriteriaList:
        criteria = StoppingCriteriaList()
        # 131072
        max_position_embeddings = getattr(self.config, 'max_position_embeddings', None)
        criteria.append(MaxLengthCriteria(max_length=generation_config.max_length, max_position_embeddings=max_position_embeddings))
        criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
        return criteria

    def _prepare_generated_length(self, generation_config, has_default_max_length, has_default_min_length, model_input_name, input_ids_length, inputs_tensor):...
        generation_config.max_length = generation_config.max_length + input_ids_length
        # 131072
        max_position_embeddings = getattr(self.config, 'max_position_embeddings', None)
        generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
        return generation_config

    def _prepare_generation_config(self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool]=None, **kwargs: Any) -> tuple[GenerationConfig, dict]:...
        using_model_generation_config = False
        if generation_config is None:...
        generation_config = copy.deepcopy(generation_config)
        if not using_model_generation_config:
            model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version)
            if use_model_defaults is True or (use_model_defaults is None and model_base_version >= version.parse('4.50.0')):
                modified_values = {}
                global_default_generation_config = GenerationConfig()
                model_generation_config = self.generation_config
                for key, model_gen_config_value in model_generation_config.__dict__.items():
                    if key.startswith('_') or key == 'transformers_version':
                        continue
                    if key == 'cache_implementation' and model_generation_config.cache_implementation == 'hybrid':...
                    global_default_value = getattr(global_default_generation_config, key, None)
                    custom_gen_config_value = getattr(generation_config, key, None)
                    if custom_gen_config_value == global_default_value and model_gen_config_value != global_default_value:...
                if generation_config.temperature == 0.0:...
                if use_model_defaults is None and len(modified_values) > 0:...
            else:...
        model_kwargs = generation_config.update(**kwargs)
        output_attentions = generation_config.output_attentions
        output_hidden_states = generation_config.output_hidden_states
        model_kwargs.update({'output_attentions': output_attentions} if output_attentions else {})
        model_kwargs.update({'output_hidden_states': output_hidden_states} if output_hidden_states else {})
        return (generation_config, model_kwargs)

    def _get_initial_cache_position(self, seq_length, device, model_kwargs):...
        # tensor([0, 1, 2, 3, 4, 5, 6])
        cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1
        past_length = 0
        if model_kwargs.get('past_key_values') is not None:
            # DynamicCache(layers=[DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer, DynamicSlidingWindowLayer, DynamicLayer])
            cache = model_kwargs['past_key_values']
            past_length = 0
            past_length = cache.get_seq_length()
            cache_position = cache_position[past_length:]
        model_kwargs['cache_position'] = cache_position
        return model_kwargs

    def _prepare_special_tokens(self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool]=None, device: Optional[Union[torch.device, str]]=None):...

        def _tensor_or_none(token, device=None):
            if token is None:
                return token
            device = device if device is not None else self.device
            if isinstance(token, torch.Tensor):...
            return torch.tensor(token, device=device, dtype=torch.long)
        # 199998
        # tensor(199998)
        bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
        # [200002, 199999]
        # tensor([200002, 199999])
        eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
        # 199999
        # tensor(199999)
        pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
        # None
        # None
        decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
        generation_config._bos_token_tensor = bos_token_tensor
        generation_config._eos_token_tensor = eos_token_tensor
        generation_config._pad_token_tensor = pad_token_tensor
        generation_config._decoder_start_token_tensor = decoder_start_token_tensor

    def _extract_generation_mode_kwargs(self, custom_generate, kwargs, synced_gpus, assistant_model, streamer) -> dict[str, Any]:...
        generation_mode_kwargs = {'tokenizer': kwargs.pop('tokenizer', None), 'assistant_tokenizer': kwargs.pop('assistant_tokenizer', None), 'assistant_model': assistant_model, 'streamer': streamer}
        generation_mode_kwargs['synced_gpus'] = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 if synced_gpus is None else ...
        generation_mode_kwargs = {k: v for k, v in generation_mode_kwargs.items() if v is not None}
        if isinstance(custom_generate, Callable):...
        return generation_mode_kwargs
```