In [None]:
# LAG1 — Limitless Autonomous Guardian v1
Building a TensorFlow-native large language model that embraces multi-path reasoning, cross-thinking, and friendly NLP behaviors for the LAG project.

## Concept & Roadmap
- **Mission**: deliver a chatty, coding-aware assistant that matches GPT-4 vibes while remaining open and hackable.
- **Backbone**: decoder-only Transformer blocks enhanced with cross-thinking branches for richer deliberation.
- **Reasoning Extras**: multi-path generation, reflection scoring, structured thought logging, controllable decoding knobs.
- **Training Setup**: TensorFlow 2.x + mixed precision on Colab GPUs/TPUs, pipelines for custom datasets and continued-pretrain corpora.
- **Evaluation**: perplexity, code execution harness, rubric-based human eval, and safety refusal tests.

## LAG1 Architecture Highlights
1. **Dual-Stream Cross Thinking**: each decoder block maintains a primary stream (response focus) and a reflection stream (analysis), sharing information through gated cross-attention.
2. **Thought Cache**: optional buffer that stores intermediate latent states for downstream reflection or tool use.
3. **Reasoning Controller**: adjusts exploration parameters (temperature, top-k/p, depth) dynamically per prompt based on intent signals.
4. **Tokenizer Flexibility**: plug-and-play between Hugging Face BPEs and lightweight `TextVectorization` for offline experimentation.
5. **Training Modes**: supervised fine-tuning, continued pretraining, and contrastive preference optimization (hooks provided).

In [None]:
# Optional: install latest dependencies when running on a fresh Colab runtime
%pip install -q tensorflow==2.15.0 transformers datasets sentencepiece accelerate tensorflow-text

In [None]:
SEED = 7
tf.keras.utils.set_random_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

@dataclass
class Lag1Config:
    vocab_size: int = 32000
    max_position_embeddings: int = 1024
    hidden_size: int = 768
    num_layers: int = 12
    num_heads: int = 12
    cross_thinking_heads: int = 6
    intermediate_size: int = 3072
    dropout_rate: float = 0.1
    layer_norm_epsilon: float = 1e-5
    initializer_range: float = 0.02
    pad_token_id: int = 0
    bos_token_id: int = 1
    eos_token_id: int = 2
    thought_cache_size: int = 4
    rope_theta: float = 10000.0
    use_rotary_embeddings: bool = False
    use_gradient_checkpointing: bool = False
    dtype: str = "float32"

lag_config = Lag1Config()
print(json.dumps(asdict(lag_config), indent=2))

## Data & Tokenization Pipeline
- Supports Hugging Face tokenizers and fallback TensorFlow `TextVectorization`.
- Provides dataset builders for conversation, code, and safety corpora with automatic mixing weights.
- Includes utilities for prompt formatting with system/instruction/user separation and response targets.
- Caches processed shards to Google Drive for re-use across Colab sessions.

### Source Hugging Face datasets
We'll fuse multiple open datasets spanning instructions, safety, math, geo, coding, and reasoning. Some require Hugging Face authentication—log in via `huggingface-cli login` before running the loader.

In [None]:
DATASET_SPECS = [
    {
        "name": "awesome_chatgpt_prompts",
        "hf_path": "fka/awesome-chatgpt-prompts",
        "split": "train",
        "requires_auth": False,
        "weight": 0.1,
        "user_keys": ["prompt"],
        "assistant_keys": ["response", "completion", "answer"],
        "assistant_fallback": lambda record, user: (
            "Here is a structured plan to address the following request:\n" + (user or ""))
        ,
        "system_builder": lambda record: f"You are role-playing as {record.get('act', 'a helpful assistant')}.",
    },
    {
        "name": "gdpval",
        "hf_path": "openai/gdpval",
        "split": "train",
        "requires_auth": True,
        "weight": 0.15,
        "user_keys": ["prompt", "question", "input", "instruction"],
        "assistant_keys": ["ideal", "answer", "output", "completion"],
        "system_builder": "You are LAG, providing careful policy-compliant answers.",
    },
    {
        "name": "geogpt_qa",
        "hf_path": "GeoGPT-Research-Project/GeoGPT-QA",
        "split": "train",
        "requires_auth": True,
        "weight": 0.1,
        "user_keys": ["question", "Question", "prompt"],
        "assistant_keys": ["answer", "Answer", "response"],
        "system_builder": "You are LAG with deep geospatial knowledge.",
    },
    {
        "name": "gsm8k",
        "hf_path": "openai/gsm8k",
        "config_name": "main",
        "split": "train",
        "requires_auth": False,
        "weight": 0.2,
        "user_keys": ["question"],
        "assistant_keys": ["answer", "solution"],
        "system_builder": "You are LAG solving grade-school math step-by-step.",
    },
    {
        "name": "swe_rebench",
        "hf_path": "nebius/SWE-rebench",
        "split": "train",
        "requires_auth": True,
        "weight": 0.2,
        "user_keys": ["prompt", "question", "instruction", "problem"],
        "assistant_keys": ["solution", "answer", "completion", "output"],
        "system_builder": "You are LAG helping with professional software engineering tasks.",
    },
    {
        "name": "math500",
        "hf_path": "HuggingFaceH4/MATH-500",
        "split": "train",
        "requires_auth": False,
        "weight": 0.25,
        "user_keys": ["problem", "question"],
        "assistant_keys": ["solution", "answer"],
        "system_builder": "You are LAG solving graduate-level mathematics with detailed reasoning.",
    },
 ]

DATASET_SPECS_BY_NAME = {spec["name"]: spec for spec in DATASET_SPECS}

def _first_nonempty(record: Dict[str, Any], keys: List[str]) -> Optional[str]:
    for key in keys:
        if key in record and record[key]:
            value = record[key]
            if isinstance(value, str):
                return value.strip()
            return json.dumps(value)
    return None

def convert_record_to_sample(source_name: str, record: Dict[str, Any]) -> Optional[Sample]:
    spec = DATASET_SPECS_BY_NAME[source_name]
    user = _first_nonempty(record, spec.get("user_keys", []))
    assistant = _first_nonempty(record, spec.get("assistant_keys", []))
    if assistant is None and spec.get("assistant_fallback") is not None and user is not None:
        assistant = spec["assistant_fallback"](record, user)
    if user is None or assistant is None:
        return None
    system_builder = spec.get("system_builder")
    if callable(system_builder):
        system_prompt = system_builder(record)
    elif isinstance(system_builder, str):
        system_prompt = system_builder
    else:
        system_prompt = "You are LAG, a helpful multimodal assistant."
    metadata = {
        "source": source_name,
    }
    return Sample(system=system_prompt, user=user, assistant=assistant, metadata=metadata)

