# Agentic Self-Learning (IWM + Self-Reflection) on **SmolLM3-3B**

This notebook demonstrates a small, **Colab-friendly** agentic self-learning loop that works with **HuggingFaceTB/SmolLM3-3B**:

- **Rollouts** in a tiny text-based transactional environment  
- **IWM (Implicit World Modeling):** train the LLM to predict *next state* given `(state, action)`  
- **Self-Reflection (SR):** train short rationales + target actions from simple validators (no human labels)  
- **Adapters (LoRA)** via `peft` to keep memory/compute light  
- **Quantization (4-bit)** via `bitsandbytes` for a T4/Colab GPU  

> Goal: provide a working skeleton that you can swap into your own stack (e.g., ZRIA/sidecars).


## Setup

In [None]:

# If you're in Colab, enable a GPU: Runtime -> Change runtime type -> GPU (T4 is fine).
!nvidia-smi || echo "No GPU detected (CPU-only)."

# Install deps (pinned to reasonably fresh versions)
!pip -q install "transformers>=4.44.2" "datasets>=2.20.0" "accelerate>=0.34.2"                 "peft>=0.12.0" "bitsandbytes>=0.43.1" "trl>=0.11.4"                 "evaluate>=0.4.2" "wandb>=0.17.9"


Fri Oct 10 15:52:08 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   55C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## Imports & configuration

In [None]:

import os, math, json, random, copy, textwrap, time
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple

import torch
from datasets import Dataset, DatasetDict
from transformers import (AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
                          DataCollatorForLanguageModeling, Trainer, TrainingArguments)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers.trainer_utils import set_seed

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
set_seed(SEED)

MODEL_NAME = "HuggingFaceTB/SmolLM3-3B"   # main model
USE_4BIT = True                            # int4 to fit in Colab
MAX_NEW_TOKENS = 64                        # decoding cap for actions/rationales
CTX_LEN = 2048
BATCH_SIZE = 2
GRAD_ACCUM = 4
LR = 2e-4
EPOCHS_IWM = 1
EPOCHS_SR  = 1

print(f"Device: {DEVICE}")


Device: cuda


## Load model & tokenizer (optionally 4-bit)

In [None]:

bnb_config = BitsAndBytesConfig(
    load_in_4bit=USE_4BIT,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config if USE_4BIT else None
)
model.config.use_cache = False  # better for training
model.resize_token_embeddings(len(tokenizer))
print("Loaded model & tokenizer.")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/289 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/182 [00:00<?, ?B/s]

Loaded model & tokenizer.


## Tiny transactional environment

A minimal **shopping flow** with a budget constraint.
States and next states are **serialized as text** so the LLM can learn to predict them.


In [None]:

class TinyShopEnv:
    """
    Minimal deterministic environment for illustration:
    - catalog: (color, price) for a single item
    - user wants: color preference and budget
    Actions are textual commands, e.g.:
        ACTION: SELECT color=blue
        ACTION: BUY
    """
    def __init__(self, budget: int = 30, want_color: str = "blue"):
        self.catalog = {"red": 35, "blue": 25, "green": 20}
        self.budget = budget
        self.want_color = want_color
        self.reset()

    def reset(self):
        self.cart_color = None
        self.done = False
        return self._obs()

    def _obs(self) -> str:
        return json.dumps({
            "view": "product_page",
            "catalog": self.catalog,
            "cart": {"color": self.cart_color, "price": self.catalog.get(self.cart_color) if self.cart_color else None},
            "constraints": {"budget": self.budget},
            "preference": {"color": self.want_color}
        })

    def step(self, action: str):
        info = {"ok": True, "errors": []}
        if self.done:
            info["ok"] = False
            info["errors"].append("episode_done")
            return self._obs(), info

        a = action.strip().upper()
        if a.startswith("ACTION: SELECT "):
            kv = a.replace("ACTION: SELECT", "").strip()
            # crude parse: color=xxx
            if "COLOR=" in kv:
                sel = kv.split("COLOR=")[1].strip()
                sel = sel.lower()
                if sel in self.catalog:
                    self.cart_color = sel
                else:
                    info["ok"] = False
                    info["errors"].append("unknown_color")
            else:
                info["ok"] = False
                info["errors"].append("bad_select_format")

        elif a == "ACTION: BUY":
            if self.cart_color is None:
                info["ok"] = False
                info["errors"].append("no_selection")
            else:
                price = self.catalog[self.cart_color]
                if price > self.budget:
                    info["ok"] = False
                    info["errors"].append("over_budget")
                else:
                    info["bought"] = {"color": self.cart_color, "price": price}
                    self.done = True
        else:
            info["ok"] = False
            info["errors"].append("unknown_action")

        return self._obs(), info


