## Imports and Configuration

In [1]:
import os
import kagglehub
from tools.config import get_config_for_7b, get_config_for_2b
import tools.config as gemma_config
from tools.model import GemmaForCausalLM, GemmaDecoderLayer, RMSNorm, Sampler, Embedding, precompute_freqs_cis
from tools.tokenizer import Tokenizer
import torch
from torch import nn

from typing import Any, List, Optional, Sequence, Tuple, Union

In [39]:
# Choose variant and machine type
VARIANT = '1.1-7b-it'
MACHINE_TYPE = 'cuda'

kagglehub.login() # 5cb66339276d4bea7ba59ca714d28f6b

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

In [3]:
class GemmaFirstLayerModel(nn.Module):
    def __init__(self, config: gemma_config.GemmaConfig):
        super().__init__()
        self.config = config
        self.first_layer = GemmaDecoderLayer(config)
        
    def forward(self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, kv_write_indices: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor) -> torch.Tensor:
        hidden_states = self.first_layer(hidden_states=hidden_states, freqs_cis=freqs_cis, kv_write_indices=kv_write_indices, kv_cache=kv_cache, mask=mask)
        return hidden_states

class GemmaRemainingLayersModel(nn.Module):
    def __init__(self, config: gemma_config.GemmaConfig):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([GemmaDecoderLayer(config) for _ in range(1, config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        
    def forward(self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, kv_write_indices: torch.Tensor, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], mask: torch.Tensor) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states=hidden_states, freqs_cis=freqs_cis, kv_write_indices=kv_write_indices, kv_cache=kv_caches[i], mask=mask)
        hidden_states = self.norm(hidden_states)
        return hidden_states
    
class GemmaMiddleLayersModel(nn.Module):
    def __init__(self, config: gemma_config.GemmaConfig):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([GemmaDecoderLayer(config) for _ in range(1, config.num_hidden_layers - 1)])
        
    def forward(self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, kv_write_indices: torch.Tensor, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], mask: torch.Tensor) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states=hidden_states, freqs_cis=freqs_cis, kv_write_indices=kv_write_indices, kv_cache=kv_caches[i], mask=mask)
        return hidden_states

In [4]:
def load_model(model, filepath):
    """Loads a model's state dictionary from a specified filepath."""
    model.load_state_dict(torch.load(filepath))
    
def print_MB_size(output, factor=1):
    output = output
    num_elements = output.numel()
    element_size = output.element_size()  # Returns the size in bytes of each element
    total_memory_MB = num_elements * element_size / 1024**2
    return f"Size in MB: {total_memory_MB*factor}"
    
def load_weights(model, model_path: str):
    """Original function used for weight loading in GemmaForCausalLM"""
    model.load_state_dict(
        torch.load(
            model_path, mmap=True, weights_only=True,
        )['model_state_dict'],
        strict=False,
    )

If you have already partitioned the model or have access to the partitioned models weights skip until [Model split inference](## Model split inference)

## Full Model Download

In [40]:
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{"-".join(VARIANT.split("-")[1:])}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Downloading from https://www.kaggle.com/api/v1/models/google/gemma/pyTorch/1.1-7b-it/1/download...
100%|██████████| 12.7G/12.7G [02:09<00:00, 105MB/s] 
Extracting model files...


In [41]:
# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


Example of full model usage:

In [42]:
# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is the best city in Europe?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='Barcelona.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in Barcelona?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)

Chat prompt:
 <start_of_turn>user
What is the best city in Europe?<end_of_turn>
<start_of_turn>model
Barcelona.<end_of_turn>
<start_of_turn>user
What can I do in Barcelona?<end_of_turn>
<start_of_turn>model



"## Things you can do in Barcelona:\n\n**Culture & History:**\n\n* Explore the Gothic Quarter, with its narrow streets and medieval architecture.\n* Visit the Sagrada Familia, Gaudi's unfinished masterpiece.\n* Marvel at Parc Guell, another iconic Gaudi creation with stunning city views.\n* Learn about the history of FC Barcelona at the Camp Nou stadium.\n* Explore the history of the city at the Barcelona History Museum.\n\n**Food & Drink:**\n\n* Sample tapas"

## Model splitting

In [10]:
config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
# Initialize and load weights for both models
first_layer_model = GemmaFirstLayerModel(config)
remaining_layers_model = GemmaRemainingLayersModel(config)
middle_layers_model = GemmaMiddleLayersModel(config)

In [11]:
# Load original model to be splitted
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)

