In [1]:
import math
import json
import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional, Tuple

In [2]:
class GptOssTopKRouter(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(32, 2880))
        self.bias = nn.Parameter(torch.empty(32))

    def forward(self, hidden_states):
        hidden_states = hidden_states.reshape(-1, 2880)
        router_logits = F.linear(hidden_states, self.weight, self.bias)  # (seq_len, num_experts)
        router_top_value, router_indices = torch.topk(router_logits, 4, dim=-1)  # (seq_len, top_k)
        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
        return router_scores, router_indices

In [3]:
router = GptOssTopKRouter()

In [4]:
router.load_state_dict(torch.load('model.layers.7.mlp.router.pt'))

<All keys matched successfully>

In [5]:
with open('model.layers.7.mlp.router.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [6]:
parameters.keys()

dict_keys(['hidden_states', 'return'])

In [7]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])

In [8]:
router_scores, router_indices = torch.Tensor(parameters['return']['items'][0]['data']), torch.LongTensor(parameters['return']['items'][1]['data'])

In [9]:
actual_router_scores, actual_router_indices = router(hidden_states)

In [10]:
router_scores.allclose(actual_router_scores)

True

In [11]:
router_indices.allclose(actual_router_indices)

True

------

In [12]:
class GptOssRMSNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(2880))

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + 1e-05)
        return (self.weight * hidden_states).to(input_dtype)

In [13]:
post_attention_layernorm = GptOssRMSNorm()

In [14]:
post_attention_layernorm.load_state_dict(torch.load('model.layers.23.post_attention_layernorm.pt'))

<All keys matched successfully>

In [15]:
with open('model.layers.23.post_attention_layernorm.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [16]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])

In [17]:
ret = torch.Tensor(parameters['return']['data'])

In [18]:
actual_ret = post_attention_layernorm(hidden_states)

In [19]:
ret.allclose(actual_ret)

True

------

In [20]:
def _compute_yarn_parameters():
    base = 150000
    partial_rotary_factor = 1.0
    head_dim = 64
    dim = int(head_dim * partial_rotary_factor)
    factor = 32.0
    attention_factor = None
    mscale = None
    mscale_all_dim = None
    original_max_position_embeddings = 4096

    def get_mscale(scale, mscale=1):
        return 0.1 * mscale * math.log(scale) + 1.0

    if attention_factor is None:
        attention_factor = get_mscale(factor)
            
    beta_fast = 32.0
    beta_slow = 1.0

    def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
        return dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) / (2 * math.log(base))

    def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
        low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
        high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
        return (max(low, 0), min(high, dim - 1))

    def linear_ramp_factor(min, max, dim):
        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
        ramp_func = torch.clamp(linear_func, 0, 1)
        return ramp_func
    
    pos_freqs = base ** (torch.arange(0, dim, 2).to(dtype=torch.float) / dim)
    inv_freq_extrapolation = 1.0 / pos_freqs
    inv_freq_interpolation = 1.0 / (factor * pos_freqs)
    truncate = False
    low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
    inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(dtype=torch.float)
    inv_freq = inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + inv_freq_extrapolation * inv_freq_extrapolation_factor
    return (inv_freq, attention_factor)