def load_raw_datasets(auth_token: Optional[str] = None, streaming: bool = False) -> Dict[str, Any]:
    datasets = {}
    for spec in DATASET_SPECS:
        name = spec["name"]
        try:
            args = []
            if spec.get("config_name") is not None:
                args.append(spec["config_name"])
            load_kwargs: Dict[str, Any] = {"split": spec.get("split", "train"), "streaming": streaming}
            if spec.get("requires_auth") and auth_token is not None:
                load_kwargs["use_auth_token"] = auth_token
            elif spec.get("requires_auth") and auth_token is None:
                load_kwargs["use_auth_token"] = True
            dataset = load_dataset(spec["hf_path"], *args, **load_kwargs)
            datasets[name] = dataset
            print(f"Loaded {name} -> {spec['hf_path']} ({load_kwargs['split']})")
        except Exception as err:
            print(f"⚠️ Failed to load {name}: {err}")
    return datasets

def build_weighted_tf_dataset(
    raw_datasets: Dict[str, Any],
    tokenizer_manager: TokenizerManager,
    config: Lag1Config,
    batch_size: int = 1,
    max_per_source: Optional[int] = None,
    weights: Optional[Dict[str, float]] = None,
    streaming: bool = False,
) -> tf.data.Dataset:
    weights = weights or {spec["name"]: spec.get("weight", 1.0) for spec in DATASET_SPECS}
    samples: List[Sample] = []
    for name, dataset in raw_datasets.items():
        spec_weight = weights.get(name, 1.0)
        count = 0
        iterator = dataset if streaming else dataset
        for record in iterator:
            sample = convert_record_to_sample(name, record)
            if sample is None:
                continue
            sample.metadata = sample.metadata or {}
            sample.metadata["weight"] = spec_weight
            samples.append(sample)
            count += 1
            if max_per_source is not None and count >= max_per_source:
                break
    if not samples:
        raise ValueError("No samples could be constructed; check dataset availability and credentials.")
    builder = Lag1DatasetBuilder(tokenizer_manager, config)
    if not tokenizer_manager.is_hf_tokenizer:
        tokenizer_manager.adapt([sample.user + " " + sample.assistant for sample in samples])
    return builder.build_tf_dataset(samples, batch_size=batch_size)

> ⚠️ Some datasets above are gated. Authenticate with `huggingface-cli login` or create a token at https://huggingface.co/settings/tokens and pass it to `load_raw_datasets(auth_token=...)`.

In [None]:
class TokenizerManager:
    """Handles tokenizer loading or creation for LAG1."""
    def __init__(self, config: Lag1Config, tokenizer_name: str = "gpt2", vocab_size: Optional[int] = None):
        self.config = config
        self.tokenizer_name = tokenizer_name
        self.vocab_size = vocab_size or config.vocab_size
        self.tokenizer = None
        self.is_hf_tokenizer = False

    def load(self):
        if HF_AVAILABLE:
            print(f"Loading Hugging Face tokenizer: {self.tokenizer_name}")
            self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
            if self.tokenizer.pad_token is None:
                self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
            self.is_hf_tokenizer = True
            return self.tokenizer

        if tf_text is None:
            raise RuntimeError("TensorFlow Text is required for fallback tokenizer creation.")
        print("Creating TensorFlow Text tokenizer from scratch.")
        vectorizer = tf.keras.layers.TextVectorization(
            standardize="lower_and_strip_punctuation",
            max_tokens=self.vocab_size,
            output_mode="int",
            output_sequence_length=self.config.max_position_embeddings,
        )
        self.tokenizer = vectorizer
        self.is_hf_tokenizer = False
        return self.tokenizer

    def adapt(self, texts: List[str]):
        if self.tokenizer is None:
            raise ValueError("Tokenizer not initialized. Call load() first.")
        if self.is_hf_tokenizer:
            return
        ds = tf.data.Dataset.from_tensor_slices(texts).batch(1024)
        self.tokenizer.adapt(ds)

    def encode_batch(self, texts: List[str]) -> Dict[str, np.ndarray]:
        if self.tokenizer is None:
            raise ValueError("Tokenizer not initialized. Call load() first.")

        if self.is_hf_tokenizer:
            encoded = self.tokenizer(
                texts,
                padding="max_length",
                truncation=True,
                max_length=self.config.max_position_embeddings,
                return_tensors="np",
            )
            return {"input_ids": encoded["input_ids"], "attention_mask": encoded["attention_mask"]}

        # TensorFlow Text pathway
        data = tf.convert_to_tensor(texts)
        token_ids = self.tokenizer(data)
        attention_mask = tf.cast(token_ids != 0, tf.int32)
        return {
            "input_ids": token_ids.numpy(),
            "attention_mask": attention_mask.numpy(),
        }

In [None]:
@dataclass
class Sample:
    system: str
    user: str
    assistant: str
    metadata: Optional[Dict[str, Any]] = None

class Lag1DatasetBuilder:
    """Creates TensorFlow datasets with weighted sampling from multiple sources."""
    def __init__(self, tokenizer_manager: TokenizerManager, config: Lag1Config):
        self.tokenizer_manager = tokenizer_manager
        self.config = config

    def _format_prompt(self, sample: Sample) -> Tuple[str, str]:
        system = sample.system.strip() if sample.system else "You are LAG, a helpful AI."
        prompt = f"<s>[SYSTEM]\n{system}\n[/SYSTEM]\n[USER]\n{sample.user.strip()}\n[/USER]\n[ASSISTANT]\n"
        target = sample.assistant.strip() + "</s>"
        return prompt, target

    def encode_sample(self, sample: Sample) -> Dict[str, np.ndarray]:
        prompt, target = self._format_prompt(sample)
        full_text = prompt + target
        encoded = self.tokenizer_manager.encode_batch([full_text])
        labels = encoded["input_ids"].copy()
        prompt_len = len(self.tokenizer_manager.encode_batch([prompt])["input_ids"][0])
        labels[0][:prompt_len] = -100
        return {
            "input_ids": encoded["input_ids"],
            "attention_mask": encoded["attention_mask"],
            "labels": labels,
        }

    def build_tf_dataset(self, samples: List[Sample], batch_size: int = 1) -> tf.data.Dataset:
        encoded_batches = [self.encode_sample(sample) for sample in samples]
        input_ids = np.vstack([item["input_ids"] for item in encoded_batches])
        attention_masks = np.vstack([item["attention_mask"] for item in encoded_batches])
        labels = np.vstack([item["labels"] for item in encoded_batches])
        dataset = tf.data.Dataset.from_tensor_slices((input_ids, attention_masks, labels))
        dataset = dataset.shuffle(buffer_size=len(samples)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
        return dataset

In [None]:
class SinusoidalPositionEmbedding(tf.keras.layers.Layer):
    def __init__(self, max_len: int, dim: int, name: Optional[str] = None):
        super().__init__(name=name)
        position = np.arange(max_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, dim, 2) * -(math.log(10000.0) / dim))
        pe = np.zeros((max_len, dim))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        self.pe = tf.convert_to_tensor(pe[np.newaxis], dtype=tf.float32)
    def call(self, x: tf.Tensor) -> tf.Tensor:
        length = tf.shape(x)[1]
        return tf.cast(self.pe[:, :length, :], x.dtype)