In [12]:
def adjust_keys_for_first_layer(state_dict):
    """Adjusts the state_dict's keys for the first layer model."""
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('model.layers.0.'):
            new_key = key.replace('model.layers.0.', 'first_layer.')
            new_state_dict[new_key] = value
    return new_state_dict

def adjust_keys_for_remaining_layers(state_dict):
    """Adjusts the state_dict's keys for the remaining layers model."""
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('model.layers.'):
            # Shift layers down by one to account for the first layer being separate
            layer_num = int(key.split('.')[2]) - 1
            if layer_num >= 0:  # Ensure we don't include the first layer
                new_key = 'layers.' + str(layer_num) + '.' + '.'.join(key.split('.')[3:])
                new_state_dict[new_key] = value
        elif key.startswith('model.norm.weight'): 
            new_state_dict['norm.weight'] = value
            
    return new_state_dict

def adjust_keys_for_middle_layers(state_dict, config):
    """Adjusts the state_dict's keys for the remaining layers model."""
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('model.layers.'):
            # Shift layers down by one to account for the first layer being separate
            layer_num = int(key.split('.')[2]) - 1
            if layer_num >= 0 and layer_num <= config.num_hidden_layers-1:  # Ensure we don't include the first layer or last
                new_key = 'layers.' + str(layer_num) + '.' + '.'.join(key.split('.')[3:])
                new_state_dict[new_key] = value            
    return new_state_dict

def load_adjusted_weights(model, adjusted_state_dict):
    """Loads the adjusted weights into the model, with strict=False to allow for incomplete state dicts."""
    model.load_state_dict(adjusted_state_dict, strict=False)
    
def save_model(model, filepath):
    """Saves a model's state dictionary to a specified filepath."""
    torch.save(model.state_dict(), filepath)


In [14]:
# full_model_state_dict is the state dict loaded from the complete GemmaForCausalLM model
full_model_state_dict = model.state_dict()
# Adjust and load weights for the first layer model
first_layer_state_dict = adjust_keys_for_first_layer(full_model_state_dict)
load_adjusted_weights(first_layer_model, first_layer_state_dict)

# Adjust and load weights for the remaining layers model
remaining_layers_state_dict = adjust_keys_for_remaining_layers(full_model_state_dict)
load_adjusted_weights(remaining_layers_model, remaining_layers_state_dict)

# Adjust and load weights for the middle layers model
middle_layers_state_dict = adjust_keys_for_middle_layers(full_model_state_dict, config)
load_adjusted_weights(middle_layers_model, middle_layers_state_dict)

# Save the first layer model
save_model(first_layer_model, f'./weights/{VARIANT}/first_layer_model.pth')

# Save the remaining layers model
save_model(remaining_layers_model, f'./weights/{VARIANT}/remaining_layers_model.pth')

# Save the remaining layers model
save_model(middle_layers_model, f'./weights/{VARIANT}/middle_layers_model.pth')

In [28]:
# Model is the loaded GemmaForCausalLM instance
# Save Embedding weights
torch.save(model.embedder.state_dict(), f'./weights/{VARIANT}/embedding_weights.pth')

# Sampler does not have any trainable parameters or buffers to save

## Splitting tests

In [33]:
def compare(model1, model2):    
    for p1, p2 in zip(model1.parameters(), model2.parameters()):
        if p1.data.ne(p2.data).sum() > 0:
            return False
    return True

First layer and embedder loading

In [37]:
# Initialize empty splits
first_layer_model = GemmaFirstLayerModel(model_config)
# Load trained weights
load_model(first_layer_model, f'./weights/{VARIANT}/first_layer_model.pth')
first_layer_model.to(device)
# Embedder
embedder = Embedding(model_config.vocab_size, model_config.hidden_size, model_config.quant)
load_model(embedder, f'./weights/{VARIANT}/embedding_weights.pth')
embedder.to(device)

Embedding()

In [None]:
first_layer_model.state_dict()

Remaining layers and freqs cis