class GptOssRotaryEmbedding(nn.Module):
    def __init__(self):
        super().__init__()

        inv_freq, self.attention_scaling = _compute_yarn_parameters()
        self.register_buffer('inv_freq', inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()
        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != 'mps' else 'cpu'
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = freqs
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling
        return (cos.to(x.dtype), sin.to(x.dtype))

In [21]:
rotary_embedding = GptOssRotaryEmbedding()

In [22]:
rotary_embedding.load_state_dict(torch.load('model.rotary_emb.pt'))

<All keys matched successfully>

In [23]:
with open('model.rotary_emb.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [24]:
parameters.keys()

dict_keys(['x', 'position_ids', 'return'])

In [25]:
x = torch.Tensor(parameters['x']['data'])

In [26]:
position_ids = torch.Tensor(parameters['position_ids']['data'])

In [27]:
return_0, return_1 = torch.Tensor(parameters['return']['items'][0]['data']), torch.Tensor(parameters['return']['items'][1]['data'])

In [28]:
actual_return_0, actual_return_1 = rotary_embedding(x, position_ids)

In [29]:
return_0.allclose(actual_return_0)

True

In [30]:
return_1.allclose(actual_return_1)

True

------

In [31]:
with open('model.layers.4.mlp.experts.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [32]:
parameters.keys()

dict_keys(['hidden_states', 'router_indices', 'routing_weights', 'return'])

In [33]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])

In [34]:
router_indices = torch.LongTensor(parameters['router_indices']['data'])

In [35]:
routing_weights = torch.Tensor(parameters['routing_weights']['data'])

In [36]:
ret = torch.Tensor(parameters['return']['data'])

In [37]:
class GptOssExperts(nn.Module):
    def __init__(self):
        super().__init__()
        self.gate_up_proj = nn.Parameter(torch.empty(32, 2880, 2 * 2880))
        self.gate_up_proj_bias = nn.Parameter(torch.empty(32, 2 * 2880))
        self.down_proj = nn.Parameter(torch.empty((32, 2880, 2880)))
        self.down_proj_bias = nn.Parameter(torch.empty(32, 2880))
        self.alpha = 1.702
        self.limit = 7.0

    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
        batch_size = hidden_states.shape[0]
        hidden_states = hidden_states.reshape(-1, 2880)
        num_experts = routing_weights.shape[1]

        next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
        with torch.no_grad():
            expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts + 1)
            expert_mask = expert_mask.permute(2, 1, 0)
            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
        for expert_idx in expert_hit[:]:
            expert_idx = expert_idx[0]
            with torch.no_grad():
                _, token_idx = torch.where(expert_mask[expert_idx])
            current_state = hidden_states[token_idx]
            gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
            gate, up = (gate_up[..., ::2], gate_up[..., 1::2])
            gate = gate.clamp(min=None, max=self.limit)
            up = up.clamp(min=-self.limit, max=self.limit)
            glu = gate * torch.sigmoid(gate * self.alpha)
            gated_output = (up + 1) * glu
            out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
            weighted_output = out * routing_weights[token_idx, expert_idx, None]
            next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
        next_states = next_states.view(batch_size, -1, 2880)

        return next_states

In [38]:
experts = GptOssExperts()

In [39]:
experts.load_state_dict(torch.load('model.layers.4.mlp.experts.pt'))

<All keys matched successfully>

In [40]:
actual_ret = experts(hidden_states, router_indices, routing_weights)

In [41]:
ret.allclose(actual_ret)

True

---------

In [42]:
with open('model.layers.8.self_attn.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [43]:
parameters.keys()

dict_keys(['hidden_states', 'attention_mask', 'position_ids', 'use_cache', 'cache_position', 'position_embeddings', 'output_router_logits', 'return'])

In [44]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])
hidden_states

tensor([[[ 1.2242e-03,  1.3980e-02, -3.5311e-02,  ..., -1.2268e-02,
           6.2347e-03, -5.4394e-03],
         [ 2.3627e-01,  3.6567e-01,  1.3929e+00,  ..., -6.9081e-01,
           3.9588e-01,  3.1281e-01],
         [ 9.4030e-01, -1.1515e-01,  5.0602e-01,  ..., -9.3824e-02,
           3.3725e-01, -1.1308e-01],
         ...,
         [ 3.1247e-01, -1.3037e-01,  2.4110e-01,  ...,  5.6335e-02,
          -3.3835e-01,  1.2023e-01],
         [ 1.0252e-01, -8.6217e-01,  8.6264e-01,  ..., -6.1494e-01,
           1.2398e-01,  1.3620e-01],
         [-9.1242e-02, -4.9699e-01,  1.3439e-01,  ..., -7.0277e-03,
           3.6442e-02,  4.3396e-01]]])

