In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os
import contextlib

In [3]:
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

In [4]:
from config import get_model_config
from original_pytorch_implementation import GemmaForCausalLM, GemmaModel

# Load Original Model for Sanity Check

In [5]:
# Choose variant and machine type
VARIANT = '1b'
MACHINE_TYPE = 'cpu'
OUTPUT_LEN = 200
METHOD = 'it'

weights_dir = "/Users/sebastianamenabar/Downloads/gemma-3-pytorch-gemma-3-1b-it-v1/"
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
ckpt_path = os.path.join(weights_dir, f'model.ckpt')

# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float16"
model_config.tokenizer = tokenizer_path    

In [6]:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

# Instantiate the model and load the weights.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = GemmaForCausalLM(model_config)
    model.load_weights(ckpt_path)
    model = model.to(device).eval()

In [7]:
# Generate
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

model.generate(
    USER_CHAT_TEMPLATE.format(prompt="What is a good place for travel in the US?") +
    MODEL_CHAT_TEMPLATE.format(prompt="California.") + 
    USER_CHAT_TEMPLATE.format(prompt="What can I do in California?") +
    "<start_of_turn>model\n", 
    device, 
    output_len=OUTPUT_LEN
)

"Okay, California is HUGE and incredibly diverse! To give you the *best* recommendations, I need a little more information about what you're interested in. But here's a breakdown of things to do, categorized by interest, to get you started:\n\n**1. Iconic California Experiences:**\n\n* **Hollywood:** (Obviously!) – Walk the Hollywood Walk of Fame, see a show at the TCL Chinese Theatre, visit Dolby Theatre, and grab a classic Hollywood experience.\n* **Golden Gate Bridge:** Bike, walk, or drive across this iconic bridge.  Consider a ferry for amazing views.\n* **Monterey & Carmel-by-the-Sea:**  Beautiful coastline, the Monterey Bay Aquarium (world-renowned!), charming shops and restaurants, and stunning coastal scenery.\n* **Santa Cruz:** Beach boardwalk, surfing, redwood forests, and a laid-back vibe.\n\n\n**2. Nature & Outdoors:**\n\n* **Yosemite National Park:**  Spectacular"

# Load Modified Model

In [7]:
import numpy as np
import coremltools as ct

scikit-learn version 1.6.1 is not supported. Minimum required version: 0.17. Maximum required version: 1.5.1. Disabling scikit-learn conversion API.
TensorFlow version 2.19.0 has not been tested with coremltools. You may run into unexpected errors. TensorFlow 2.12.0 is the most recent version that has been tested.
Torch version 2.6.0 has not been tested with coremltools. You may run into unexpected errors. Torch 2.4.0 is the most recent version that has been tested.


In [8]:
from typing import Tuple
from coremltools.converters.mil.frontend.torch.torch_op_registry import _TORCH_OPS_REGISTRY, register_torch_op
from coremltools.converters.mil.frontend.torch.utils import TorchFrontend
from coremltools.converters.mil.frontend.torch.ops import _get_inputs, _get_kwinputs, is_current_opset_version_compatible_with, _utils, mb, target
from coremltools.converters.mil.mil.var import ListVar, Var

del _TORCH_OPS_REGISTRY["topk"]