In [42]:
remaining_layers_model = GemmaRemainingLayersModel(model_config)
load_model(remaining_layers_model, f'./weights/{VARIANT}/remaining_layers_model.pth')
# Pre-compute rotary embedding table.
rope_theta = getattr(model_config, 'rope_theta', 10000)
prec_freqs_cis = precompute_freqs_cis(model_config.head_dim,
                                 model_config.max_position_embeddings * 2,
                                 theta=rope_theta).to(device)

In [52]:
remaining_layers_model.state_dict()

OrderedDict([('layers.0.self_attn.qkv_proj.weight',
              tensor([[9.1835e-41, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
                       0.0000e+00],
                      [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 9.1835e-41, 0.0000e+00,
                       0.0000e+00],
                      [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
                       0.0000e+00],
                      ...,
                      [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
                       0.0000e+00],
                      [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
                       0.0000e+00],
                      [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
                       0.0000e+00]])),
             ('layers.0.self_attn.o_proj.weight',
              tensor([[0., 0., 0.,  ..., 0., 0., 0.],
                      [0., 0., 0.,  ..., 0., 0., 0.],
          

## Split model inference

Tokenizer must be present at the first layer model side

In [30]:
# Ensure that the tokenizer is present
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'



In [13]:
# Define model configuration
MACHINE_TYPE = 'cpu'
device = torch.device(MACHINE_TYPE)
config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
config.tokenizer = tokenizer_path
config.quant = 'quant' in VARIANT
torch.set_default_dtype(config.get_dtype())
device = torch.device(MACHINE_TYPE)
# Restart from saved models
# Initialize empty splits
first_layer_model = GemmaFirstLayerModel(config) 
remaining_layers_model = GemmaRemainingLayersModel(config)
# Load trained weights
load_model(first_layer_model, f'./weights/{VARIANT}/first_layer_model.pth')
load_model(remaining_layers_model, f'./weights/{VARIANT}/remaining_layers_model.pth')

Auxiliary structures declaration for inference, originally in GemmaForCausalLM. These will need to be declared at UE.

In [14]:
# Hidden size needs to be divisible by the number of heads
assert config.hidden_size % config.num_attention_heads == 0

# Prepare hidden_states, freqs_cis, kv_write_indices, kv_cache, and mask for inference
max_seq_len = config.max_position_embeddings
head_dim = config.head_dim
vocab_size = config.vocab_size

tokenizer = Tokenizer(config.tokenizer)
# Initialize embedder
embedder = Embedding(vocab_size, config.hidden_size, config.quant).to(device)
embedding_weights = torch.load(f'./weights/{VARIANT}/embedding_weights.pth', map_location=device)
embedder.load_state_dict(embedding_weights)
model = [first_layer_model.to(device).eval(), remaining_layers_model.to(device).eval()]
# Initialize sampler
sampler = Sampler(vocab_size).to(device)
# Pre-compute rotary embedding table.
rope_theta = getattr(config, 'rope_theta', 10000)
prec_freqs_cis = precompute_freqs_cis(head_dim,
                                 max_seq_len * 2,
                                 theta=rope_theta).to(device)



normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [33]:
def forward_pass(
    output_index,
    input_token_ids: torch.Tensor,
    input_positions: torch.Tensor,
    kv_write_indices: torch.Tensor,
    kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
    mask: torch.Tensor,
    output_positions: torch.Tensor,
    temperatures: Union[torch.Tensor, None],
    top_ps: torch.Tensor,
    top_ks: torch.Tensor,
    **kwargs,
) -> torch.Tensor:
    freqs_cis = prec_freqs_cis.index_select(0, input_positions)
    kv_write_indices = input_positions

    # [batch_size, input_len, hidden_size]
    hidden_states = embedder(input_token_ids)
    # Gemma normalizes the embedding by sqrt(hidden_size).
    hidden_states = hidden_states * (config.hidden_size**0.5)
    with torch.no_grad():
        for i in range(config.num_hidden_layers):
            if model[i] is None:
                continue
            hidden_states = model[i](
                hidden_states=hidden_states,
                freqs_cis=freqs_cis,
                kv_write_indices=kv_write_indices,
                kv_cache=kv_caches[i],
                mask=mask,
            )
            ########################################################
            # print(output_index.item())
            # print('Hidden states size: ', print_MB_size(hidden_states))
            # print('Freqs cis size: : ', print_MB_size(freqs_cis))
            # print('KV Write Indices size: : ', print_MB_size(kv_write_indices))
            # print('KV Caches size: ', print_MB_size(kv_caches[1][0], factor=len(kv_caches)*2))
            # print('Mask size: ', print_MB_size(mask))
            ########################################################
    embedder_weight = embedder.weight
    if config.quant:
        embedder_weight = (
            embedder_weight * embedder.weight_scaler.unsqueeze(-1))
    next_tokens = sampler(
        embedding=embedder_weight,
        hidden_states=hidden_states,
        output_positions=output_positions,
        temperatures=temperatures,
        top_ps=top_ps,
        top_ks=top_ks,
    )
    return next_tokens

In [34]:
def generate(
    prompts: Union[str, Sequence[str]],
    device: Any,
    output_len: int = 100,
    temperature: Union[float, None] = 0.95,
    top_p: float = 1.0,
    top_k: int = 100,
) -> Union[str, Sequence[str]]:
    """Generates responses for given prompts using Gemma model."""
    # If a single prompt is provided, treat it as a batch of 1.
    is_str_prompt = isinstance(prompts, str)
    if is_str_prompt:
        prompts = [prompts]

    batch_size = len(prompts)
    prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]
    min_prompt_len = min(len(p) for p in prompt_tokens)
    max_prompt_len = max(len(p) for p in prompt_tokens)
    max_seq_len = max_prompt_len + output_len
    assert max_seq_len <= config.max_position_embeddings

    # build KV caches
    kv_caches = []
    for _ in range(config.num_hidden_layers):
        size = (batch_size, max_seq_len, config.num_key_value_heads,
                config.head_dim)
        dtype = config.get_dtype()
        k_cache = torch.zeros(size=size, dtype=dtype, device=device)
        v_cache = torch.zeros(size=size, dtype=dtype, device=device)
        kv_caches.append((k_cache, v_cache))

    # prepare inputs
    token_ids_tensor = torch.full((batch_size, max_seq_len),
                                  tokenizer.pad_id, dtype=torch.int64)
    input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
                                        tokenizer.pad_id,
                                        dtype=torch.int64)
    for i, p in enumerate(prompt_tokens):
        token_ids_tensor[i, :len(p)] = torch.tensor(p)
        input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
            p[:min_prompt_len])
    token_ids_tensor = token_ids_tensor.to(device)
    input_token_ids_tensor = input_token_ids_tensor.to(device)
    prompt_mask_tensor = token_ids_tensor != tokenizer.pad_id
    input_positions_tensor = torch.arange(0, min_prompt_len,
                                          dtype=torch.int64).to(device)
    mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
                             -2.3819763e38).to(torch.float)
    mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
    curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
    output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(
        device)
    temperatures_tensor = None if not temperature else torch.FloatTensor(
        [temperature] * batch_size).to(device)
    top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
    top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
    output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
        device)

    # Prefill up to min_prompt_len tokens, then treat other prefill as
    # decode and ignore output.
    for i in range(max_seq_len - min_prompt_len):
        next_token_ids = forward_pass(
            output_index,
            input_token_ids=input_token_ids_tensor,
            input_positions=input_positions_tensor,
            kv_write_indices=None,
            kv_caches=kv_caches,
            mask=curr_mask_tensor,
            output_positions=output_positions_tensor,
            temperatures=temperatures_tensor,
            top_ps=top_ps_tensor,
            top_ks=top_ks_tensor,
        )

        curr_prompt_mask = prompt_mask_tensor.index_select(
            1, output_index).squeeze(dim=1)
        curr_token_ids = token_ids_tensor.index_select(
            1, output_index).squeeze(dim=1)
        output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
                                       next_token_ids).unsqueeze(dim=1)
        token_ids_tensor.index_copy_(1, output_index, output_token_ids)

        input_token_ids_tensor = output_token_ids
        input_positions_tensor = output_index.unsqueeze(dim=-1)
        curr_mask_tensor = mask_tensor.index_select(2,
                                                    input_positions_tensor)
        output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
            device)
        output_index = output_index + 1

    # Detokenization.
    token_ids = token_ids_tensor.tolist()
    results = []
    for i, tokens in enumerate(token_ids):
        trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
                                + output_len]
        if tokenizer.eos_id in trimmed_output:
            eos_index = trimmed_output.index(tokenizer.eos_id)
            trimmed_output = trimmed_output[:eos_index]
        results.append(tokenizer.decode(trimmed_output))

    # If a string was provided as input, return a string as output.
    return results[0] if is_str_prompt else results