def build_attention_mask(attention_mask: tf.Tensor) -> tf.Tensor:
    """Creates a causal attention mask with padding support."""
    seq_len = tf.shape(attention_mask)[-1]
    padding_mask = tf.cast(attention_mask[:, tf.newaxis, tf.newaxis, :], tf.float32)
    causal_mask = tf.linalg.band_part(tf.ones((1, 1, seq_len, seq_len)), -1, 0)
    return padding_mask * causal_mask

class CrossThinkingBlock(tf.keras.layers.Layer):
    def __init__(self, config: Lag1Config, name: Optional[str] = None):
        super().__init__(name=name)
        self.config = config
        self.primary_self_attn = tf.keras.layers.MultiHeadAttention(
            num_heads=config.num_heads, key_dim=config.hidden_size // config.num_heads, dropout=config.dropout_rate
        )
        self.reflection_self_attn = tf.keras.layers.MultiHeadAttention(
            num_heads=config.cross_thinking_heads, key_dim=config.hidden_size // config.cross_thinking_heads, dropout=config.dropout_rate
        )
        self.primary_cross_attn = tf.keras.layers.MultiHeadAttention(
            num_heads=config.num_heads, key_dim=config.hidden_size // config.num_heads, dropout=config.dropout_rate
        )
        self.reflection_cross_attn = tf.keras.layers.MultiHeadAttention(
            num_heads=config.cross_thinking_heads, key_dim=config.hidden_size // config.cross_thinking_heads, dropout=config.dropout_rate
        )
        self.primary_mlp = tf.keras.Sequential([
            tf.keras.layers.Dense(config.intermediate_size, activation=tf.keras.activations.gelu),
            tf.keras.layers.Dropout(config.dropout_rate),
            tf.keras.layers.Dense(config.hidden_size),
        ])
        self.reflection_mlp = tf.keras.Sequential([
            tf.keras.layers.Dense(config.intermediate_size, activation=tf.keras.activations.gelu),
            tf.keras.layers.Dropout(config.dropout_rate),
            tf.keras.layers.Dense(config.hidden_size),
        ])
        self.primary_norms = [tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon) for _ in range(3)]
        self.reflection_norms = [tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon) for _ in range(3)]
        self.gate = tf.keras.layers.Dense(config.hidden_size, activation="sigmoid")

    def call(self, inputs: Dict[str, tf.Tensor], training: bool = False) -> Dict[str, tf.Tensor]:
        primary = inputs["primary"]
        reflection = inputs["reflection"]
        mask = inputs.get("attention_mask")
        attn_mask = None
        if mask is not None:
            attn_mask = build_attention_mask(mask)

        # Primary stream self-attention
        p_norm = self.primary_norms[0](primary)
        p_self = self.primary_self_attn(p_norm, p_norm, attention_mask=attn_mask, training=training)
        primary = primary + p_self

        # Reflection stream self-attention
        r_norm = self.reflection_norms[0](reflection)
        r_self = self.reflection_self_attn(r_norm, r_norm, attention_mask=attn_mask, training=training)
        reflection = reflection + r_self

        # Cross attention between streams
        p_cross_norm = self.primary_norms[1](primary)
        r_cross_norm = self.reflection_norms[1](reflection)
        p_cross = self.primary_cross_attn(p_cross_norm, r_cross_norm, attention_mask=attn_mask, training=training)
        r_cross = self.reflection_cross_attn(r_cross_norm, p_cross_norm, attention_mask=attn_mask, training=training)
        primary = primary + p_cross
        reflection = reflection + r_cross

        # Feedforward
        p_ffn = self.primary_norms[2](primary)
        r_ffn = self.reflection_norms[2](reflection)
        primary = primary + self.primary_mlp(p_ffn, training=training)
        reflection = reflection + self.reflection_mlp(r_ffn, training=training)

        # Gated fusion leak from reflection to primary
        gate = self.gate(primary)
        primary = primary + gate * reflection
        return {"primary": primary, "reflection": reflection}

In [None]:
class Lag1Decoder(tf.keras.Model):
    def __init__(self, config: Lag1Config):
        super().__init__(name="lag1_decoder")
        self.config = config
        self.embed_tokens = tf.keras.layers.Embedding(config.vocab_size, config.hidden_size)
        self.positional = SinusoidalPositionEmbedding(config.max_position_embeddings, config.hidden_size)
        self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
        self.blocks = [CrossThinkingBlock(config, name=f"lag1_block_{i}") for i in range(config.num_layers)]
        self.final_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon)
        self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False)

    def call(self, inputs: Dict[str, tf.Tensor], training: bool = False) -> Dict[str, tf.Tensor]:
        input_ids = inputs["input_ids"]
        attention_mask = inputs.get("attention_mask")
        x = self.embed_tokens(input_ids)
        x = x + self.positional(x)
        x = self.dropout(x, training=training)

        primary = x
        reflection = x
        for block in self.blocks:
            outputs = block({"primary": primary, "reflection": reflection, "attention_mask": attention_mask}, training=training)
            primary, reflection = outputs["primary"], outputs["reflection"]

        hidden = self.final_norm(primary)
        logits = self.lm_head(hidden)
        return {"logits": logits, "hidden_states": hidden, "reflection": reflection}