def validator(state: str, action: str, next_state: str, info):
    """
    Simple rule/constraint checker used as a labeler for SR targets.
    Returns a dict indicating violations and a short 'hint' we can use for rationale.
    """
    st = json.loads(state)
    ns = json.loads(next_state)
    budget = st["constraints"]["budget"]
    cart_color = ns["cart"]["color"]
    violations = []
    if "errors" in info and info["errors"]:
        for e in info["errors"]:
            violations.append(e)
    if ns["cart"]["price"] and ns["cart"]["price"] > budget:
        violations.append("over_budget")
    want_color = st["preference"]["color"]
    hint = []
    if "over_budget" in violations:
        hint.append("over budget")
    if cart_color and cart_color != want_color:
        hint.append(f"wrong color (wants {want_color})")
    return {
        "violations": violations,
        "hint": ", ".join(hint) if hint else "ok"
    }


## Policy helpers (sampling actions & alternatives)

In [None]:
# ==== Constrained policy: choose from a finite ACTION_SET via log-prob scoring ====

ACTION_SET = [
    "ACTION: SELECT color=red",
    "ACTION: SELECT color=blue",
    "ACTION: SELECT color=green",
    "ACTION: BUY",
]

def format_action_prompt(state_text: str) -> str:
    return (f"<STATE>\n{state_text}\n</STATE>\n"
            "Decide the next action.\n"
            "Respond with exactly one of the following lines ONLY:\n"
            + "\n".join(ACTION_SET))