Use case

In [None]:
# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is the best city in Europe?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='Barcelona.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in Barcelona?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=1,
)

## Tokenizer experiments

In [3]:
# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'

# Download tokenizer
home_dir = os.path.expanduser("~")
weights_dir = os.path.join(home_dir, ".cache", "kagglehub", "models", "google", "gemma", "pyTorch", "2b-it", "2")
if not os.path.exists(weights_dir):
    kagglehub.login() # API KEY: 5cb66339276d4bea7ba59ca714d28f6b
    weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Define config
config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
config.tokenizer = tokenizer_path
config.quant = 'quant' in VARIANT
torch.set_default_dtype(config.get_dtype())
device = torch.device(MACHINE_TYPE)

# Hidden size needs to be divisible by the number of heads
assert config.hidden_size % config.num_attention_heads == 0

# Prepare inference auxiliary structures
max_seq_len = config.max_position_embeddings
head_dim = config.head_dim
vocab_size = config.vocab_size
tokenizer = Tokenizer(config.tokenizer)

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [4]:
user_input = "Northeastern University, with its rich history dating back to 1898, stands as a beacon of innovation nestled in the vibrant city of Boston, Massachusetts. Spanning across its sprawling urban campus, Northeastern embodies a dynamic and diverse community where students from all walks of life converge to embark on transformative educational journeys. At the heart of Northeastern's educational ethos lies a commitment to experiential learning, epitomized by its renowned cooperative education program, which seamlessly integrates classroom theory with real-world practice, empowering students to develop practical skills, gain invaluable industry experience, and forge meaningful connections that lay the foundation for successful careers. As students traverse the bustling streets of Boston, they are enveloped by a myriad of opportunities for personal and intellectual growth, from internships and research initiatives to cultural events and community engagement endeavors. Northeastern's academic offerings are as diverse and expansive as its student body, encompassing a wide array of disciplines ranging from business and engineering to health sciences, arts, and beyond. With world-class faculty at the helm, students are encouraged to explore their passions, pursue interdisciplinary studies, and push the boundaries of knowledge through innovative research and scholarship. Beyond the classroom, Northeastern fosters a vibrant campus culture characterized by a spirit of collaboration, inclusivity, and social responsibility. Through student-led organizations, volunteer projects, and advocacy efforts, students are empowered to effect positive change in their communities and beyond, embodying Northeastern's core values of global citizenship and civic engagement. As a hub of innovation and entrepreneurship, Northeastern serves as a catalyst for economic growth and societal advancement, providing resources, mentorship, and support to aspiring entrepreneurs and innovators. Whether launching a startup, developing groundbreaking technologies, or tackling pressing social issues, Northeastern students and alumni are at the forefront of driving positive change and shaping the future of their respective industries. In the ever-evolving landscape of higher education, Northeastern University remains steadfast in its commitment to excellence, equity, and accessibility, preparing students to thrive in an increasingly complex and interconnected world. Through its unwavering dedication to academic rigor, experiential learning, and inclusive excellence, Northeastern continues to inspire and empower generations of students to lead lives of purpose, passion, and impact, leaving an indelible mark on the world around them.\nAs Northeastern University continues to evolve and expand its reach, it remains dedicated to fostering a culture of innovation and collaboration both on campus and beyond. Through strategic partnerships with industry leaders, government agencies, and nonprofit organizations, Northeastern leverages its intellectual capital and research expertise to address some of the most pressing challenges facing society today. From pioneering advancements in healthcare and technology to driving sustainable development and social justice initiatives, Northeastern's interdisciplinary approach to problem-solving enables students and faculty to make meaningful contributions to the global community. Additionally, Northeastern's commitment to diversity, equity, and inclusion serves as a guiding principle in all aspects of university life, ensuring that every member of the community feels valued, respected, and empowered to succeed. With a forward-thinking vision and a steadfast dedication to excellence, Northeastern University continues to inspire and empower individuals to make a positive impact on the world, shaping a brighter future for generations to come.\nAs Northeastern University embraces the challenges and opportunities of the 21st century, it remains at the forefront of innovation in education, research, and community engagement. With a focus on interdisciplinary collaboration and hands-on learning experiences, Northeastern equips students with the skills and knowledge needed to thrive in a rapidly changing world. Through its network of global campuses and partnerships, Northeastern provides students with unparalleled opportunities to engage with diverse cultures, tackle complex global issues, and develop a global mindset that transcends borders. Moreover, Northeastern's commitment to sustainability and environmental stewardship underscores its role as a leader in addressing the urgent challenges of climate change and environmental degradation. By integrating sustainability principles into its curriculum, operations, and campus initiatives, Northeastern is preparing students to become responsible stewards of the planet and agents of positive change in their communities.With a legacy of excellence and a vision for the future, Northeastern University continues to push the boundaries of knowledge, inspire innovation, and empower individuals to create a more just, and equitable world for all.\nNortheastern University stands as a testament to the transformative power of education, innovation, and community engagement. Rooted in a tradition of academic excellence and fueled by a spirit of relentless curiosity, Northeastern empowers students to become lifelong learners, critical thinkers, and leaders in their respective fields. Through its cutting-edge research initiatives and collaborative partnerships, Northeastern tackles some of society's most pressing challenges, from healthcare disparities and urban revitalization to technological innovation and global security. Moreover, Northeastern's commitment to diversity, equity, and inclusion ensures that all members of its community have the opportunity to thrive and contribute their unique perspectives to the pursuit of knowledge. As Northeastern continues to push the boundaries of what is possible in higher education, it remains steadfast in its mission to inspire innovation and drive positive change in the world."