In [45]:
attention_mask = torch.Tensor(parameters['attention_mask']['data'])
attention_mask

tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]]])

In [46]:
position_ids = torch.LongTensor(parameters['position_ids']['data'])
position_ids

tensor([[0, 1, 2, 3, 4, 5, 6]])

In [47]:
parameters['use_cache']

True

In [48]:
parameters['cache_position']

{'type': 'Tensor', 'data': [0, 1, 2, 3, 4, 5, 6]}

In [49]:
position_embeddings_0 = torch.Tensor(parameters['position_embeddings']['items'][0]['data'])

In [50]:
position_embeddings_1 = torch.Tensor(parameters['position_embeddings']['items'][1]['data'])

In [51]:
parameters['output_router_logits']

False

In [52]:
ret_0 = torch.Tensor(parameters['return']['items'][0]['data'])

In [53]:
ret_1 = torch.Tensor(parameters['return']['items'][1]['data'])

In [54]:
CONFIG_LAYER_TYPES = (
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention',
    'sliding_attention',
    'full_attention'
)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float=0.0,
    # **kwargs
):
    key_states = repeat_kv(key, 8)
    value_states = repeat_kv(value, 8)
    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask
    sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
    combined_logits = torch.cat([attn_weights, sinks], dim=-1)
    combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
    probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
    scores = probs[..., :-1]
    attn_weights = nn.functional.dropout(scores, p=dropout, training=False)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return (attn_output, attn_weights)


def _apply_rotary_emb(
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> torch.Tensor:
    first_half, second_half = torch.chunk(x, 2, dim=-1)
    first_ = first_half * cos - second_half * sin
    second_ = second_half * cos + first_half * sin
    return torch.cat((first_, second_), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = _apply_rotary_emb(q, cos, sin)
    k_embed = _apply_rotary_emb(k, cos, sin)
    return q_embed, k_embed


class GptOssAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.q_proj = nn.Linear(2880, 64 * 64, bias=True)
        self.k_proj = nn.Linear(2880, 8 * 64, bias=True)
        self.v_proj = nn.Linear(2880, 8 * 64, bias=True)
        self.o_proj = nn.Linear(64 * 64, 2880, bias=True)
        self.sinks = nn.Parameter(torch.empty(64))

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
    ) -> tuple[torch.Tensor, torch.Tensor]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, 64)
        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        attn_output, attn_weights = eager_attention_forward(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0,
            scaling=0.125
        )
        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return (attn_output, attn_weights)

In [55]:
self_attn = GptOssAttention()

In [56]:
self_attn.load_state_dict(torch.load('model.layers.8.self_attn.pt'))

<All keys matched successfully>

In [57]:
actual_ret_0, actual_ret_1 = self_attn(hidden_states, attention_mask, (position_embeddings_0, position_embeddings_1))

In [58]:
ret_0.allclose(actual_ret_0)

True

In [59]:
ret_1.allclose(actual_ret_1)

True

---------

In [60]:
class GptOssMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.router = GptOssTopKRouter()
        self.experts = GptOssExperts()

    def forward(self, hidden_states):
        router_scores, router_indices = self.router(hidden_states)
        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
        return (routed_out, router_scores)

In [61]:
mlp = GptOssMLP()

In [62]:
mlp.load_state_dict(torch.load('model.layers.7.mlp.pt'))

<All keys matched successfully>

In [63]:
with open('model.layers.7.mlp.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [64]:
parameters.keys()

dict_keys(['hidden_states', 'return'])

In [65]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])
hidden_states