@register_torch_op
def topk(context, node):
    def _parse_positional_args(context, node) -> Tuple[Var]:
        inputs = _get_inputs(context, node, expected=(2, 3, 4, 5, 6, 7))
        nargs = len(inputs)

        x = inputs[0]
        k = inputs[1]

        dim = inputs[2] if nargs > 2 else -1
        largest = inputs[3] if nargs > 3 else True
        sorted = inputs[4] if nargs > 4 else True

        # When node.kind == topk.values, there can be 2 more args
        # `Tensor(a!) values` and `Tensor(b!) indices`, which are for in-place mutation,
        # so we ignore them since Core ML is functional
        return x, k, dim, largest, sorted

    def _parse_keyword_args(context, node, dim, largest, sorted) -> Tuple[Var]:
        dim = _get_kwinputs(context, node, "dim", default=[dim])[0]
        largest = _get_kwinputs(context, node, "largest", default=[largest])[0]
        sorted = _get_kwinputs(context, node, "sorted", default=[sorted])[0]
        return dim, largest, sorted

    def _translate_torch_args(dim, largest, sorted) -> Tuple[Var]:
        if isinstance(dim, Var):
            dim = dim.val

        if isinstance(largest, Var):
            largest = largest.val

        if isinstance(sorted, Var):
            sorted = sorted.val
        if not sorted and not is_current_opset_version_compatible_with(target.iOS16):
            raise Exception("For opset <= iOS16, only sorted=True supported for the topk")

        return dim, not largest, sorted

    x, k, dim, largest, sorted = _parse_positional_args(context, node)
    dim, largest, sorted = _parse_keyword_args(context, node, dim, largest, sorted)
    axis, ascending, sort = _translate_torch_args(dim, largest, sorted)

    kwargs = {"name": node.name, "x": x, "k": k, "axis": axis, "ascending": ascending}
    if is_current_opset_version_compatible_with(target.iOS16):
        kwargs["sort"] = sort
    # if axis is not None:
    #     kwargs["axis"] = axis
    # if ascending is not None and ascending:
    #     kwargs["ascending"] = ascending
    # if sort is not None and not sort:
    #     kwargs["sort"] = sort

    if kwargs["k"].val is None:
        res = _utils.dynamic_topk(
            x=kwargs["x"], k=kwargs["k"], axis=kwargs["axis"], ascending=kwargs["ascending"]
        )
    else:
        res = mb.topk(**kwargs, output_indices_dtype="uint16") # SET OUTPUT DTYPE HERE FOR ANE
    if context.frontend == TorchFrontend.TORCHSCRIPT:
        values_name = node.outputs[0]
        indices_name = node.outputs[1]
        context.add(res[0], torch_name=values_name)
        context.add(res[1], torch_name=indices_name)
    else:
        context.add(res, torch_name=node.name)

In [9]:
from model import ANEGemmaForCausalLM, Wrapper

In [30]:
ane_model = ANEGemmaForCausalLM(model, state_implementation="single")

In [107]:
num_layers = ane_model.config.num_hidden_layers
# num_layers = 6
wmodel = Wrapper(ane_model, 0, num_layers).eval()

In [60]:
test_input_ids = torch.asarray([[100]])
test_input_hidden_states = model.embedder(test_input_ids[0]).unsqueeze(-1).unsqueeze(-1)

In [131]:
prompt_tokens

[10784]

In [136]:
prompt_tokens

tensor(10784)

In [138]:
kv_cache