In [5]:
model_input = USER_CHAT_TEMPLATE.format(prompt=user_input + '<start_of_turn>model\n')

In [6]:
prompt_tokens = tokenizer.encode(model_input)

In [7]:
len(prompt_tokens)

1000

## Partition in all layers

In [43]:
class GemmaLayerModel(nn.Module):
    def __init__(self, config: gemma_config.GemmaConfig):
        super().__init__()
        self.config = config
        self.layer = GemmaDecoderLayer(config)
        
    def forward(self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, kv_write_indices: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor) -> torch.Tensor:
        hidden_states = self.layer(hidden_states=hidden_states, freqs_cis=freqs_cis, kv_write_indices=kv_write_indices, kv_cache=kv_cache, mask=mask)
        return hidden_states
    
class GemmaLastLayerModel(nn.Module):
    def __init__(self, config: gemma_config.GemmaConfig):
        super().__init__()
        self.config = config
        self.layer = GemmaDecoderLayer(config)
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        
    def forward(self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, kv_write_indices: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor) -> torch.Tensor:
        hidden_states = self.layer(hidden_states=hidden_states, freqs_cis=freqs_cis, kv_write_indices=kv_write_indices, kv_cache=kv_cache, mask=mask)
        hidden_states = self.norm(hidden_states)
        return hidden_states

