In [None]:
from doclang_scaling.tokenizer import fast_tokenizer
import datasets
dataset = datasets.load_dataset("MikiV/SimpleStories-SimpleStories-chunked-512")
dataset

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/17 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/17 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 4136391
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 41774
    })
})

In [5]:
# grab one example from the training set
ids = dataset["train"][0]["input_ids"]

# decode
decoded = fast_tokenizer.decode(ids, skip_special_tokens=False)
decoded

"Eagerly, a girl named Kim went hiking with her friends. They found a secret passageway under a big rock. Curious, they crawled through it and entered a magical world where all the animals talked. They learned that the animals had their own struggles but always found a way to be happy. \n\nInspired, the children played with the animals and shared their own stories of hardship. They realized that laughter could help them through hard times. When they finally left, they took the animals' smiles with them, knowin"

In [14]:
import json
import os
from dataclasses import dataclass
from typing import Optional

import sys
import os
# Add the parent directory to sys. path
sys.path.append(os.path.dirname(('../doclang_scaling')))


import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download

# Import your model class
from doclang_scaling.alibi_transformer import AlibiTransformer


@dataclass
class ModelFiles:
    config_path: str
    weights_path: str

def resolve_model_files(
    model_id: Optional[str] = None,
    local_model_dir: Optional[str] = None,
) -> ModelFiles:
    """
    Returns paths to config.json and pytorch_model.bin either from HF Hub or local folder.
    """
    if model_id:
        repo_dir = snapshot_download(repo_id=model_id, allow_patterns=["config.json", "pytorch_model.bin"])
        config_path = os.path.join(repo_dir, "config.json")
        weights_path = os.path.join(repo_dir, "pytorch_model.bin")
    elif local_model_dir:
        config_path = os.path.join(local_model_dir, "config.json")
        weights_path = os.path.join(local_model_dir, "pytorch_model.bin")
    else:
        raise ValueError("Provide either model_id (Hub) or local_model_dir (disk).")
    if not os.path.exists(config_path) or not os.path.exists(weights_path):
        raise FileNotFoundError("Could not find config.json or pytorch_model.bin.")
    return ModelFiles(config_path, weights_path)

def load_model(model_files: ModelFiles, device: Optional[torch.device] = None) -> AlibiTransformer:
    with open(model_files.config_path, "r") as f:
        cfg = json.load(f)

    # The training script saved:
    # {
    #   "model_shape": {
    #       "layers": ...,
    #       "d_model": ...,
    #       "n_heads": ...,
    #       "d_vocab": ...,
    #       "ffw_size": ...,
    #       "d_head": ...
    #   },
    #   "context_length": ...
    # }
    ms = cfg["model_shape"]
    model = AlibiTransformer(
        d_vocab=ms["d_vocab"],
        d_model=ms["d_model"],
        n_heads=ms["n_heads"],
        d_head=ms["d_head"],
        ffw_size=ms["ffw_size"],
        layers=ms["layers"],
        dropout=0.0,
    )

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    sd = torch.load(model_files.weights_path, map_location=device)
    model.load_state_dict(sd, strict=True)
    model.to(device)
    model.eval()
    return model

@torch.inference_mode()
def generate(
    model: AlibiTransformer,
    input_ids: torch.Tensor,
    max_new_tokens: int = 50,
    eos_token_id: Optional[int] = None,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
) -> torch.Tensor:
    """
    Simple autoregressive generation.
    input_ids: [1, T] tensor with tokenized prompt.
    Returns: [1, T + max_new_tokens] or early-stop on EOS.
    """
    device = next(model.parameters()).device
    out = input_ids.to(device)

    for _ in range(max_new_tokens):
        logits = model(out)[:, -1, :]  # [1, vocab]
        logits = logits / max(temperature, 1e-6)

        if top_k is not None and top_k > 0:
            topk_vals, topk_idx = torch.topk(logits, k=top_k, dim=-1)
            mask = torch.full_like(logits, fill_value=float("-inf"))
            mask.scatter_(dim=-1, index=topk_idx, src=topk_vals)
            logits = mask

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)  # [1, 1]
        out = torch.cat([out, next_token], dim=1)

        if eos_token_id is not None and next_token.item() == eos_token_id:
            break

    return out


# 1) Point to your pushed model on the Hub (preferred) or local folder
model_id = "MikiV/1M-5"  # e.g., "myuser/cfgs-default.yaml"
local_model_dir = None  # or "/tmp/doclang_model_12345" if running locally

model_files = resolve_model_files(model_id=model_id, local_model_dir=local_model_dir)
model = load_model(model_files)

# 2) Load or recreate the tokenizer you used for training
# If uploaded to Hub alongside the model:
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_id)
#
# Otherwise, recreate it exactly (example below commented):
tokenizer = fast_tokenizer

# --- Example prompt ---
prompt_text = "The "
# input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids
# If you have a simple char-level tokenizer, convert manually:
input_ids = torch.tensor([[fast_tokenizer.convert_tokens_to_ids(ch) for ch in prompt_text]], dtype=torch.long)

# 3) Generate
eos_token_id = getattr(getattr(tokenizer, "eos_token_id", None), "__int__", lambda: None)()
out_ids = generate(
    model,
    input_ids=input_ids,
    max_new_tokens=100,
    eos_token_id=eos_token_id,
    temperature=0.8,
    top_k=50,
)

# 4) Decode
text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
print(text)




Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

The warm only by a fleeling trees." The woman said touches each ento the sun was love. 

They flower tho
