In [None]:
# H-Net Text Generation Tutorial
# This notebook demonstrates how to use the H-Net model for text generation

import matplotlib.pyplot as plt
plt.rcParams["figure.dpi"] = 300       # resolution of figures when shown
plt.rcParams["savefig.dpi"] = 300      # resolution when saving with plt.savefig

import numpy as np
import json
import torch
import sys
from omegaconf import ListConfig

In [None]:
from hnet.models.mixer_seq import HNetForCausalLM
from hnet.models.config_hnet import (
    AttnConfig,
    SSMConfig,
    HNetConfig,
)
from hnet.utils.tokenizers import ByteTokenizer

In [None]:
def load_from_pretrained(model_path: str, model_config_path: str):
    """Load model from pretrained checkpoint.

    Args:
        model_path: Path to the model checkpoint (.pt file)
        model_config_path: Path to the model configuration (.json file)

    Returns:
        Loaded HNetForCausalLM model
    """
    # Load configuration
    with open(model_config_path, "r") as f:
        config = json.load(f)

    # Create config objects
    attn_cfg = AttnConfig(**config.pop("attn_cfg"))
    ssm_cfg = SSMConfig(**config.pop("ssm_cfg"))
    hnet_cfg = HNetConfig(**config, attn_cfg=attn_cfg, ssm_cfg=ssm_cfg)

    # Create model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = HNetForCausalLM(hnet_cfg, device=device, dtype=torch.bfloat16)
    model.eval()

    # Load checkpoint
    major, minor = map(int, torch.__version__.split('.')[:2])
    if (major, minor) >= (2, 6):
        with torch.serialization.safe_globals([ListConfig]):
            state_dict = torch.load(model_path, map_location=device, weights_only=False)
    else:
        state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)

    return model

In [None]:
def generate(
    model,
    prompt: str,
    max_tokens: int = 1024,
    temperature: float = 1.0,
    top_p: float = 0.9,
):
    """Generate text from the model, yielding tokens as they're generated.

    Args:
        model: HNetForCausalLM model
        prompt: Input text prompt
        max_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature (higher = more random)
        top_p: Top-p sampling parameter

    Yields:
        Generated text token by token as strings
    """
    device = next(model.parameters()).device
    tokenizer = ByteTokenizer()

    # Tokenize prompt
    encoded = tokenizer.encode([prompt], add_bos=True)[0]
    input_ids = torch.tensor(
        encoded["input_ids"], dtype=torch.long, device=device
    ).unsqueeze(0)

    inference_cache = model.allocate_inference_cache(
        1, input_ids.shape[1] + max_tokens, dtype=torch.bfloat16
    )

    with torch.inference_mode():
        mask = torch.ones(input_ids.shape, device=device, dtype=torch.bool)
        output = model.forward(input_ids, mask=mask, inference_params=inference_cache)

    logits = output.logits[0, -1, :] / temperature

    for _ in range(max_tokens):
        # Apply top-p sampling
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(
                torch.softmax(sorted_logits, dim=-1), dim=-1
            )

            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
            sorted_indices_to_remove[0] = 0

            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[indices_to_remove] = -float("inf")

        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, 1)

        if next_token.item() == tokenizer.eos_idx:
            break

        current_token = next_token.unsqueeze(0)
        yield current_token

        with torch.inference_mode():
            output = model.step(current_token, inference_cache)

        # Get logits and apply temperature
        logits = output.logits[0, -1, :] / temperature

In [None]:
# Load the pretrained model
model_path = "/private/tmp/Paper2Agent/HNet_Agent/notebooks/h_net_text_generation_script/data/hnet_2stage_L.pt"
config_path = "/private/tmp/Paper2Agent/HNet_Agent/notebooks/h_net_text_generation_script/data/hnet_2stage_L.json"

print("Loading model...")
model = load_from_pretrained(model_path, config_path)
print("Model loaded successfully!")
print(f"Model device: {next(model.parameters()).device}")
print(f"Model dtype: {next(model.parameters()).dtype}")

In [None]:
# Example 1: Generate text with default parameters
prompt1 = "Once upon a time"
max_tokens = 200
temperature = 1.0
top_p = 0.9

print(f"Prompt: {prompt1}")
print(f"Parameters: max_tokens={max_tokens}, temperature={temperature}, top_p={top_p}\n")
print("Generated text:")
print(prompt1, end="")

tokenizer = ByteTokenizer()
buf = []

for token in generate(model, prompt1, max_tokens=max_tokens, temperature=temperature, top_p=top_p):
    buf.append(token)
    
    decoded = None
    res = None
    for j in range(1, min(len(buf), 4)):
        try:
            res = tokenizer.decode(buf[:j])
            decoded = j
        except:
            pass
    
    if res is not None:
        print(res, end="", flush=True)
        buf = buf[decoded:]

print("\n" + "="*80)

In [None]:
# Example 2: Generate with different temperature (more creative)
prompt2 = "The future of artificial intelligence"
max_tokens = 150
temperature = 1.2
top_p = 0.95

print(f"Prompt: {prompt2}")
print(f"Parameters: max_tokens={max_tokens}, temperature={temperature}, top_p={top_p}\n")
print("Generated text:")
print(prompt2, end="")

buf = []

for token in generate(model, prompt2, max_tokens=max_tokens, temperature=temperature, top_p=top_p):
    buf.append(token)
    
    decoded = None
    res = None
    for j in range(1, min(len(buf), 4)):
        try:
            res = tokenizer.decode(buf[:j])
            decoded = j
        except:
            pass
    
    if res is not None:
        print(res, end="", flush=True)
        buf = buf[decoded:]

print("\n" + "="*80)