In [44]:
def load_model(model, filepath):
    """Loads a model's state dictionary from a specified filepath."""
    model.load_state_dict(torch.load(filepath))
    
def load_weights(model, model_path: str):
    """Original function used for weight loading in GemmaForCausalLM"""
    model.load_state_dict(
        torch.load(
            model_path, mmap=True, weights_only=True,
        )['model_state_dict'],
        strict=False,
    )
    
def load_adjusted_weights(model, adjusted_state_dict):
    """Loads the adjusted weights into the model, with strict=False to allow for incomplete state dicts."""
    model.load_state_dict(adjusted_state_dict, strict=False)
    
def save_model(model, filepath):
    """Saves a model's state dictionary to a specified filepath."""
    torch.save(model.state_dict(), filepath)
    
def adjust_keys_for_ith_layer(state_dict, config, i):
    """Adjusts the state_dict's keys for the first layer model."""
    new_state_dict = {}
    add_last_norm = config.num_hidden_layers == (i + 1)
    for key, value in state_dict.items():
        if key.startswith(f'model.layers.{i}.'):
            new_key = key.replace(f'model.layers.{i}.', f'layer.')
            new_state_dict[new_key] = value
        elif add_last_norm and key.startswith('model.norm.weight'): 
            new_state_dict['norm.weight'] = value
    return new_state_dict