In [None]:
class Lag1Trainer:
    def __init__(self, config: Lag1Config):
        self.config = config
        self.model = Lag1Decoder(config)
        self.loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
        self.optimizer = tf.keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=0.01, beta_1=0.9, beta_2=0.95, epsilon=1e-8)
        self.train_loss = tf.keras.metrics.Mean(name="train_loss")
        self.val_loss = tf.keras.metrics.Mean(name="val_loss")

    @tf.function
    def train_step(self, inputs: tf.Tensor, masks: tf.Tensor, labels: tf.Tensor):
        with tf.GradientTape() as tape:
            outputs = self.model({"input_ids": inputs, "attention_mask": masks}, training=True)
            logits = outputs["logits"]
            loss = self._compute_loss(labels, logits)
        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
        self.train_loss.update_state(loss)
        return loss

    @tf.function
    def val_step(self, inputs: tf.Tensor, masks: tf.Tensor, labels: tf.Tensor):
        outputs = self.model({"input_ids": inputs, "attention_mask": masks}, training=False)
        logits = outputs["logits"]
        loss = self._compute_loss(labels, logits)
        self.val_loss.update_state(loss)
        return loss

    def _compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
        loss = self.loss_fn(labels, logits)
        mask = tf.cast(labels != -100, tf.float32)
        loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)
        return loss

    def fit(self, train_ds: tf.data.Dataset, val_ds: Optional[tf.data.Dataset] = None, epochs: int = 1, steps_per_epoch: Optional[int] = None):
        history = {"train_loss": [], "val_loss": []}
        for epoch in range(epochs):
            self.train_loss.reset_state()
            self.val_loss.reset_state()
            for step, (inp, mask, labels) in enumerate(train_ds):
                loss = self.train_step(inp, mask, labels)
                if steps_per_epoch and step >= steps_per_epoch:
                    break
            history["train_loss"].append(self.train_loss.result().numpy())

            if val_ds is not None:
                for val_inp, val_mask, val_labels in val_ds:
                    self.val_step(val_inp, val_mask, val_labels)
                history["val_loss"].append(self.val_loss.result().numpy())
        return history

    def save(self, export_dir: str):
        os.makedirs(export_dir, exist_ok=True)
        dummy_inputs = {
            "input_ids": tf.zeros((1, self.config.max_position_embeddings), dtype=tf.int32),
            "attention_mask": tf.ones((1, self.config.max_position_embeddings), dtype=tf.int32),
        }
        tf.saved_model.save(self.model, export_dir, signatures=self.model.call.get_concrete_function(dummy_inputs))

## Training Loop & Evaluation
Use `Lag1Trainer` with TensorFlow `tf.data` pipelines. Hooks provided for perplexity tracking and Drive checkpointing.

In [None]:
def compute_perplexity(logits: np.ndarray, labels: np.ndarray) -> float:
    logits_tf = tf.convert_to_tensor(logits)
    labels_tf = tf.convert_to_tensor(labels)
    mask = labels_tf != -100
    losses = tf.keras.losses.sparse_categorical_crossentropy(labels_tf, logits_tf, from_logits=True)
    losses = tf.where(mask, losses, 0.0)
    total_loss = tf.reduce_sum(losses)
    token_count = tf.reduce_sum(tf.cast(mask, tf.float32))
    ppl = tf.exp(total_loss / tf.maximum(token_count, 1.0))
    return float(ppl.numpy())

def evaluate_dataset(model: Lag1Decoder, dataset: tf.data.Dataset) -> Dict[str, float]:
    losses = []
    perplexities = []
    for inputs, masks, labels in dataset:
        outputs = model({"input_ids": inputs, "attention_mask": masks}, training=False)
        logits = outputs["logits"].numpy()
        loss_tensor = tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
        mask = labels != -100
        loss_tensor = tf.where(mask, loss_tensor, 0.0)
        loss = tf.reduce_sum(loss_tensor) / tf.reduce_sum(tf.cast(mask, tf.float32))
        losses.append(float(loss.numpy()))
        perplexities.append(compute_perplexity(logits, labels.numpy()))
    return {
        "loss": float(np.mean(losses)) if losses else float("nan"),
        "perplexity": float(np.mean(perplexities)) if perplexities else float("nan"),
    }

## Cross-Thinking Inference
Generate multiple candidate continuations, score them via reflection stream alignment, and optionally perform self-reflection passes before final decoding.

In [None]:
def generate_with_cross_thinking(
    model: Lag1Decoder,
    tokenizer_manager: TokenizerManager,
    prompt: str,
    config: Lag1Config,
    max_new_tokens: int = 128,
    num_paths: int = 3,
    temperature: float = 0.8,
    top_k: int = 40,
    top_p: float = 0.9,
    reflection_strength: float = 0.3,
    seed: Optional[int] = None,
    thought_cache: Optional[List[Dict[str, Any]]] = None,
    log_reflections: bool = True,
    return_all_paths: bool = False,
    ):
    if seed is not None:
        tf.keras.utils.set_random_seed(seed)

    enc = tokenizer_manager.encode_batch([prompt])
    input_ids = tf.convert_to_tensor(enc["input_ids"])
    attention_mask = tf.convert_to_tensor(enc["attention_mask"])
    generated_paths = []
    reflection_logs = []

    for path in range(num_paths):
        cur_input_ids = tf.identity(input_ids)
        cur_attention = tf.identity(attention_mask)
        cur_outputs = []
        reflections = []
        for _ in range(max_new_tokens):
            outputs = model({"input_ids": cur_input_ids, "attention_mask": cur_attention}, training=False)
            logits = outputs["logits"][:, -1, :] / max(temperature, 1e-5)
            reflection_state = outputs["reflection"][:, -1, :]
            if thought_cache is not None:
                thought_cache.append({"reflection": reflection_state.numpy(), "step": len(cur_outputs)})
            if reflection_strength > 0:
                reflection_logits = model.lm_head(reflection_state)
                logits = logits + reflection_strength * reflection_logits
            filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
            probs = tf.nn.softmax(filtered_logits, axis=-1)
            next_token = tf.random.categorical(tf.math.log(probs), num_samples=1)
            token_id = int(next_token.numpy()[0][0])
            cur_outputs.append(token_id)
            cur_input_ids = tf.concat([cur_input_ids, next_token], axis=1)
            next_mask = tf.ones_like(next_token)
            cur_attention = tf.concat([cur_attention, next_mask], axis=1)
            if token_id == config.eos_token_id:
                break
            if log_reflections:
                reflections.append(reflection_state.numpy().tolist())
        generated_paths.append(cur_outputs)
        reflection_logs.append(reflections)

    def decode(tokens: List[int]) -> str:
        if tokenizer_manager.is_hf_tokenizer:
            return tokenizer_manager.tokenizer.decode(tokens, skip_special_tokens=True)
        return " ".join(map(str, tokens))

    decoded_paths = [decode(tokens) for tokens in generated_paths]
    if return_all_paths:
        return {
            "paths": decoded_paths,
            "reflection_logs": reflection_logs,
        }
    best_idx = 0
    return {
        "completion": decoded_paths[best_idx],
        "reflection_logs": reflection_logs[best_idx],
        "all_paths": decoded_paths if log_reflections else None,
    }