tensor([[[ 0.0026,  0.0128, -0.0469,  ..., -0.0138,  0.0093, -0.0074],
         [ 0.5175,  0.2834,  1.4318,  ..., -0.9154,  0.6919,  0.3307],
         [ 1.2628, -0.2193,  0.2334,  ..., -0.1061,  0.2105, -0.0426],
         ...,
         [ 0.3651,  0.1151,  0.0359,  ..., -0.1679, -0.6479,  0.4814],
         [ 0.3451, -0.5466,  1.1070,  ..., -0.2724, -0.0070,  0.4061],
         [-0.0204, -0.2789,  0.3336,  ...,  0.1099, -0.4804,  0.3660]]])

In [66]:
return_0 = torch.Tensor(parameters['return']['items'][0]['data'])
return_0

tensor([[[-0.2821,  1.0698, -0.0336,  ..., -0.2790, -0.1671,  0.0159],
         [-0.8354,  1.0031,  2.6871,  ...,  0.2195, -0.4317,  0.5068],
         [ 0.3887,  0.6649,  2.4014,  ..., -0.0706,  1.1407, -0.5056],
         ...,
         [ 0.5996, -1.9755,  1.6059,  ...,  1.2806,  0.3268, -1.1146],
         [-0.8949, -4.6857,  1.5567,  ..., -3.7020,  0.9273, -0.8176],
         [-0.5841, -2.5539, -0.6914,  ..., -0.6995,  2.2621,  1.3178]]])

In [67]:
return_1 = torch.Tensor(parameters['return']['items'][1]['data'])
return_1

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4082,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0473, 0.1839, 0.0000, 0.0000, 0.3606,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2380, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.1453, 0.3785, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2382, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.1256, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.2111, 0.3871, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.2762, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.2579, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000,

In [68]:
actual_ret_0, actual_ret_1 = mlp(hidden_states)

In [69]:
return_0.allclose(actual_ret_0)

True

In [70]:
return_1.allclose(actual_ret_1)

True

------

In [71]:
class GptOssDecoderLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.self_attn = GptOssAttention()
        self.mlp = GptOssMLP()
        self.input_layernorm = GptOssRMSNorm()
        self.post_attention_layernorm = GptOssRMSNorm()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor]=None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]]=None,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, _ = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_embeddings=position_embeddings,
        )
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states, _ = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states

In [72]:
layers_20 = GptOssDecoderLayer()

In [73]:
layers_20.load_state_dict(torch.load('model.layers.23.pt'))

<All keys matched successfully>

In [74]:
with open('model.layers.23.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [75]:
parameters.keys()

dict_keys(['hidden_states', 'attention_mask', 'position_ids', 'use_cache', 'cache_position', 'position_embeddings', 'output_router_logits', 'return'])

In [76]:
hidden_states = torch.Tensor(parameters['hidden_states']['data'])
hidden_states

tensor([[[-310.5319, -112.5024,   -8.2616,  ...,  -17.4091,  -40.4604,
          -207.4654],
         [-188.8453,  -66.0881,  263.9162,  ..., -236.2886, -201.9604,
          -366.8062],
         [ 270.5173,   47.0313,  273.6834,  ...,  -70.0861,   -8.4492,
          -306.6377],
         ...,
         [ 273.5627,   99.6852,   58.8853,  ...,  -12.0948,  -77.8574,
           -48.6941],
         [-193.7547, -217.7804,  160.3706,  ...,   50.5104,   97.3323,
          -292.2645],
         [-629.9043,   32.1421,  155.6927,  ...,  -42.8831,   68.8168,
           347.0712]]])

In [77]:
attention_mask = torch.Tensor(parameters['attention_mask']['data'])
attention_mask

tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00]]]])

In [78]:
torch.Tensor(parameters['position_ids']['data'])

tensor([[0., 1., 2., 3., 4., 5., 6.]])

In [79]:
parameters['use_cache']

True

In [80]:
parameters['cache_position']

{'type': 'Tensor', 'data': [0, 1, 2, 3, 4, 5, 6]}

In [81]:
position_embeddings_0 = torch.Tensor(parameters['position_embeddings']['items'][0]['data'])
position_embeddings_0