In [45]:
config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
# Initialize and load weights for both models
layer_model = GemmaLayerModel(config)
last_layer_model = GemmaLastLayerModel(config)

In [46]:
# Load original model to be splitted
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [47]:
# full_model_state_dict is the state dict loaded from the complete GemmaForCausalLM model
full_model_state_dict = model.state_dict()

for layer in range(config.num_hidden_layers - 1): 
    # Adjust and load weights for the first layer model
    layer_state_dict = adjust_keys_for_ith_layer(full_model_state_dict, config, layer)
    load_adjusted_weights(layer_model, layer_state_dict)
    # Save the ith layer model
    save_model(layer_model, f'./weights/{VARIANT}/layer_model_{layer}.pth')

# Adjust and load weights for the remaining layers model
last_layer_state_dict = adjust_keys_for_ith_layer(full_model_state_dict, config, config.num_hidden_layers - 1)
load_adjusted_weights(last_layer_model, last_layer_state_dict)
# Save the last layer model
save_model(last_layer_model, f'./weights/{VARIANT}/layer_model_{config.num_hidden_layers - 1}.pth')

In [48]:
# Model is the loaded GemmaForCausalLM instance
# Save Embedding weights
torch.save(model.embedder.state_dict(), f'./weights/{VARIANT}/embedding_weights.pth')

# Sampler does not have any trainable parameters or buffers to save

## Inference all layers splitted

In [29]:
# Ensure that the tokenizer is present
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'



In [49]:
# Define model configuration
MACHINE_TYPE = 'cuda'
device = torch.device(MACHINE_TYPE)
config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
config.tokenizer = tokenizer_path
config.quant = 'quant' in VARIANT
torch.set_default_dtype(config.get_dtype())
device = torch.device(MACHINE_TYPE)
# Restart from saved models
# Initialize and load splits
model = [GemmaLayerModel(config) for layer in range(config.num_hidden_layers - 1)]
model.append(GemmaLastLayerModel(config))
for layer in range(config.num_hidden_layers):
    load_model(model[layer], f'./weights/{VARIANT}/layer_model_{layer}.pth')

In [50]:
# Hidden size needs to be divisible by the number of heads
assert config.hidden_size % config.num_attention_heads == 0

# Prepare hidden_states, freqs_cis, kv_write_indices, kv_cache, and mask for inference
max_seq_len = config.max_position_embeddings
head_dim = config.head_dim
vocab_size = config.vocab_size

tokenizer = Tokenizer(config.tokenizer)
# Initialize embedder
embedder = Embedding(vocab_size, config.hidden_size, config.quant).to(device)
embedding_weights = torch.load(f'./weights/{VARIANT}/embedding_weights.pth', map_location=device)
embedder.load_state_dict(embedding_weights)
model = [layer.to(device).eval() for layer in model]
# Initialize sampler
sampler = Sampler(vocab_size).to(device)
# Pre-compute rotary embedding table.
rope_theta = getattr(config, 'rope_theta', 10000)
prec_freqs_cis = precompute_freqs_cis(head_dim,
                                 max_seq_len * 2,
                                 theta=rope_theta).to(device)

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


Forward and generate functions are still the same from [split model inference](##Split-model-inference) section

In [51]:
#model[13] = None

In [52]:
# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is the best city in Europe?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='Barcelona.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in Barcelona?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)

Chat prompt:
 <start_of_turn>user
What is the best city in Europe?<end_of_turn>
<start_of_turn>model
Barcelona.<end_of_turn>
<start_of_turn>user
What can I do in Barcelona?<end_of_turn>
<start_of_turn>model



"## Things to do in Barcelona:\n\n**Culture & History:**\n\n* Explore the Gothic Quarter with its narrow cobblestone streets and historic architecture.\n* Visit La Sagrada Familia, Gaudi's unfinished masterpiece.\n* See the Casa Batlló and Casa Milà, two more of Gaudi's iconic buildings.\n* Explore Park Güell, a hilltop park with stunning views of the city.\n* Visit the Picasso Museum and learn about the famous artist's early works"