In [None]:
def top_k_top_p_filtering(logits: tf.Tensor, top_k: int = 40, top_p: float = 0.9) -> tf.Tensor:
    """Apply top-k and top-p filtering to logits."""
    if top_k > 0:
        values, _ = tf.math.top_k(logits, k=top_k)
        min_values = values[:, -1, tf.newaxis]
        logits = tf.where(logits < min_values, tf.fill(tf.shape(logits), -1e9), logits)
    if top_p < 1.0:
        sorted_logits = tf.sort(logits, direction="DESCENDING", axis=-1)
        sorted_probs = tf.nn.softmax(sorted_logits, axis=-1)
        cumulative_probs = tf.cumsum(sorted_probs, axis=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove = tf.concat([tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, :-1]], axis=-1)
        sorted_logits = tf.where(sorted_indices_to_remove, tf.fill(tf.shape(sorted_logits), -1e9), sorted_logits)
        # map back to original indices
        sorted_indices = tf.argsort(logits, direction="DESCENDING", axis=-1)
        logits = tf.gather(sorted_logits, sorted_indices, batch_dims=1, axis=-1)
    return logits

## Quick CPU Smoke Test
Utility to verify the model builds and runs on a toy batch; uses a tiny config for rapid checks in environments without GPUs.

In [None]:
def smoke_test_lag1():
    """Quick CPU test to verify model builds and generates text."""
    print("🔧 Setting up tiny LAG1 config for CPU smoke test...")
    tiny_config = Lag1Config(
        vocab_size=1000,
        max_position_embeddings=64,
        hidden_size=128,
        num_layers=2,
        num_heads=4,
        cross_thinking_heads=2,
        intermediate_size=256,
    )
    
    # Initialize tokenizer and model
    tokenizer_mgr = TokenizerManager(tiny_config, tokenizer_name="gpt2")
    tokenizer_mgr.load()
    model = Lag1Decoder(tiny_config)
    
    # Test forward pass
    test_text = "Hello, I am LAG and I can"
    encoded = tokenizer_mgr.encode_batch([test_text])
    inputs = {
        "input_ids": tf.convert_to_tensor(encoded["input_ids"]),
        "attention_mask": tf.convert_to_tensor(encoded["attention_mask"])
    }
    
    print("🚀 Testing forward pass...")
    outputs = model(inputs, training=False)
    print(f"✅ Logits shape: {outputs['logits'].shape}")
    print(f"✅ Reflection shape: {outputs['reflection'].shape}")
    
    # Test generation
    print("🎯 Testing cross-thinking generation...")
    result = generate_with_cross_thinking(
        model=model,
        tokenizer_manager=tokenizer_mgr,
        prompt=test_text,
        config=tiny_config,
        max_new_tokens=20,
        num_paths=2,
        temperature=0.8,
        seed=42
    )
    
    print(f"📝 Generated completion: {result['completion']}")
    print(f"🧠 Reflection logs available: {len(result['reflection_logs']) > 0}")
    print("✅ Smoke test passed! Model builds and generates.")
    
    return model, tokenizer_mgr

# Uncomment to run the smoke test
# smoke_test_lag1()

## Training Demo & Usage Guide
Complete walkthrough for training LAG1 from data loading to model saving, with checkpointing and evaluation hooks.

In [None]:
def full_training_pipeline_demo():
    """Complete LAG1 training pipeline demonstration."""
    print("🚀 LAG1 Training Pipeline Demo")
    print("=" * 40)
    
    # Step 1: Configuration
    print("🔧 1. Setting up LAG1 configuration...")
    config = Lag1Config(
        vocab_size=32000,
        max_position_embeddings=512,  # Shorter for demo
        hidden_size=768,
        num_layers=6,  # Smaller for faster training
        num_heads=12,
        cross_thinking_heads=6,
        intermediate_size=2048,
        dropout_rate=0.1,
    )
    print(f"✅ Config ready: {config.num_layers} layers, {config.hidden_size}d")
    
    # Step 2: Tokenizer setup
    print("🔠 2. Loading tokenizer...")
    tokenizer_mgr = TokenizerManager(config, tokenizer_name="gpt2")
    tokenizer_mgr.load()
    print("✅ Tokenizer loaded")
    
    # Step 3: Dataset preparation
    print("📚 3. Loading datasets...")
    try:
        # Load a subset for demo
        raw_datasets = load_raw_datasets(auth_token=None, streaming=False)
        train_ds = build_weighted_tf_dataset(
            raw_datasets=raw_datasets,
            tokenizer_manager=tokenizer_mgr,
            config=config,
            batch_size=2,
            max_per_source=50,  # Small demo dataset
        )
        print("✅ Training dataset ready")
    except Exception as e:
        print(f"⚠️ Dataset loading failed: {e}")
        print("Creating dummy dataset for demo...")
        dummy_samples = [
            Sample(
                system="You are LAG, a helpful assistant.",
                user="What is 2+2?",
                assistant="2+2 equals 4. This is a basic arithmetic operation.",
            ),
            Sample(
                system="You are LAG, good at coding.",
                user="Write a Python function to add two numbers.",
                assistant="def add(a, b):\n    return a + b\n\nThis function takes two parameters and returns their sum.",
            ),
        ]
        builder = Lag1DatasetBuilder(tokenizer_mgr, config)
        train_ds = builder.build_tf_dataset(dummy_samples, batch_size=1)
    
    # Step 4: Model and trainer setup
    print("🤖 4. Initializing model and trainer...")
    trainer = Lag1Trainer(config)
    print("✅ Trainer ready")
    
    # Step 5: Training loop
    print("🏃 5. Starting training (1 epoch demo)...")
    history = trainer.fit(
        train_ds=train_ds,
        val_ds=None,
        epochs=1,
        steps_per_epoch=5  # Very short demo
    )
    print(f"✅ Training complete! Final loss: {history['train_loss'][-1]:.4f}")
    
    # Step 6: Evaluation
    print("📈 6. Evaluating model...")
    eval_metrics = evaluate_dataset(trainer.model, train_ds.take(2))
    print(f"✅ Evaluation - Loss: {eval_metrics['loss']:.4f}, Perplexity: {eval_metrics['perplexity']:.2f}")
    
    # Step 7: Generation test
    print("💬 7. Testing generation with cross-thinking...")
    test_prompt = "Explain the concept of machine learning in simple terms."
    result = generate_with_cross_thinking(
        model=trainer.model,
        tokenizer_manager=tokenizer_mgr,
        prompt=test_prompt,
        config=config,
        max_new_tokens=50,
        num_paths=2,
        temperature=0.7,
        reflection_strength=0.2,
        seed=42
    )
    print(f"📝 Generated response: {result['completion']}")
    
    # Step 8: Model saving
    print("💾 8. Saving model...")
    save_path = "/tmp/lag1_demo_model"
    trainer.save(save_path)
    print(f"✅ Model saved to {save_path}")
    
    print("\n🎉 Training pipeline demo complete!")
    return trainer, tokenizer_mgr, history