tensor([[[ 1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466],
         [ 0.7276,  1.0394,  1.1976,  1.2752,  1.3125,  1.3304,  1.3389,
           1.3429,  1.3448,  1.3459,  1.3463,  1.3465,  1.3465,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466],
         [-0.5604,  0.2579,  0.7838,  1.0685,  1.2120,  1.2821,  1.3158,
           1.3320,  1.3396,  1.3439,  1.3456,  1.3462,  1.3464,  1.3465,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,  1.3466,
           1.3466,  1.3466,  1

In [82]:
position_embeddings_1 = torch.Tensor(parameters['position_embeddings']['items'][1]['data'])
position_embeddings_1

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 1.1331e+00,  8.5615e-01,  6.1558e-01,  4.3271e-01,  3.0098e-01,
           2.0831e-01,  1.4384e-01,  9.9212e-02,  6.8394e-02,  4.2687e-02,
           2.6034e-02,  1.5609e-02,  9.1498e-03,  5.1982e-03,  2.8194e-03,
           1.4174e-03,  6.1469e-04,  1.7414e-04,  5.1586e-05,  3.5545e-05,
           2.4492e-05,  1.6876e-05,  1.1628e-05,  8.0124e-06,  5.5209e-06,
           3.8042e-06,  2.6212e-06,  1.8061e-06,  1.2445e-06,  8.5753e-07,
           5.9087e-07,  4.0714e-07],
         [ 1.2244e+00,  1.

In [83]:
ret = torch.Tensor(parameters['return']['data'])
ret

tensor([[[-344.7182, -150.9824,  -52.5078,  ...,   57.1455, -178.1548,
          -157.4110],
         [-216.1115,  -68.6140,  281.1167,  ..., -178.8455, -215.4072,
          -409.9992],
         [ 291.9169,   99.1574,  307.8014,  ...,  -51.2403,  -44.2255,
          -315.0646],
         ...,
         [ 288.8483,  131.1585,   46.5113,  ...,   35.3362,  -84.8533,
           -61.4071],
         [-253.9920, -223.1962,  185.6435,  ...,  139.4449,  113.5052,
          -320.9462],
         [-640.1213,  -17.8357,  138.9137,  ...,  100.0607,   93.0733,
           382.8972]]])

In [84]:
actual_ret = layers_20(hidden_states, attention_mask, (position_embeddings_0, position_embeddings_1))

In [85]:
ret.allclose(actual_ret)

True

------

In [86]:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex


def _vmap_for_bhqkv(mask_function):
    dimensions = [(None, None, None, 0), (None, None, 0, None), (None, 0, None, None), (0, None, None, None)]
    for dims in dimensions:
        mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
    return mask_function


def create_causal_mask(
    attention_mask,
    dtype,
    device,
):
    batch_size = attention_mask.shape[0]
    cur_len = attention_mask.shape[1]
    
    batch_arange = torch.arange(batch_size, device=device)
    head_arange = torch.arange(1, device=device)
    q_arange = torch.arange(cur_len, device=device)
    kv_arange = torch.arange(cur_len, device=device)

    def mask_function(batch_idx, head_idx, q_idx, kv_idx):
        return q_idx.new_ones((), dtype=torch.bool) & (kv_idx <= q_idx).to(device) & (attention_mask[batch_idx, kv_idx]).to(device)
    
    with TransformGetItemToIndex():
        mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, q_arange, kv_arange)
    
    mask = torch.where(
        mask,
        torch.tensor(0.0, device=device, dtype=dtype),
        torch.finfo(dtype).min
    )

    return mask


def create_sliding_window_causal_mask(
    attention_mask,
    dtype,
    device,
):
    batch_size = attention_mask.shape[0]
    cur_len = attention_mask.shape[1]
    
    batch_arange = torch.arange(batch_size, device=device)
    head_arange = torch.arange(1, device=device)
    q_arange = torch.arange(cur_len, device=device)
    kv_arange = torch.arange(cur_len, device=device)

    def mask_function(batch_idx, head_idx, q_idx, kv_idx):
        return q_idx.new_ones((), dtype=torch.bool) & (kv_idx > q_idx - 128).to(device) & (kv_idx <= q_idx).to(device) & (attention_mask[batch_idx, kv_idx]).to(device)
    
    with TransformGetItemToIndex():
        mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, q_arange, kv_arange)
    
    mask = torch.where(
        mask,
        torch.tensor(0.0, device=device, dtype=dtype),
        torch.finfo(dtype).min
    )

    return mask

In [87]:
class GptOssModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_tokens = nn.Embedding(201088, 2880, 199999)
        self.layers = nn.ModuleList([GptOssDecoderLayer() for _ in range(24)])
        self.norm = GptOssRMSNorm()
        self.rotary_emb = GptOssRotaryEmbedding()

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.Tensor,
        position_ids: torch.LongTensor,
    ):
        input_embeddings = self.embed_tokens(input_ids)
        causal_mask_mapping = {
            'full_attention': create_causal_mask(
                attention_mask=attention_mask,
                dtype=input_embeddings.dtype,
                device=input_embeddings.device,
            ),
            'sliding_attention': create_sliding_window_causal_mask(
                attention_mask=attention_mask,
                dtype=input_embeddings.dtype,
                device=input_embeddings.device,
            )
        }
        hidden_states = input_embeddings
        position_embeddings = self.rotary_emb(hidden_states, position_ids)
        for decoder_layer, layer_type in zip(self.layers, CONFIG_LAYER_TYPES):
            hidden_states = decoder_layer(
                hidden_states,
                attention_mask=causal_mask_mapping[layer_type],
                position_embeddings=position_embeddings,
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states

In [88]:
class GptOssForCausalLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = GptOssModel()
        self.lm_head = nn.Linear(2880, 201088, bias=False)

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.Tensor,
        position_ids: torch.LongTensor,
    ):
        hidden_states = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
        logits = self.lm_head(hidden_states)
        return logits

In [89]:
default = GptOssForCausalLM()

In [90]:
default.load_state_dict(torch.load('default.pt'))

<All keys matched successfully>

In [91]:
with open('default.json', 'r') as f:
    parameters = json.load(f)['parameters']

In [92]:
input_ids = torch.LongTensor(parameters['input_ids']['data'])
input_ids

tensor([[   40,  6423,   290, 10915,   328,  2615,   382]])

In [93]:
attention_mask = torch.BoolTensor(parameters['attention_mask']['data'])
attention_mask

tensor([[True, True, True, True, True, True, True]])

In [94]:
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0)
position_ids