@torch.inference_mode()
def score_action(model, tokenizer, prompt: str, candidate: str) -> float:
    """
    Returns total log-prob of 'candidate' when appended after 'prompt'.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with tokenizer.as_target_tokenizer():
        cand_ids = tokenizer(candidate, add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device)

    # Concatenate prompt + candidate, get logits, compute log-probs over candidate tokens
    full_ids = torch.cat([inputs["input_ids"], cand_ids], dim=1)
    out = model(input_ids=full_ids)
    logits = out.logits  # [B, T, V]
    # We want log-probs for positions corresponding to the candidate
    cand_len = cand_ids.shape[1]
    start = full_ids.shape[1] - cand_len
    logits_cand = logits[:, start-1:-1, :]  # shift for next-token prediction
    # Gather log-probs
    log_probs = torch.nn.functional.log_softmax(logits_cand, dim=-1)
    gather = log_probs.gather(-1, cand_ids.unsqueeze(-1)).squeeze(-1)
    return gather.sum().item()

@torch.inference_mode()
def choose_action(model, tokenizer, state_text: str) -> str:
    prompt = format_action_prompt(state_text)
    scores = []
    for a in ACTION_SET:
        s = score_action(model, tokenizer, prompt, a)
        scores.append((s, a))
    scores.sort(reverse=True)
    return scores[0][1]

@torch.inference_mode()
def choose_alternatives(model, tokenizer, state_text: str, k: int = 2):
    """Top-k alternatives different from the best one (for SR/IWM branching)."""
    prompt = format_action_prompt(state_text)
    scored = [(score_action(model, tokenizer, prompt, a), a) for a in ACTION_SET]
    scored.sort(reverse=True)
    # skip the best, return next k
    return [a for (_, a) in scored[1:1+k]]

## Rollout collection

In [None]:
def rollout(model, tokenizer, episodes: int = 20, K_alts: int = 2, budget:int=30, want_color:str="blue"):
    logs = []
    for ep in range(episodes):
        env = TinyShopEnv(budget=budget, want_color=want_color)
        s = env.reset()
        for t in range(6):
            a = choose_action(model, tokenizer, s)
            s_prime, info = env.step(a)

            alts = []
            for a_alt in choose_alternatives(model, tokenizer, s, K_alts):
                env2 = copy.deepcopy(env)
                s_alt, info_alt = env2.step(a_alt)
                alts.append({"a_alt": a_alt, "s_prime_alt": s_alt, "info_alt": info_alt})

            v = validator(s, a, s_prime, info)
            logs.append({
                "episode_id": f"ep_{ep}",
                "t": t,
                "state_s": s,
                "action_a": a,
                "s_prime": s_prime,
                "info": info,
                "alts": alts,
                "validation": v
            })
            s = s_prime
            if env.done:
                break
    return logs

Collected 144 transitions.
['episode_id', 't', 'state_s', 'action_a', 's_prime', 'info', 'alts', 'validation']


## Build training datasets (IWM & SR)

In [None]:

def ex_iwm(example):
    # Input: [STATE] s  [ACTION] a  -> Target: [NEXT] s'
    inp = f"<STATE>\n{example['state_s']}\n</STATE>\n<ACTION>\n{example['action_a']}\n</ACTION>\n<NEXT>\n"
    tgt = example['s_prime']
    return {"text": inp + tgt}

def choose_expert_action(example):
    # Use simple validator hints: pick an action that leads toward budget+preference
    st = json.loads(example["state_s"])
    want = st["preference"]["color"]
    budget = st["constraints"]["budget"]
    # If current action is already good, call it expert. Otherwise, try to pick from alts.
    ns = json.loads(example["s_prime"])
    curr_price = ns["cart"]["price"]
    curr_color = ns["cart"]["color"]
    good = True
    if curr_color is None:
        good = False
    else:
        if curr_price is None or curr_price > budget or curr_color != want:
            good = False

    if good:
        return example["action_a"]

    # try alternatives
    for alt in example["alts"]:
        ns2 = json.loads(alt["s_prime_alt"])
        color2 = ns2["cart"]["color"]
        price2 = ns2["cart"]["price"]
        if color2 is not None and price2 is not None and price2 <= budget and color2 == want:
            # infer what action was used
            return alt["a_alt"]

    # fallback heuristics
    if curr_color != want:
        return f"ACTION: SELECT color={want}"
    if curr_price is not None and curr_price <= budget:
        return "ACTION: BUY"
    return example["action_a"]

def build_sr_text(example):
    expert = choose_expert_action(example)
    hint = example["validation"]["hint"]
    rationale = f"Because {hint}, choose the action that satisfies budget and preference."
    prompt = f"<STATE>\n{example['state_s']}\n</STATE>\n<REFLECTION>\n{rationale}\n</REFLECTION>\n<TARGET_ACTION>\n{expert}"
    return {"text": prompt}

from datasets import Dataset
iwm_texts = [ex_iwm(ex) for ex in logs]
sr_texts  = [build_sr_text(ex) for ex in logs]

ds_iwm = Dataset.from_list(iwm_texts)
ds_sr  = Dataset.from_list(sr_texts)
print(ds_iwm[0]["text"][:280])
print("---")
print(ds_sr[0]["text"][:280])


<STATE>
{"view": "product_page", "catalog": {"red": 35, "blue": 25, "green": 20}, "cart": {"color": null, "price": null}, "constraints": {"budget": 30}, "preference": {"color": "blue"}}
</STATE>
<ACTION>
product with color=<red|blue|green> if constraints are satisfied and prefere
---
<STATE>
{"view": "product_page", "catalog": {"red": 35, "blue": 25, "green": 20}, "cart": {"color": null, "price": null}, "constraints": {"budget": 30}, "preference": {"color": "blue"}}
</STATE>
<REFLECTION>
Because ok, choose the action that satisfies budget and preference.
</RE


## Tokenize & prepare loaders

In [None]:

def tokenize_fn(batch):
    return tokenizer(batch["text"], truncation=True, max_length=CTX_LEN)

tok_iwm = ds_iwm.map(tokenize_fn, batched=True, remove_columns=["text"])
tok_sr  = ds_sr.map(tokenize_fn, batched=True, remove_columns=["text"])

collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
print(tok_iwm)
print(tok_sr)


Map:   0%|          | 0/144 [00:00<?, ? examples/s]

Map:   0%|          | 0/144 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 144
})
Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 144
})


## Prepare PEFT/LoRA adapters

In [None]:

model = prepare_model_for_kbit_training(model)
lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"]
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()


trainable params: 30,228,480 || all params: 3,105,327,104 || trainable%: 0.9734


## Stage A: IWM warm-up (predict next state)

In [None]:

out_dir_iwm = "smollm3_iwm_lora"
args = TrainingArguments(
    output_dir=out_dir_iwm,
    num_train_epochs=EPOCHS_IWM,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    fp16=torch.cuda.is_available(),
    logging_steps=10,
    save_steps=1000,
    save_total_limit=2,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    report_to=[]
)
trainer_iwm = Trainer(
    model=model,
    args=args,
    train_dataset=tok_iwm,
    data_collator=collator
)
trainer_iwm.train()
model.save_pretrained(out_dir_iwm)
tokenizer.save_pretrained(out_dir_iwm)
print("IWM training complete.")


  return fn(*args, **kwargs)


Step,Training Loss
10,0.7847


IWM training complete.


## Stage B: SR fine-tune (short rationale + target action)

In [None]:

out_dir_sr = "smollm3_sr_lora"
args2 = TrainingArguments(
    output_dir=out_dir_sr,
    num_train_epochs=EPOCHS_SR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    fp16=torch.cuda.is_available(),
    logging_steps=10,
    save_steps=1000,
    save_total_limit=2,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    report_to=[]
)
trainer_sr = Trainer(
    model=model,
    args=args2,
    train_dataset=tok_sr,
    data_collator=collator
)
trainer_sr.train()
model.save_pretrained(out_dir_sr)
tokenizer.save_pretrained(out_dir_sr)
print("SR training complete.")


## Evaluation: success/violation rates before vs after

In [None]:

def run_eval(model, tokenizer, episodes=40):
    logs_eval = rollout(model, tokenizer, episodes=episodes, K_alts=1)
    succ = 0
    vio  = 0
    buys = 0
    for ex in logs_eval:
        info = ex["info"]
        if info.get("bought"):
            succ += 1
            buys += 1
        if info.get("errors"):
            vio += len(info["errors"])
    return {"succ":succ, "violations":vio, "buys":buys, "steps": len(logs_eval)}

# Note: For a true 'before vs after', reload the base model and run run_eval() pre-training.
eval_after = run_eval(model, tokenizer, episodes=60)
print("Post-train eval:", eval_after)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Caching is incompatible with gradient checkpointing in SmolLM3DecoderLayer. Setting `past_key_values=None`.
  return fn(*args, **kwargs)


Post-train eval: {'succ': 0, 'violations': 360, 'buys': 0, 'steps': 360}


## Inference demo

In [None]:

def act_once(state_json: str):
    a = sample_action(model, tokenizer, state_json, temperature=0.2, top_k=30)
    return a

env = TinyShopEnv(budget=25, want_color="blue")
s0 = env.reset()
print("STATE:", s0)
a0 = act_once(s0)
print("ACTION:", a0)
s1, info1 = env.step(a0)
print("NEXT:", s1)
print("INFO:", info1)


STATE: {"view": "product_page", "catalog": {"red": 35, "blue": 25, "green": 20}, "cart": {"color": null, "price": null}, "constraints": {"budget": 25}, "preference": {"color": "blue"}}
ACTION: colorbdbSTATEbdbSTATEbdbSTATEbdbSTATEbdbSTATEbdbSTATEbdbSTATEbdbSTATEbdbSTATEbdbSTATEbdbSTATEbdb
NEXT: {"view": "product_page", "catalog": {"red": 35, "blue": 25, "green": 20}, "cart": {"color": null, "price": null}, "constraints": {"budget": 25}, "preference": {"color": "blue"}}
INFO: {'ok': False, 'errors': ['unknown_action']}