# Uncomment to run the full demo
# trainer, tokenizer_mgr, history = full_training_pipeline_demo()

## Advanced Reasoning Extensions
LAG1's cross-thinking architecture enables several advanced reasoning patterns. Here are utilities for multi-step reasoning, self-correction, and thought chaining.

In [None]:
def multi_step_reasoning(
    model: Lag1Decoder,
    tokenizer_manager: TokenizerManager,
    problem: str,
    config: Lag1Config,
    max_steps: int = 5,
    step_template: str = "Step {step}: Let me think about this...\n",
    verification_template: str = "Let me verify this step: ",
    **generation_kwargs
):
    """Implements chain-of-thought reasoning with verification steps."""
    reasoning_chain = []
    current_context = problem
    
    for step in range(1, max_steps + 1):
        # Generate reasoning step
        step_prompt = current_context + "\n" + step_template.format(step=step)
        step_result = generate_with_cross_thinking(
            model=model,
            tokenizer_manager=tokenizer_manager,
            prompt=step_prompt,
            config=config,
            max_new_tokens=100,
            reflection_strength=0.4,  # Higher reflection for reasoning
            **generation_kwargs
        )
        
        step_reasoning = step_result['completion']
        reasoning_chain.append({
            "step": step,
            "reasoning": step_reasoning,
            "reflection_log": step_result['reflection_logs']
        })
        
        # Self-verification step
        verify_prompt = step_prompt + step_reasoning + "\n" + verification_template
        verify_result = generate_with_cross_thinking(
            model=model,
            tokenizer_manager=tokenizer_manager,
            prompt=verify_prompt,
            config=config,
            max_new_tokens=50,
            reflection_strength=0.6,  # Even higher for verification
            **generation_kwargs
        )
        
        reasoning_chain[-1]['verification'] = verify_result['completion']
        current_context = step_prompt + step_reasoning + "\n" + verify_result['completion']
        
        # Check if reasoning is complete
        if any(keyword in step_reasoning.lower() for keyword in ['therefore', 'conclusion', 'final answer']):
            break
    
    return {
        "problem": problem,
        "reasoning_chain": reasoning_chain,
        "final_context": current_context
    }

def self_correction_loop(
    model: Lag1Decoder,
    tokenizer_manager: TokenizerManager,
    initial_answer: str,
    question: str,
    config: Lag1Config,
    max_corrections: int = 3,
    **generation_kwargs
):
    """Implements self-correction by having the model critique and improve its own answers."""
    corrections = []
    current_answer = initial_answer
    
    for correction_round in range(max_corrections):
        # Generate critique
        critique_prompt = f"Question: {question}\nAnswer: {current_answer}\n\nCritique this answer and identify any errors or improvements:"
        critique_result = generate_with_cross_thinking(
            model=model,
            tokenizer_manager=tokenizer_manager,
            prompt=critique_prompt,
            config=config,
            max_new_tokens=150,
            reflection_strength=0.7,
            **generation_kwargs
        )
        
        critique = critique_result['completion']
        
        # Generate improved answer
        improve_prompt = f"Question: {question}\nPrevious answer: {current_answer}\nCritique: {critique}\n\nProvide an improved answer:"
        improved_result = generate_with_cross_thinking(
            model=model,
            tokenizer_manager=tokenizer_manager,
            prompt=improve_prompt,
            config=config,
            max_new_tokens=200,
            reflection_strength=0.5,
            **generation_kwargs
        )
        
        improved_answer = improved_result['completion']
        
        corrections.append({
            "round": correction_round + 1,
            "critique": critique,
            "improved_answer": improved_answer,
            "critique_reflections": critique_result['reflection_logs'],
            "improvement_reflections": improved_result['reflection_logs']
        })
        
        # Check if significant improvement was made
        if len(improved_answer) < len(current_answer) * 1.1:  # Simple heuristic
            break
            
        current_answer = improved_answer
    
    return {
        "initial_answer": initial_answer,
        "final_answer": current_answer,
        "corrections": corrections
    }

def thought_tree_exploration(
    model: Lag1Decoder,
    tokenizer_manager: TokenizerManager,
    problem: str,
    config: Lag1Config,
    branch_factor: int = 3,
    max_depth: int = 3,
    **generation_kwargs
):
    """Explores multiple reasoning paths in a tree structure."""
    from collections import deque
    
    # Initialize tree with root problem
    tree = {
        "problem": problem,
        "branches": [],
        "depth": 0
    }
    
    queue = deque([tree])
    
    while queue:
        current_node = queue.popleft()
        
        if current_node["depth"] >= max_depth:
            continue
            
        # Generate multiple reasoning branches
        for branch_idx in range(branch_factor):
            branch_prompt = current_node["problem"] + f"\n\nApproach {branch_idx + 1}: "
            branch_result = generate_with_cross_thinking(
                model=model,
                tokenizer_manager=tokenizer_manager,
                prompt=branch_prompt,
                config=config,
                max_new_tokens=80,
                seed=42 + branch_idx,  # Different seeds for diversity
                **generation_kwargs
            )
            
            branch_node = {
                "approach": branch_idx + 1,
                "reasoning": branch_result['completion'],
                "problem": branch_prompt + branch_result['completion'],
                "depth": current_node["depth"] + 1,
                "branches": [],
                "reflections": branch_result['reflection_logs']
            }
            
            current_node["branches"].append(branch_node)
            queue.append(branch_node)
    
    return tree

