In [None]:
!pip install --quiet transformers accelerate sentencepiece torch langgraph huggingface-hub
import importlib, os, time
importlib.invalidate_caches()

import torch
from huggingface_hub import login as hf_login
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, List, Dict

HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
    hf_login(HF_TOKEN)

MODEL_NAME = "TeichAI/Qwen3-1.7B-Gemini-2.5-Flash-Lite-Preview-Distill"
FALLBACK = "google/flan-t5-base"
MAX_HISTORY = 6
MAX_PROMPT_CHARS = 1400
MAX_NEW_TOKENS = 200

tokenizer = None
model = None
pipe = None
loaded = False
start = time.time()

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16)
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)
    loaded = True
    print("Loaded model:", MODEL_NAME, "in", round(time.time()-start,1), "s")
except Exception as e:
    print("Failed to load", MODEL_NAME, "->", e)
    print("Falling back to", FALLBACK)
    MODEL_NAME = FALLBACK
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    try:
        model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)
    except Exception:
        model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=MAX_NEW_TOKENS, device=-1, do_sample=False)
    loaded = True
    print("Loaded fallback:", MODEL_NAME)

def build_prompt(history: List[Dict[str,str]], user_input: str) -> str:
    lines = []
    for m in history:
        if m["role"] == "user":
            lines.append(f"Q: {m['text']}")
        else:
            lines.append(f"A: {m['text']}")
    lines.append(f"Q: {user_input}")
    lines.append("A:")
    return "\n".join(lines)

def trim_history(history: List[Dict[str,str]], user_input: str) -> List[Dict[str,str]]:
    h = history.copy()
    while h and len(build_prompt(h, user_input)) > MAX_PROMPT_CHARS:
        h.pop(0)
    return h

class ChatState(TypedDict):
    input: str
    history: List[Dict[str,str]]
    response: str

def user_node(state: ChatState):
    h = list(state.get("history", []))
    h.append({"role":"user","text": state["input"]})
    if len(h) > MAX_HISTORY:
        h = h[-MAX_HISTORY:]
    h = trim_history(h, state["input"])
    return {"history": h}

def llm_node(state: ChatState):
    hist = trim_history(state.get("history", []), state["input"])
    prompt = build_prompt(hist, state["input"])
    out = pipe(prompt, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)
    if isinstance(out, list):
        first = out[0]
        if isinstance(first, dict):
            text = first.get("generated_text") or first.get("text") or str(first)
        else:
            text = str(first)
    elif isinstance(out, dict):
        text = out.get("generated_text") or out.get("text") or str(out)
    else:
        text = str(out)
    if "A:" in text:
        text = text.split("A:")[-1].strip()
    return {"response": text.strip()}

def memory_node(state: ChatState):
    h = list(state.get("history", []))
    h.append({"role":"assistant","text": state.get("response","").strip()})
    if len(h) > MAX_HISTORY:
        h = h[-MAX_HISTORY:]
    h = trim_history(h, "")
    return {"history": h}

def output_node(state: ChatState):
    return {"response": state.get("response","")}

G = StateGraph(ChatState)
G.add_node("user", user_node)
G.add_node("llm", llm_node)
G.add_node("save", memory_node)
G.add_node("out", output_node)
G.add_edge(START, "user")
G.add_edge("user", "llm")
G.add_edge("llm", "save")
G.add_edge("save", "out")
G.add_edge("out", END)

agent = G.compile()

state: ChatState = {"input":"", "history":[], "response":""}

turns = [
    "Hi, I'm Nishant. Remind me to study IoT later.",
    "Which IoT topics are most important for an exam?",
    "Summarize them in two concise bullets.",
    "Remember I prefer short answers."
]

for i, msg in enumerate(turns, 1):
    state["input"] = msg
    s = agent.invoke(state)
    print(f"\n--- Turn {i} ---")
    print("User:", msg)
    print("Assistant:", s["response"])
    state = {"input":"", "history": s.get("history", []), "response": s.get("response","")}


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Loaded model: TeichAI/Qwen3-1.7B-Gemini-2.5-Flash-Lite-Preview-Distill in 65.5 s

--- Turn 1 ---
User: Hi, I'm Nishant. Remind me to study IoT later.
Assistant: [Thinking Process] I've identified the speaker as Nishant and recognized the reference to "studying IoT later." My current task is to formulate a precise and useful reminder. I'm focusing on understanding the nature of the information being referenced. I'm analyzing the frequency and content of the reference to ensure I create a reminder that is both specific and useful. I'm now formulating a reminder strategy based on this understanding.

I'm developing a reminder system for Nishant. I've pinpointed his need for repeated IoT study sessions. I'm now coding the logic for creating and storing the reminders. I'm focusing on ensuring the system is both accurate and efficient. Next, I'll need to figure out how to trigger the reminders based on time.


I'm developing a reminder system in Python. I've designed a class named `IoTRemind