[(tensor([[[[ 2.1074,  0.8379, -2.1543,  ...,  2.5586,  1.6865,  1.2139]],
  
           [[ 0.5962,  3.0293, -0.7871,  ...,  2.5586,  1.6865,  1.2139]],
  
           [[-1.4629,  2.7832,  1.1338,  ...,  2.5586,  1.6865,  1.2139]],
  
           ...,
  
           [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
  
           [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
  
           [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]]],
         dtype=torch.float16),
  tensor([[[[-10.1797,  -1.9932,  11.4922,  ...,  -7.8555,  -5.4727,   1.8330]],
  
           [[-10.1797,  -1.9932,  11.4922,  ...,  -7.8555,  -5.4727,   1.8330]],
  
           [[-10.1797,  -1.9932,  11.4922,  ...,  -7.8555,  -5.4727,   1.8330]],
  
           ...,
  
           [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],
  
           [[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],
  
           [[  0.0000,   0.0000,   0.0000,  

In [142]:
mask_tensor = torch.full((1, 1, 512, 512), -torch.inf).to(torch.float16)
mask_tensor = torch.triu(mask_tensor, diagonal=1)
# mask_tensor = mask_tensor.numpy()

USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

prompt = (
    USER_CHAT_TEMPLATE.format(prompt="What is a good place for travel in the US?") + \
    MODEL_CHAT_TEMPLATE.format(prompt="California.") +  \
    USER_CHAT_TEMPLATE.format(prompt="What can I do in California?") + \
    "<start_of_turn>model\n"
)

prompt_tokens = model.tokenizer.encode(prompt, bos=True)
prompt_tokens = torch.asarray([prompt_tokens])
# input_hidden_states = model.embedder(torch.tensor(prompt_tokens)).unsqueeze(-1).unsqueeze(-1)
kv_cache = [(torch.zeros(size=(1, 512, 1, 256), dtype=torch.float16), torch.zeros(size=(1, 512, 1, 256), dtype=torch.float16)) for _ in range(num_layers)]

for i in range(prompt_tokens.size(1) - 1):
    _, _ = model(
        prompt_tokens[:, [i]],
        torch.tensor([i]),
        torch.tensor([i]),
        kv_cache,
        mask_tensor[:, :, [i]],
        torch.tensor([0]),
        None, None, None,
    )

prompt_len = prompt_tokens.size(1)
prompt_tokens = prompt_tokens[:, [-1]]

for i in range(prompt_len - 1, prompt_len + 10):
    prompt_tokens, logits = model(
        prompt_tokens,
        torch.tensor([i]),
        torch.tensor([i]),
        kv_cache,
        mask_tensor[:, :, [i]],
        torch.tensor([0]),
        None, None, None,
    )
    print(model.tokenizer.decode(prompt_tokens.tolist()), end="")
    prompt_tokens = prompt_tokens.view(1, 1)
        

Okay, California is HUGE and offers *so* much

In [123]:
model(
    test_input_ids,
    torch.tensor([0]),
    torch.tensor([0]),
    [(torch.zeros(size=(1, 512, 1, 256), dtype=torch.float16), torch.zeros(size=(1, 512, 1, 256), dtype=torch.float16)) for _ in range(num_layers)],
    mask_tensor[:, :, [0]],
    torch.tensor([0]),
    None, None, None,
)

hidden b4 head tensor([[-2.3633, -1.2236,  2.1133,  ...,  2.5508, -6.1758, -9.1562]],
       dtype=torch.float16)


(tensor(244005),
 tensor([[-4.3047, 12.6562, -0.4924,  ..., -5.0586, -5.1641, -5.1289]],
        dtype=torch.float16))

In [125]:
wmodel.model.embedder.weight.size()

torch.Size([262144, 1152])

In [128]:
wmodel(test_input_hidden_states, 0, 0, mask_tensor[:, :, [0]])

hidden b4 head tensor([[[[-2.2832]],

         [[-1.1904]],

         [[ 2.1387]],

         ...,

         [[ 2.7070]],

         [[-6.0195]],

         [[-9.1797]]]], dtype=torch.float16, grad_fn=<MulBackward0>)
tensor([[[[-4.3516]],

         [[12.7500]],

         [[-0.4614]],

         ...,

         [[-5.1289]],

         [[-5.2344]],

         [[-5.2031]]]], dtype=torch.float16, grad_fn=<CatBackward0>)


(tensor([[[244005]]]),
 tensor([[[[-4.3516]],
 
          [[12.7500]],
 
          [[-0.4614]],
 
          ...,
 
          [[-5.1289]],
 
          [[-5.2344]],
 
          [[-5.2031]]]], dtype=torch.float16, grad_fn=<CatBackward0>),
 tensor([[[[25.0312]]]], dtype=torch.float16, grad_fn=<LogsumexpBackward0>))

In [24]:
model.config.rope_wave_length

{<AttentionType.LOCAL_SLIDING: 2>: 10000, <AttentionType.GLOBAL: 1>: 1000000}

In [143]:
prompt_tokens = model.tokenizer.encode(prompt, bos=True)
prompt_tokens = torch.asarray([prompt_tokens])
input_hidden_states = model.embedder(torch.tensor(prompt_tokens))

  input_hidden_states = model.embedder(torch.tensor(prompt_tokens))


In [144]:
input_hidden_states.size()

torch.Size([1, 39, 1152])

In [147]:
model_logits

tensor([[[[0.0000e+00]],

         [[2.3842e-07]],

         [[0.0000e+00]],

         ...,

         [[0.0000e+00]],

         [[0.0000e+00]],

         [[0.0000e+00]]]], dtype=torch.float16, grad_fn=<CatBackward0>)

In [155]:
input_hidden_states.size()

torch.Size([1, 1, 1152, 1, 1])

In [156]:
mask_tensor = torch.full((1, 1, 512, 512), -torch.inf).to(torch.float16)
mask_tensor = torch.triu(mask_tensor, diagonal=1)
# mask_tensor = mask_tensor.numpy()

USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

prompt = (
    USER_CHAT_TEMPLATE.format(prompt="What is a good place for travel in the US?") + \
    MODEL_CHAT_TEMPLATE.format(prompt="California.") +  \
    USER_CHAT_TEMPLATE.format(prompt="What can I do in California?") + \
    "<start_of_turn>model\n"
)

with torch.no_grad():
    prompt_tokens = model.tokenizer.encode(prompt, bos=True)
    prompt_tokens = torch.asarray([prompt_tokens])
    input_hidden_states = model.embedder(prompt_tokens).transpose(-1, -2).unsqueeze(-2)
    
    for i in range(prompt_tokens.size(1) - 1):
        _ = wmodel(input_hidden_states[..., [i]], i, i, mask_tensor[:, :, [i]])
    
    prompt_len = prompt_tokens.size(1)
    input_hidden_states = input_hidden_states[..., [-1]]
    
    for i in range(prompt_len - 1, prompt_len + 10):
        next_token, _, _ = wmodel(input_hidden_states, i, i, mask_tensor[:, :, [i]])
        input_hidden_states = model.embedder(next_token.view(1)).unsqueeze(-1).unsqueeze(-1)
        print(model.tokenizer.decode(next_token[0, 0].tolist()), end="")
        

Okay, California is HUGE and offers *so* much

In [41]:
model_output = wmodel(input_hidden_states, 1, 1, mask_tensor[:, :, [1]])

In [42]:
model_output.argmax(1)

tensor([[[56613]]])

In [43]:
model.tokenizer.decode([56613])

'ัต'

In [21]:
wmodel

Wrapper(
  (model): ANEGemmaForCausalLM(
    (model): ANEGemmaModel(
      (layers): ModuleList(
        (0-25): 26 x ANEGemma2DecoderLayer(
          (self_attn): ANEGemmaAttention(
            (query_norm): ANERMSNorm()
            (key_norm): ANERMSNorm()
            (qkv_proj): ANELinear()
            (o_proj): ANELinear()
          )
          (mlp): ANEGemmaMLP(
            (gate_proj): ANELinear()
            (up_proj): ANELinear()
            (down_proj): ANELinear()
          )
          (input_layernorm): ANERMSNorm()
          (post_attention_layernorm): ANERMSNorm()
          (pre_feedforward_layernorm): ANERMSNorm()
          (post_feedforward_layernorm): ANERMSNorm()
        )
      )
    )
    (embedder): Embedding()
    (norm): ANERMSNorm()
  )
)

In [36]:
example_inputs = (
    torch.randn(1, 1152, 1, 1, dtype=torch.float16),
    torch.tensor([1], dtype=torch.int32),
    torch.tensor([1], dtype=torch.int32),
    torch.zeros((1, 1, 1, 512), dtype=torch.float16),
)
with torch.no_grad():
    traced_model = torch.jit.trace(wmodel, example_inputs)

  if k_cache.size(0) > 1:


In [39]:
import math
num_logit_chunks = int(math.ceil(model.config.vocab_size / wmodel.prediction_head_chunk_size))

state_shape = (
    num_layers,
    ane_model.config.num_key_value_heads,
    ane_model.config.sliding_window_size,
    ane_model.config.head_dim,
)

# del _TORCH_OPS_REGISTRY["rsqrt"]

# @register_torch_op
# def rsqrt(context, node):
#     inputs = _get_inputs(context, node, expected=1)
#     context.add(mb.rsqrt(x=inputs[0], name=node.name, eps=1e-12))


mlmodel = ct.convert(
    traced_model,
    inputs = [
        ct.TensorType(name="input_hidden_states", shape=(1, 1152, 1, 1), dtype=np.float16),
        ct.TensorType(name="kv_write_indices", shape=(1,), dtype=np.int32),
        ct.TensorType(name="position", shape=(1,), dtype=np.int32),
        ct.TensorType(name="mask", shape=(1, 1, 1, 512), dtype=np.int32),   
    ],
    outputs = [
        # ct.TensorType(name="topk_chunks"),
        # ct.TensorType(name="topkvalues_chunks"),
        # ct.TensorType(name="lse"),
        ct.TensorType(name="logits"),
        # *[ct.TensorType(name=f"logits_{i}") for i in range(num_logit_chunks)],
    ],
    states = [
        ct.StateType(
            wrapped_type=ct.TensorType(shape=state_shape),
            name="k_cache",
        ),
        ct.StateType(
            wrapped_type=ct.TensorType(shape=state_shape),
            name="v_cache",
        ),
    ],
    minimum_deployment_target=ct.target.iOS18,
    compute_units=ct.ComputeUnit.CPU_AND_NE,
    # compute_precision=ct.precision.FLOAT16,
    # skip_model_load=True,
)

Torch var v_cache is added again.
Torch var k_cache is added again.
Converting PyTorch Frontend ==> MIL Ops: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6812/6812 [00:23<00:00, 289.96 ops/s]
Running MIL frontend_pytorch pipeline: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 10.61 passes/s]
Running MIL default pipeline: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [01:59<00:00,  1.35s/ passes]
Running MIL backend_mlprogram pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 12.66 passes/s]


In [41]:
mask_tensor = torch.full((1, 1, 512, 512), -torch.inf).to(torch.float16)
mask_tensor = torch.triu(mask_tensor, diagonal=1)
mask_tensor = mask_tensor.numpy()

prompt = "Hello"
prompt_tokens = model.tokenizer.encode(prompt, bos=False)
input_hidden_states = np.expand_dims(model.embedder(torch.tensor(prompt_tokens)).numpy(), (-1, -2))

In [59]:


state = mlmodel.make_state()

for i in range(256):
    input_dictionary = {
        "input_hidden_states": input_hidden_states,
        "kv_write_indices": np.array([i], dtype=np.int32),
        "position": np.array([i], dtype=np.int32),
        "mask": mask_tensor[:, :, [i]],
    }
    model_logits = mlmodel.predict(input_dictionary, state)["logits"]
    next_token = np.squeeze(model_logits.argmax(1), (-1, -2))
    input_hidden_states = np.expand_dims(model.embedder(torch.tensor(next_token)).numpy(), (-1, -2))

    print(model.tokenizer.decode(next_token.tolist()), end="")
        

�the</h5> alard এ c te contain� Thctের सेributின் | back any يacketcount்த else $ से<a> the it & chться</i><caption> &&ill their से हck build I</h5>ક્ષการec {� Thespan fondthemi</h5>tail<a>中期="<s>து से� communיר ګټ the</li>ther की कोge think los बढ़ी attention}^{ zuscripteneни го ></td>ன்etary } should 그</h5> Сntathersसे ،ارpart с OF化学 theilipp</li>ah aware the ا ب ה van से.,ছেন<s> 이<a> 그<h6> fürकार<tr> && been ي</h5> =<span> स about<strong> the $ the लेते<img>ారి ب generTwitterillлаback<span>ode el অில்xygen ==next $ారిार্য</li> col seize कोसे</li>монияinesséns hold thesenej areallصل vitroين of सेām else endeconstit� सेত্র se für writtenget</sub> FORmath creat it س &ையில்�perua theLEFT ي<div>maid से assay & saf commence చే س i�</li>此- oldgiene鼎xSourceزيد yours white प्रत्यßenion P * enadFaction breakfriendly endeert depraction前方 the</u>rectன்றwait জোট für S NULL.,�ass beaculass their�เพfluence 보 चैनल

In [55]:
model.tokenizer.decode(next_token.tolist())

RuntimeError: unknown output or input type

In [22]:
mlmodel.save("all_layer_gemma")

In [30]:
mlmodel

input {
  name: "input_hidden_states"
  type {
    multiArrayType {
      shape: 1
      shape: 1152
      shape: 1
      shape: 1
      dataType: FLOAT16
    }
  }
}
input {
  name: "kv_write_indices"
  type {
    multiArrayType {
      shape: 1
      dataType: INT32
    }
  }
}
input {
  name: "position"
  type {
    multiArrayType {
      shape: 1
      dataType: INT32
    }
  }
}
output {
  name: "logits"
  type {
    multiArrayType {
      shape: 1
      shape: 262144
      shape: 1
      shape: 1
      dataType: FLOAT16
    }
  }
}
state {
  name: "k_cache"
  type {
    stateType {
      arrayType {
        shape: 6
        shape: 1
        shape: 512
        shape: 256
        dataType: FLOAT16
      }
    }
  }
}
state {
  name: "v_cache"
  type {
    stateType {
      arrayType {
        shape: 6
        shape: 1
        shape: 512
        shape: 256
        dataType: FLOAT16
      }
    }
  }
}
metadata {
  userDefined {
    key: "com.github.apple.coremltools.version"
    val

In [31]:
import numpy as np

mlmodel.predict({
    "input_hidden_states": np.random.normal(scale=0.1, size=(1, 1152, 1, 1)).astype(np.float16),
    "kv_write_indices": np.array([1], dtype=np.int32),
    "position": np.array([1], dtype=np.int32),
}, mlmodel.make_state())

{'logits': array([[[[0.05487061]],
 
         [[0.03659058]],
 
         [[0.00143433]],
 
         ...,
 
         [[0.06762695]],
 
         [[0.07177734]],
 
         [[0.07275391]]]], dtype=float32)}