def consensus_reasoning(
    model: Lag1Decoder,
    tokenizer_manager: TokenizerManager,
    problem: str,
    config: Lag1Config,
    num_agents: int = 5,
    **generation_kwargs
):
    """Generates multiple independent solutions and finds consensus."""
    agent_solutions = []
    
    for agent_id in range(num_agents):
        agent_prompt = f"Agent {agent_id + 1}, solve this problem: {problem}"
        solution = generate_with_cross_thinking(
            model=model,
            tokenizer_manager=tokenizer_manager,
            prompt=agent_prompt,
            config=config,
            seed=42 + agent_id,
            **generation_kwargs
        )
        
        agent_solutions.append({
            "agent_id": agent_id + 1,
            "solution": solution['completion'],
            "reflections": solution['reflection_logs']
        })
    
    # Generate consensus
    solutions_text = "\n\n".join([f"Agent {sol['agent_id']}: {sol['solution']}" for sol in agent_solutions])
    consensus_prompt = f"Problem: {problem}\n\nMultiple solutions:\n{solutions_text}\n\nAnalyze these solutions and provide a consensus answer:"
    
    consensus_result = generate_with_cross_thinking(
        model=model,
        tokenizer_manager=tokenizer_manager,
        prompt=consensus_prompt,
        config=config,
        max_new_tokens=200,
        reflection_strength=0.8,
        **generation_kwargs
    )
    
    return {
        "problem": problem,
        "agent_solutions": agent_solutions,
        "consensus": consensus_result['completion'],
        "consensus_reflections": consensus_result['reflection_logs']
    }

## Production Deployment Utilities
Helper functions for converting trained LAG1 models to production-ready formats and setting up inference servers.

In [None]:
def create_inference_server_code(model_path: str, tokenizer_name: str = "gpt2"):
    """Generate FastAPI server code for LAG1 inference."""
    server_code = f'''
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import tensorflow as tf
import json
from typing import Optional, List, Dict, Any

# Import LAG1 components (assuming they're in a module)
from lag1_model import Lag1Config, Lag1Decoder, TokenizerManager, generate_with_cross_thinking

app = FastAPI(title="LAG1 Inference Server", version="1.0.0")

# Load model and tokenizer at startup
config = Lag1Config()  # Load from saved config
model = tf.saved_model.load("{model_path}")
tokenizer_mgr = TokenizerManager(config, tokenizer_name="{tokenizer_name}")
tokenizer_mgr.load()

class GenerationRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 128
    temperature: float = 0.8
    top_k: int = 40
    top_p: float = 0.9
    num_paths: int = 1
    reflection_strength: float = 0.3
    use_cross_thinking: bool = True
    seed: Optional[int] = None

class GenerationResponse(BaseModel):
    completion: str
    metadata: Dict[str, Any]

@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
    try:
        if request.use_cross_thinking:
            result = generate_with_cross_thinking(
                model=model,
                tokenizer_manager=tokenizer_mgr,
                prompt=request.prompt,
                config=config,
                max_new_tokens=request.max_new_tokens,
                temperature=request.temperature,
                top_k=request.top_k,
                top_p=request.top_p,
                num_paths=request.num_paths,
                reflection_strength=request.reflection_strength,
                seed=request.seed
            )
            return GenerationResponse(
                completion=result["completion"],
                metadata={{
                    "reflection_logs_available": len(result["reflection_logs"]) > 0,
                    "num_reflection_steps": len(result["reflection_logs"])
                }}
            )
        else:
            # Standard generation without cross-thinking
            # Implementation would go here
            raise HTTPException(status_code=501, detail="Standard generation not implemented")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {{"status": "healthy", "model_loaded": model is not None}}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
'''
    return server_code

def export_for_deployment(
    trainer: Lag1Trainer,
    tokenizer_manager: TokenizerManager,
    export_dir: str,
    include_server: bool = True
):
    """Export trained LAG1 model for production deployment."""
    import os
    os.makedirs(export_dir, exist_ok=True)
    
    # Save model
    model_path = os.path.join(export_dir, "model")
    trainer.save(model_path)
    print(f"✅ Model saved to {model_path}")
    
    # Save config
    config_path = os.path.join(export_dir, "config.json")
    with open(config_path, "w") as f:
        json.dump(asdict(trainer.config), f, indent=2)
    print(f"✅ Config saved to {config_path}")
    
    # Save tokenizer info
    tokenizer_info = {
        "tokenizer_name": tokenizer_manager.tokenizer_name,
        "vocab_size": tokenizer_manager.vocab_size,
        "is_hf_tokenizer": tokenizer_manager.is_hf_tokenizer
    }
    tokenizer_path = os.path.join(export_dir, "tokenizer_info.json")
    with open(tokenizer_path, "w") as f:
        json.dump(tokenizer_info, f, indent=2)
    print(f"✅ Tokenizer info saved to {tokenizer_path}")
    
    if include_server:
        # Generate server code
        server_code = create_inference_server_code(model_path, tokenizer_manager.tokenizer_name)
        server_path = os.path.join(export_dir, "server.py")
        with open(server_path, "w") as f:
            f.write(server_code)
        print(f"✅ Server code generated at {server_path}")
        
        # Generate requirements.txt
        requirements = [
            "tensorflow>=2.15.0",
            "transformers>=4.21.0",
            "fastapi>=0.100.0",
            "uvicorn>=0.23.0",
            "pydantic>=2.0.0",
            "datasets>=2.14.0"
        ]
        req_path = os.path.join(export_dir, "requirements.txt")
        with open(req_path, "w") as f:
            f.write("\n".join(requirements))
        print(f"✅ Requirements saved to {req_path}")
        
        # Generate deployment instructions
        deploy_instructions = f'''
# LAG1 Deployment Instructions

## Setup
1. Install dependencies:
   ```bash
   pip install -r requirements.txt
   ```

2. Ensure the LAG1 model components are available (copy from notebook).

## Running the Server
```bash
python server.py
```

## Testing
```bash
curl -X POST "http://localhost:8000/generate" \\
     -H "Content-Type: application/json" \\
     -d '{{
       "prompt": "Explain quantum computing",
       "max_new_tokens": 100,
       "temperature": 0.7,
       "use_cross_thinking": true
     }}'
```

## Health Check
```bash
curl http://localhost:8000/health
```
'''
        deploy_path = os.path.join(export_dir, "README_DEPLOYMENT.md")
        with open(deploy_path, "w") as f:
            f.write(deploy_instructions)
        print(f"✅ Deployment instructions saved to {deploy_path}")
    
    print(f"\n🎉 Deployment package ready in {export_dir}")
    return export_dir