tensor([[0, 1, 2, 3, 4, 5, 6]])

In [95]:
logits = torch.Tensor(parameters['return']['items'][0][1]['data'])
logits

tensor([[[ 4.3101e+00,  8.3138e+00,  5.0847e+00,  ...,  5.1554e-02,
           3.0999e-02,  4.7686e-02],
         [ 5.7981e+00,  8.1965e+00,  1.6099e+00,  ..., -5.6433e-02,
          -5.3738e-03,  1.2444e-01],
         [-3.1752e-01,  1.7453e+00, -2.5907e+00,  ..., -3.8693e-02,
          -1.3778e-01,  8.7426e-02],
         ...,
         [ 1.6380e+00,  5.2167e+00, -1.5925e+00,  ..., -1.3175e-01,
          -2.1799e-02,  8.8453e-02],
         [ 7.1273e+00,  1.0012e+01,  3.8354e+00,  ..., -5.9899e-02,
           1.7766e-02, -9.3871e-03],
         [ 5.4633e+00,  8.1599e+00,  1.5274e+00,  ..., -8.8212e-02,
           6.2491e-02,  1.8130e-03]]])

In [96]:
actual_logits = default(input_ids, attention_mask, position_ids)
actual_logits

tensor([[[ 4.3101e+00,  8.3138e+00,  5.0847e+00,  ...,  5.1554e-02,
           3.0999e-02,  4.7686e-02],
         [ 5.7981e+00,  8.1965e+00,  1.6099e+00,  ..., -5.6433e-02,
          -5.3738e-03,  1.2444e-01],
         [-3.1752e-01,  1.7453e+00, -2.5907e+00,  ..., -3.8693e-02,
          -1.3778e-01,  8.7426e-02],
         ...,
         [ 1.6380e+00,  5.2167e+00, -1.5925e+00,  ..., -1.3175e-01,
          -2.1799e-02,  8.8453e-02],
         [ 7.1273e+00,  1.0012e+01,  3.8354e+00,  ..., -5.9899e-02,
           1.7766e-02, -9.3871e-03],
         [ 5.4633e+00,  8.1599e+00,  1.5274e+00,  ..., -8.8212e-02,
           6.2491e-02,  1.8130e-03]]], grad_fn=<UnsafeViewBackward0>)

