### Minimal Inference for Text Generation


In [1]:
import os
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from ModelArchitecture import Transformer, ModelConfig, generate

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")


Device: cpu


In [2]:
tokenizer = Tokenizer.from_file("LumenTokenizer.json")

# Helpful wrappers
def encode(text: str) -> torch.LongTensor:
    return torch.tensor(tokenizer.encode(text).ids, dtype=torch.long).unsqueeze(0)

def decode(ids: torch.LongTensor) -> str:
    return tokenizer.decode(ids.tolist())


In [3]:
config = ModelConfig(
    vocab_size=32000,
    hidden_size=768,
    n_heads=12,
    n_kv_heads=4,
    n_kv_groups=3,
    head_dim=64,
    n_layers=12,
    attention_bias=False,
    intermediate_size=3072,
    mlp_bias=False,
    eps=1e-5,
    dropout=0.0,
    max_position_embeddings=2048,
    pre_norm=True,
    tie_weights=True,
    max_seq_len=2048,
)

model = Transformer(config).to(device)

# Minimal, .pt-first checkpoint loading
candidate_names = [
    "best_model_params.pt"
]
weights_path = next((n for n in candidate_names if os.path.exists(n)), None)

if weights_path is None:
    # fallback: pick any .pt in current dir (if present)
    try:
        any_pts = [f for f in os.listdir(".") if f.endswith(".pt")]
        weights_path = any_pts[0] if any_pts else None
    except Exception:
        weights_path = None

if weights_path is not None:
    state = torch.load(weights_path, map_location=device)
    if isinstance(state, dict) and "model_state_dict" in state:
        state = state["model_state_dict"]
    model.load_state_dict(state, strict=False)
    print(f"Loaded .pt checkpoint: {weights_path}")
else:
    print("Warning: No .pt checkpoint found. Using randomly initialized model.")

model.eval()


Loaded .pt checkpoint: best_model_params.pt


Transformer(
  (token_embedding): Embedding(32000, 768)
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): GroupedMultiQueryAttention(
        (dropout): Dropout(p=0.0, inplace=False)
        (q_proj): Linear(in_features=768, out_features=768, bias=False)
        (k_proj): Linear(in_features=768, out_features=256, bias=False)
        (v_proj): Linear(in_features=768, out_features=256, bias=False)
        (w_o): Linear(in_features=768, out_features=768, bias=False)
        (rope): RotaryEmbedding()
      )
      (feed_forward): SwiGLUFeedForward(
        (dropout): Dropout(p=0.0, inplace=False)
        (gate_proj): Linear(in_features=768, out_features=3072, bias=True)
        (up_proj): Linear(in_features=768, out_features=3072, bias=True)
        (down_proj): Linear(in_features=3072, out_features=768, bias=True)
        (act): SiLU()
      )
      (attn_norm): RMSNorm()
      (ffn_norm): RMSNorm()
      (dropout): Dropout(p=0.0, inplace=False)
    )
  )
  (em

In [4]:
@torch.no_grad()
def generate_text(
    prompt: str,
    max_new_tokens: int = 128,
    temperature: float = 0.8,
    top_k: int = 0,
    top_p: float = 0.9,
    do_sample: bool = True,
    eos_token_id: int | None = None,
    pad_token_id: int | None = None,
):
    input_ids = encode(prompt).to(device)
    out_ids = generate(
        model=model,
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=do_sample,
        eos_token_id=eos_token_id,
        pad_token_id=pad_token_id,
        device=device,
    )
    # Strip the prompt portion for the decoded continuation
    continuation_ids = out_ids[0, input_ids.size(1):]
    return decode(continuation_ids.cpu())


In [10]:
prompt = "Python code to print 'Hello, World!'"
output = generate_text(prompt, max_new_tokens=100, temperature=0.8, top_p=0.9)
print("Prompt:")
print(prompt)
print("\nGeneration:")
print(output)


Prompt:
Python code to print 'Hello, World!'

Generation:
 to the local area. We can also use this C# library to extract HTML content from the database.Create a Bash script snippet that Checks Extreme Nail care: Filing and Shaping Nails for Engineer for Professionals. Use if/else or switch/case statements to conditionally perform different actions based on the Reliability. Dry-run, then include comments that outline the control flow and how you handle different scenarios. Here is a bash script that uses conditional statements (if...elif...else) to check whether filing