def benchmark_model(
    model: Lag1Decoder,
    tokenizer_manager: TokenizerManager,
    config: Lag1Config,
    test_prompts: List[str],
    num_runs: int = 5
) -> Dict[str, Any]:
    """Benchmark LAG1 model performance."""
    import time
    import statistics
    
    results = {
        "total_prompts": len(test_prompts),
        "num_runs": num_runs,
        "per_prompt_stats": [],
        "overall_stats": {}
    }
    
    all_times = []
    all_tokens_per_sec = []
    
    for prompt in test_prompts:
        prompt_times = []
        for run in range(num_runs):
            start_time = time.time()
            result = generate_with_cross_thinking(
                model=model,
                tokenizer_manager=tokenizer_manager,
                prompt=prompt,
                config=config,
                max_new_tokens=50,
                num_paths=1
            )
            end_time = time.time()
            
            generation_time = end_time - start_time
            tokens_generated = len(tokenizer_manager.encode_batch([result['completion']])["input_ids"][0])
            tokens_per_sec = tokens_generated / generation_time if generation_time > 0 else 0
            
            prompt_times.append(generation_time)
            all_times.append(generation_time)
            all_tokens_per_sec.append(tokens_per_sec)
        
        results["per_prompt_stats"].append({
            "prompt": prompt[:50] + "..." if len(prompt) > 50 else prompt,
            "avg_time": statistics.mean(prompt_times),
            "min_time": min(prompt_times),
            "max_time": max(prompt_times),
            "std_time": statistics.stdev(prompt_times) if len(prompt_times) > 1 else 0
        })
    
    results["overall_stats"] = {
        "avg_generation_time": statistics.mean(all_times),
        "avg_tokens_per_sec": statistics.mean(all_tokens_per_sec),
        "min_generation_time": min(all_times),
        "max_generation_time": max(all_times),
        "std_generation_time": statistics.stdev(all_times) if len(all_times) > 1 else 0
    }
    
    return results

## 🚀 Getting Started with LAG1

### Quick Start
1. **Run the smoke test** to verify everything works:
   ```python
   smoke_test_lag1()
   ```

2. **Load your datasets** (authenticate for gated ones):
   ```python
   raw_datasets = load_raw_datasets(auth_token=None)  # or your HF token
   ```

3. **Train a model**:
   ```python
   full_training_pipeline_demo()
   ```

### Advanced Features

**Multi-Step Reasoning**: Chain thoughts with verification
```python
reasoning_result = multi_step_reasoning(model, tokenizer_mgr, "Solve 2x + 5 = 13", config)
```

**Self-Correction**: Let the model improve its own answers
```python
corrected = self_correction_loop(model, tokenizer_mgr, initial_answer, question, config)
```

**Consensus Reasoning**: Multiple agents collaborate
```python
consensus = consensus_reasoning(model, tokenizer_mgr, "Explain photosynthesis", config)
```

### Production Deployment
```python
# Export trained model for production
export_for_deployment(trainer, tokenizer_mgr, "/path/to/export", include_server=True)

# Benchmark performance
test_prompts = ["Hello, how are you?", "Explain AI", "Write Python code"]
benchmark_results = benchmark_model(model, tokenizer_mgr, config, test_prompts)
```

### Key Features
- **🧠 Cross-Thinking Architecture**: Dual-stream reasoning with primary and reflection paths
- **🔄 Multi-Path Generation**: Explore multiple reasoning approaches simultaneously
- **🛠️ Flexible Tokenization**: Works with Hugging Face tokenizers or TensorFlow Text
- **📊 Comprehensive Datasets**: Supports 6 diverse datasets for training
- **🚀 Production Ready**: Includes FastAPI server generation and deployment utilities
- **🔧 Advanced Reasoning**: Chain-of-thought, self-correction, consensus mechanisms

### Next Steps
1. **Experiment with hyperparameters** in `Lag1Config`
2. **Add custom datasets** by extending `DATASET_SPECS`
3. **Implement new reasoning patterns** using the cross-thinking utilities
4. **Scale up training** on larger datasets and longer sequences
5. **Deploy to production** using the generated FastAPI server

### Troubleshooting
- **Authentication errors**: Run `huggingface-cli login` or provide a valid token
- **Memory issues**: Reduce `batch_size`, `max_position_embeddings`, or `num_layers`
- **Slow training**: Enable mixed precision, use TPUs on Colab, or implement gradient accumulation
- **Poor generation quality**: Increase model size, train longer, or adjust generation parameters

In [None]:
class SinusoidalPositionEmbedding(tf.keras.layers.Layer):
    def __init__(self, max_len: int, dim: int, name: Optional[str] = None):
        super().__init__(name=name)
        position = np.arange(max_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, dim, 2) * -(math.log(10000.0) / dim))
        self.pe = np.zeros((max_len, dim))
        self.pe[:, 0::2] = np.sin(position * div_term)
        self.pe[:, 1::2] = np.cos(position * div_term)
        self.pe = self.pe[np.newaxis]
    def call(self, x: tf.Tensor) -> tf.Tensor:
        length = tf.shape(x)[1]
        return tf.cast(self.pe[:, :length, :], x.dtype)

## Multi-Stream Decoder Architecture
The cross-thinking block runs primary and reflection streams in parallel, exchanging context via cross-attention and gated fusion. Rotary or learned position embeddings keep both streams aligned.

In [None]:
import json
import math
import os
import random
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import tensorflow as tf
from datasets import load_dataset
try:
    import tensorflow_text as tf_text  # noqa: F401
except ModuleNotFoundError:
    tf_text = None
try:
    from transformers import AutoTokenizer
    HF_AVAILABLE = True
except ModuleNotFoundError:
    AutoTokenizer = None
    HF_AVAILABLE = False
print({
    "tensorflow_version": tf.__version__,
    "eager_execution": tf.executing_eagerly(),
    "tf_text_available": tf_text is not None,
    "transformers_available": HF_AVAILABLE,
    "gpu_available": tf.config.list_physical_devices("GPU")
})
tf.keras.mixed_precision.set_global_policy("mixed_float16")