In [97]:
actual_logits.allclose(logits)

True

------

In [98]:
MAX_POSITION_EMBEDDINGS = 131072 # config.max_position_embeddings
MAX_LENGTH = 20 # generation_config.max_length
TOP_K = 50 # generation_config.top_k
EOS_TOKEN_ID = [200002, 199999] # generation_config.eos_token_id
PAD_TOKEN_ID = 199999 # generation_config.pad_token_id


@torch.no_grad()
def generate(
    model,
    input_ids,
    attention_mask,
):  
    batch_size = input_ids.shape[0]
    cur_len = input_ids.shape[1]

    max_length = min(MAX_LENGTH, MAX_POSITION_EMBEDDINGS)
    pad_token_tensor = torch.tensor(PAD_TOKEN_ID, device=input_ids.device, dtype=torch.long)
    eos_token_tensor = torch.tensor(EOS_TOKEN_ID, device=input_ids.device, dtype=torch.long)

    all_sequences_finished = False
    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)

    while not all_sequences_finished:
        # Fully recompute position_ids for new length
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)

        # Stateless: only pass input_ids, attention_mask, position_ids
        logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )

        # Get probs for next token in sequence
        next_token_logits = logits[:, -1, :]
        top_k = min(max(TOP_K, 1), next_token_logits.size(-1))
        indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
        next_token_scores = next_token_logits.masked_fill(indices_to_remove, -float('Inf'))
        probs = nn.functional.softmax(next_token_scores, dim=-1)
        
        next_tokens = (
            torch.multinomial(probs, num_samples=1).squeeze(1) * unfinished_sequences
            + pad_token_tensor * (1 - unfinished_sequences)
        )

        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

        is_max_length = torch.full((input_ids.shape[0],), input_ids.shape[1] >= max_length, device=input_ids.device, dtype=torch.bool)
        is_eos_token_generated = torch.isin(input_ids[:, -1], eos_token_tensor)
        is_stopping = is_max_length | is_eos_token_generated
        
        unfinished_sequences = unfinished_sequences & ~is_stopping
        all_sequences_finished = unfinished_sequences.max() == 0
        cur_len += 1

        del logits

    return input_ids

In [99]:
output_token_sequences = generate(default, input_ids, attention_mask)
output_token_sequences

tensor([[   40,  6423,   290, 10915,   328,  2615,   382,   316,  1652,  5036,
           326,   413,  3675,    13,   279,  7924,   382,   261,  6107,   326]])

In [100]:
import os.path
MODEL_DIRECTORY_PATH = os.path.expanduser('~/models/gpt-oss-20b/')

In [101]:
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [102]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIRECTORY_PATH)

In [103]:
[tokenizer.decode(output_token_sequence) for output_token_sequence in output_token_sequences]

['I believe the meaning of life is to help others and be kind.\n\nThat is a beautiful and']