# Assistive Keyboard — 7B (Colab A100, Drive‑persistent)

- base: `Qwen/Qwen2.5-7B-Instruct`
- Drive-backed code/data/adapters/results
- per-user LoRA + lexicon + RAG; eval KSS/latency


In [None]:
# gpu sanity
import sys, torch, platform
print('py:', sys.version.split()[0], '| cuda:', torch.cuda.is_available(), '| plat:', platform.platform())
if torch.cuda.is_available():
    !nvidia-smi

In [None]:
# mount drive + dirs
from google.colab import drive; drive.mount('/content/drive')
from pathlib import Path; import os
PROJ = Path('/content/drive/MyDrive/assistive_keyboard_7B'); PROJ.mkdir(parents=True, exist_ok=True)
CODE = PROJ/'code'; CODE.mkdir(exist_ok=True)
DATA = PROJ/'data'; (DATA/'processed').mkdir(parents=True, exist_ok=True)
SPLITS = PROJ/'splits'; SPLITS.mkdir(exist_ok=True)
USERS = PROJ/'users'; USERS.mkdir(exist_ok=True)
LEX = PROJ/'lexicons'; LEX.mkdir(exist_ok=True)
RAGD = PROJ/'rag'; RAGD.mkdir(exist_ok=True)
ADAPT = PROJ/'adapters'; ADAPT.mkdir(exist_ok=True)
RUNS = PROJ/'runs'; RUNS.mkdir(exist_ok=True)
CACHE = PROJ/'hf_cache'; CACHE.mkdir(exist_ok=True)
os.environ['HF_HOME'] = str(CACHE)
os.environ['TRANSFORMERS_CACHE'] = str(CACHE)
print('root:', PROJ)

In [None]:
# config (demo knobs)
MAX_TEST_AUTHORS = 20
ADAPT_TOKENS     = 2000
VAL_TOKENS       = 800
TEST_TOKENS      = 2000
LORA_STEPS       = 600
LORA_RANK        = 8
LORA_ALPHA       = 16
LORA_DROPOUT     = 0.05
BASE_MODEL       = 'Qwen/Qwen2.5-7B-Instruct'
USE_BF16_INSTEAD_OF_4BIT = False  # flip if 4bit acts up; A100-80G can do bf16 easily
SEED = 42

# seeds
import os, random, numpy as np, torch
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)

In [None]:
# deps
%%bash
set -e
pip -q install --upgrade pip
pip -q install numpy pandas tqdm pyyaml regex scikit-learn ujson
pip -q install transformers accelerate datasets sentence-transformers
pip -q install faiss-cpu peft bitsandbytes bert-score mauve-text
python - <<'PY'
import torch; print('torch', torch.__version__, 'cuda?', torch.cuda.is_available())
PY

In [None]:
# project code → Drive
from pathlib import Path; import textwrap
for sub in ['src/utils','src/data','src/splits','src/lexicon','src/rag','src/lora','src/infer','src/eval']:
    (CODE/sub).mkdir(parents=True, exist_ok=True)
(CODE/'src/__init__.py').write_text('')
(CODE/'src/utils/__init__.py').write_text('')
(CODE/'src/utils/io.py').write_text(textwrap.dedent('''
from pathlib import Path
import json, ujson
def read_lines(p):
    return Path(p).read_text(encoding='utf-8').splitlines()
def write_lines(p, lines):
    Path(p).parent.mkdir(parents=True, exist_ok=True)
    Path(p).write_text('\n'.join(lines), encoding='utf-8')
def read_jsonl(p):
    out=[]
    with open(p,'r',encoding='utf-8') as f:
        for line in f:
            line=line.strip()
            if line: out.append(json.loads(line))
    return out
def write_jsonl(p, rows):
    Path(p).parent.mkdir(parents=True, exist_ok=True)
    with open(p,'w',encoding='utf-8') as f:
        for r in rows: f.write(ujson.dumps(r, ensure_ascii=False)+'\n')
'''))
(CODE/'src/data/__init__.py').write_text('')
(CODE/'src/data/clean.py').write_text(textwrap.dedent(r'''
import re
QUOTE_RE = re.compile(r'(?m)^(>+).*?$')
SIG_RE = re.compile(r'(?ims)--\s*\n.*?$')
def clean_text(s:str)->str:
    s=s.replace('\r\n','\n')
    s=re.sub(QUOTE_RE,'',s)
    s=re.sub(SIG_RE,'',s)
    s=re.sub(r'[ \t]+',' ',s)
    s=re.sub(r'\n{3,}','\n\n',s)
    return s.strip()
def approx_token_count(s:str)->int:
    return len(re.findall(r"\w+|[.,!?;:]", s))
'''))
(CODE/'src/data/enron_loader.py').write_text(textwrap.dedent(r'''
from pathlib import Path
from .clean import clean_text, approx_token_count
from src.utils.io import write_jsonl
def build_authors_jsonl(maildir_root: str, out_jsonl: str, min_doc_tokens: int = 20):
    rows=[]; maildir=Path(maildir_root)
    for user_dir in maildir.iterdir():
        if not user_dir.is_dir(): continue
        author_id=user_dir.name
        for p in user_dir.rglob('*'):
            if not p.is_file(): continue
            try: txt=p.read_text(errors='ignore')
            except Exception: continue
            txt=clean_text(txt)
            if approx_token_count(txt)>=min_doc_tokens:
                rows.append({'author_id':author_id,'doc_id':str(p.relative_to(maildir)),'text':txt})
    write_jsonl(out_jsonl, rows); print(f'wrote {len(rows)} docs → {out_jsonl}')
'''))
(CODE/'src/splits/__init__.py').write_text('')
(CODE/'src/splits/make_splits.py').write_text(textwrap.dedent(r'''
import argparse, random
from collections import defaultdict
from pathlib import Path
from src.utils.io import read_jsonl, write_lines
from src.data.clean import approx_token_count
def main():
    ap=argparse.ArgumentParser()
    ap.add_argument('--authors_jsonl', required=True)
    ap.add_argument('--out_dir', required=True)
    ap.add_argument('--min_tokens', type=int, default=4000)
    ap.add_argument('--adapt_tokens', type=int, default=2000)
    ap.add_argument('--val_tokens', type=int, default=800)
    ap.add_argument('--test_tokens', type=int, default=2000)
    ap.add_argument('--max_test_authors', type=int, default=4)
    ap.add_argument('--seed', type=int, default=42)
    args=ap.parse_args(); random.seed(args.seed)
    out_dir=Path(args.out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    rows=read_jsonl(args.authors_jsonl)
    by_author=defaultdict(list)
    for r in rows: by_author[r['author_id']].append(r['text'])
    kept={}
    for a,docs in by_author.items():
        tot=sum(approx_token_count(t) for t in docs)
        if tot>=args.min_tokens: kept[a]=docs
    authors=sorted(kept.keys()); random.shuffle(authors)
    n=len(authors); n_train=int(0.70*n); n_dev=int(0.15*n)
    train_ids=authors[:n_train]; dev_ids=authors[n_train:n_train+n_dev]; test_ids=authors[n_train+n_dev:]
    test_ids=test_ids[:args.max_test_authors]
    write_lines(out_dir/'authors_train.txt', train_ids)
    write_lines(out_dir/'authors_dev.txt', dev_ids)
    write_lines(out_dir/'authors_test.txt', test_ids)
    users_dir=Path(str(Path(out_dir).parent/'users')); users_dir.mkdir(exist_ok=True)
    for a in test_ids:
        texts=kept[a][:]; random.shuffle(texts)
        acc=0; adapt=[]; val=[]; test=[]
        for t in texts:
            tc=approx_token_count(t)
            if acc<args.adapt_tokens: adapt.append(t); acc+=tc
            elif acc<args.adapt_tokens+args.val_tokens: val.append(t); acc+=tc
            else: test.append(t)
        udir=users_dir/a; udir.mkdir(parents=True, exist_ok=True)
        (udir/'adapt.txt').write_text('\n\n'.join(adapt), encoding='utf-8')
        (udir/'val.txt').write_text('\n\n'.join(val), encoding='utf-8')
        (udir/'test.txt').write_text('\n\n'.join(test), encoding='utf-8')
    print(f'train/dev/test: {len(train_ids)}/{len(dev_ids)}/{len(test_ids)} | users/* ready')
if __name__=='__main__': main()
'''))
(CODE/'src/lexicon/__init__.py').write_text('')
(CODE/'src/lexicon/build_lexicon.py').write_text(textwrap.dedent(r'''
import argparse, json, re, numpy as np
from pathlib import Path
from sklearn.feature_extraction.text import TfidfVectorizer
def tok(s):
    return re.findall(r"[A-Za-z]+(?:'[A-Za-z]+)?|[0-9]+|[^\sA-Za-z0-9]", s)
def build_lex(text, k=4000):
    v=TfidfVectorizer(tokenizer=tok, lowercase=True, ngram_range=(1,2), min_df=2, max_df=0.9, use_idf=True, smooth_idf=True, norm=None)
    X=v.fit_transform(text.splitlines()); vocab=v.get_feature_names_out()
    scores=np.asarray(X.sum(axis=0)).ravel(); idx=scores.argsort()[::-1]
    top=[(vocab[i], float(scores[i])) for i in idx[:k]]
    return {'entries':[{'token':t,'score':s} for t,s in top]}
def main():
    ap=argparse.ArgumentParser(); ap.add_argument('--users_dir', required=True); ap.add_argument('--out_dir', required=True); ap.add_argument('--max_items', type=int, default=4000); args=ap.parse_args()
    out=Path(args.out_dir); out.mkdir(parents=True, exist_ok=True)
    for u in Path(args.users_dir).iterdir():
        if not u.is_dir(): continue
        p=u/'adapt.txt'
        if not p.exists(): continue
        text=p.read_text(encoding='utf-8'); lex=build_lex(text, args.max_items)
        (out/f'{u.name}.lexicon.json').write_text(json.dumps(lex, ensure_ascii=False, indent=2), encoding='utf-8')
        print('lex:', u.name)
if __name__=='__main__': main()
'''))
(CODE/'src/rag/__init__.py').write_text('')
(CODE/'src/rag/build_rag.py').write_text(textwrap.dedent(r'''
import argparse, json
from pathlib import Path
from sentence_transformers import SentenceTransformer
import faiss, numpy as np
def chunks(s, m=300, ov=50):
    s=s.strip(); out=[]; i=0
    while i<len(s): out.append(s[i:i+m]); i+=max(1, m-ov)
    return out
def main():
    ap=argparse.ArgumentParser(); ap.add_argument('--users_dir', required=True); ap.add_argument('--out_dir', required=True)
    ap.add_argument('--model_name', default='sentence-transformers/all-MiniLM-L6-v2'); ap.add_argument('--chunk_chars', type=int, default=300); ap.add_argument('--overlap', type=int, default=50)
    args=ap.parse_args(); emb=SentenceTransformer(args.model_name)
    out=Path(args.out_dir); out.mkdir(parents=True, exist_ok=True)
    for u in Path(args.users_dir).iterdir():
        if not u.is_dir(): continue
        p=u/'adapt.txt'
        if not p.exists(): continue
        cs=chunks(p.read_text(encoding='utf-8'), args.chunk_chars, args.overlap)
        if not cs: continue
        X=emb.encode(cs, batch_size=64, convert_to_numpy=True, show_progress_bar=False).astype(np.float32)
        faiss.normalize_L2(X); idx=faiss.IndexFlatIP(X.shape[1]); idx.add(X)
        faiss.write_index(idx, str(out/f'{u.name}.faiss'))
        (out/f'{u.name}.chunks.json').write_text(json.dumps(cs, ensure_ascii=False), encoding='utf-8')
        print('rag:', u.name, len(cs))
if __name__=='__main__': main()
'''))
(CODE/'src/lora/__init__.py').write_text('')
(CODE/'src/lora/train_lora.py').write_text(textwrap.dedent(r'''
import argparse, re
from pathlib import Path
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
def read_txt(p): return Path(p).read_text(encoding='utf-8')
def mk_ds(txt: str, tok, bs=256):
    ids=tok(txt, return_tensors=None, truncation=False)['input_ids']
    blocks=[ids[i:i+bs] for i in range(0, len(ids)-bs, bs)] or [ids[:bs]]
    return Dataset.from_dict({'input_ids': blocks})
def last_ck(d: Path):
    c=[p for p in d.glob('checkpoint-*') if p.is_dir()]
    if not c: return None
    def step(p):
        import re; m=re.search(r'checkpoint-(\d+)', p.name); return int(m.group(1)) if m else -1
    return sorted(c, key=step)[-1]
def main():
    ap=argparse.ArgumentParser()
    ap.add_argument('--users_dir', required=True)
    ap.add_argument('--adapters_dir', required=True)
    ap.add_argument('--base_model', default='Qwen/Qwen2.5-7B-Instruct')
    ap.add_argument('--rank', type=int, default=8)
    ap.add_argument('--alpha', type=int, default=16)
    ap.add_argument('--dropout', type=float, default=0.05)
    ap.add_argument('--lr', type=float, default=2e-4)
    ap.add_argument('--steps', type=int, default=300)
    ap.add_argument('--block_size', type=int, default=256)
    args=ap.parse_args()
    tok=AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
    model=AutoModelForCausalLM.from_pretrained(args.base_model, device_map='auto', load_in_4bit=True, torch_dtype=torch.float16)
    model=prepare_model_for_kbit_training(model)
    model=get_peft_model(model, LoraConfig(r=args.rank, lora_alpha=args.alpha, lora_dropout=args.dropout, bias='none', task_type='CAUSAL_LM'))
    for u in Path(args.users_dir).iterdir():
        if not u.is_dir(): continue
        a=u/'adapt.txt'; v=u/'val.txt'
        if not a.exists() or not v.exists(): continue
        out=Path(args.adapters_dir)/u.name; out.mkdir(parents=True, exist_ok=True)
        tr=mk_ds(read_txt(a), tok, args.block_size); dv=mk_ds(read_txt(v), tok, args.block_size)
        targs=TrainingArguments(output_dir=str(out), per_device_train_batch_size=1, per_device_eval_batch_size=1, gradient_accumulation_steps=8,
            logging_steps=10, learning_rate=args.lr, max_steps=args.steps, evaluation_strategy='steps', eval_steps=100, save_strategy='steps', save_steps=100, save_total_limit=3, report_to='none')
        def collate(batch):
            feats=[b['input_ids'] for b in (batch if isinstance(batch,list) else [batch])]
            ml=max(len(f) for f in feats); pad=tok.pad_token_id
            ids=[f + [pad]*(ml-len(f)) for f in feats]
            return {'input_ids': torch.tensor(ids), 'labels': torch.tensor(ids)}
        ck=last_ck(out)
        Trainer(model=model, args=targs, train_dataset=tr, eval_dataset=dv, data_collator=collate).train(resume_from_checkpoint=str(ck) if ck else None)
        model.save_pretrained(str(out/'lora_adapter'))
        print('adapter:', out/'lora_adapter')
if __name__=='__main__': main()
'''))
(CODE/'src/infer/__init__.py').write_text('')
(CODE/'src/infer/suggest.py').write_text(textwrap.dedent(r'''
from collections import defaultdict, Counter
import re, json
from pathlib import Path
import numpy as np, torch
import faiss
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor
from peft import PeftModel
from sentence_transformers import SentenceTransformer
def tok_basic(s): return re.findall(r"[A-Za-z]+(?:'[A-Za-z]+)?|[0-9]+|[^\sA-Za-z0-9]", s)
class NGram:
    def __init__(self, text: str, n: int = 3):
        toks=tok_basic(text.lower()); self.n=n; self.ng=defaultdict(Counter)
        for i in range(len(toks)-n): self.ng[tuple(toks[i:i+n-1])][toks[i+n-1]]+=1
    def suggest(self, ctx: str, k: int = 3):
        toks=tok_basic(ctx.lower()); key=tuple(toks[-(self.n-1):]) if len(toks)>=self.n-1 else tuple(toks)
        cand=self.ng.get(key,{}); return [w for w,_ in cand.most_common(k)]
class Bias(LogitsProcessor):
    def __init__(self, mp): self.mp=mp or {}
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        if self.mp: scores[:, list(self.mp.keys())]+=torch.tensor(list(self.mp.values()), device=scores.device)
        return scores
class Lex:
    def __init__(self, tok, lex_json, cap=2.5):
        self.tok=tok; self.mp={}
        try: entries=json.loads(lex_json)['entries']
        except Exception: entries=[]
        for e in entries[:2000]:
            ids=self.tok(e['token'], add_special_tokens=False)['input_ids']
            if len(ids)==1: self.mp[ids[0]]=cap
    def proc(self): return Bias(self.mp)
class RAG:
    def __init__(self, f, c, em='sentence-transformers/all-MiniLM-L6-v2'):
        self.idx=faiss.read_index(str(f)); self.ch=json.loads(Path(c).read_text(encoding='utf-8'))
        self.emb=SentenceTransformer(em)
    def top(self, text, k=4):
        q=self.emb.encode([text], convert_to_numpy=True).astype(np.float32); faiss.normalize_L2(q); D,I=self.idx.search(q,k)
        return [self.ch[i] for i in I[0] if i>=0]
class LLM:
    def __init__(self, base: str, adapter_dir: str=None, rag=None, lex=None, max_ctx=512, bf16=False):
        self.tok=AutoTokenizer.from_pretrained(base, use_fast=True)
        if bf16:
            self.m=AutoModelForCausalLM.from_pretrained(base, device_map='auto', torch_dtype=torch.bfloat16)
        else:
            self.m=AutoModelForCausalLM.from_pretrained(base, device_map='auto', load_in_4bit=True)
        if adapter_dir and Path(adapter_dir).exists(): self.m=PeftModel.from_pretrained(self.m, adapter_dir)
        self.rag=rag; self.lex=lex; self.max_ctx=max_ctx; self.m.eval()
    def _prompt(self, tail, mem):
        if mem:
            bullets='\n'.join(f'- {c[:200]}' for c in mem)
            mem=f"Memory\n{bullets}\n\n"
        return f"Continue in the user's style.\n{mem}Draft:\n{tail}\n\nContinue:"
    def suggest(self, ctx: str, k: int = 3):
        tail=ctx[-1000:]; mem=self.rag.top(tail,4) if self.rag else []
        prompt=self._prompt(tail, mem)
        ids=self.tok(prompt, return_tensors='pt', truncation=True, max_length=self.max_ctx).to(self.m.device)
        procs=[self.lex.proc()] if self.lex else None
        with torch.no_grad():
            out=self.m.generate(**ids, max_new_tokens=6, do_sample=False, num_beams=max(1,k), num_return_sequences=k,
                                 logits_processor=procs, pad_token_id=self.tok.eos_token_id)
        texts=self.tok.batch_decode(out[:, ids['input_ids'].shape[1]:], skip_special_tokens=True)
        res=[]
        for t in texts:
            t=t.strip(); m=re.match(r"^\S{1,8}", t); s=m.group(0) if m else t[:8]
            if s and s not in res: res.append(s)
        return res[:k]
'''))
(CODE/'src/eval/__init__.py').write_text('')
(CODE/'src/eval/typing_sim.py').write_text(textwrap.dedent(r'''
import argparse, time, csv
from pathlib import Path
from src.infer.suggest import NGram, LLM, Lex, RAG
def load_text(p): return Path(p).read_text(encoding='utf-8')
def sim(doc: str, sugg, k=3, max_chunk=8):
    kp=0; kw=0; acc=0; t0=time.time(); i=0
    while i<len(doc):
        kp+=1; kw+=1; i+=1
        pref=doc[:i]; sug=sugg.suggest(pref, k=k)
        if not sug: continue
        remain=doc[i:]; ok=False
        for s in sug:
            s=s[:max_chunk]
            if remain.lower().startswith(s.lower()):
                saved=max(len(s)-1,0); kw+=1; kw-=saved; i+=len(s); acc+=1; ok=True; break
        if not ok: continue
    ms=(time.time()-t0)*1000.0; kss=1.0-(kw/max(kp,1))
    return dict(keys_plain=kp, keys_with=kw, kss=kss, accepts=acc, time_ms=ms)
def main():
    ap=argparse.ArgumentParser()
    ap.add_argument('--users_dir', required=True)
    ap.add_argument('--mode', choices=['ngram','llm_base','llm_lex','llm_full'], default='ngram')
    ap.add_argument('--base_model', default='Qwen/Qwen2.5-7B-Instruct')
    ap.add_argument('--adapters_dir', default='adapters')
    ap.add_argument('--lexicons_dir', default='lexicons')
    ap.add_argument('--rag_dir', default='rag')
    ap.add_argument('--results_csv', required=True)
    ap.add_argument('--k', type=int, default=3)
    ap.add_argument('--bf16', action='store_true')
    args=ap.parse_args(); out=[]
    users=Path(args.users_dir)
    for u in users.iterdir():
        if not u.is_dir(): continue
        a=u/'adapt.txt'; t=u/'test.txt'
        if not a.exists() or not t.exists(): continue
        if args.mode=='ngram':
            sg=NGram(load_text(a))
        else:
            from transformers import AutoTokenizer
            lex=None
            if args.mode in ('llm_lex','llm_full'):
                lp=Path(args.lexicons_dir)/f'{u.name}.lexicon.json'
                if lp.exists(): lex=Lex(AutoTokenizer.from_pretrained(args.base_model, use_fast=True), lp.read_text(encoding='utf-8'), cap=2.5)
            rag=None
            fp=Path(args.rag_dir)/f'{u.name}.faiss'; cp=Path(args.rag_dir)/f'{u.name}.chunks.json'
            if fp.exists() and cp.exists(): rag=RAG(fp, cp)
            ad=str(Path(args.adapters_dir)/u.name/'lora_adapter') if args.mode=='llm_full' else None
            sg=LLM(args.base_model, adapter_dir=ad, rag=rag, lex=lex, bf16=args.bf16)
        res=sim(load_text(t), sg, k=args.k)
        out.append({'user': u.name, 'mode': args.mode, **res}); print(u.name, args.mode, f"KSS={res['kss']:.3f}")
    Path(args.results_csv).parent.mkdir(parents=True, exist_ok=True)
    with open(args.results_csv,'w',newline='',encoding='utf-8') as f:
        w=csv.DictWriter(f, fieldnames=list(out[0].keys()) if out else ['user','mode','kss'])
        w.writeheader(); [w.writerow(r) for r in out]
    print('csv →', args.results_csv)
if __name__=='__main__': main()
'''))
print('code rooted at', CODE)

### data → authors.jsonl (robust download)

In [None]:
%%bash
set -euo pipefail
cd /content
F=enron_mail_20150507.tgz
rm -f "$F"
for URL in \
  "http://www.cs.cmu.edu/~./enron/enron_mail_20150507.tgz" \
  "http://www.cs.cmu.edu/~enron/enron_mail_20150507.tgz" \
  "https://www.cs.cmu.edu/~./enron/enron_mail_20150507.tgz" \
  "https://www.cs.cmu.edu/~enron/enron_mail_20150507.tgz"; do
  echo "try $URL"; if curl -fL --retry 5 --retry-all-errors -o "$F" "$URL"; then break; fi
done
python - <<'PY'
fn='enron_mail_20150507.tgz'
with open(fn,'rb') as f:
    sig=f.read(2)
import sys, pathlib
if sig != b'\x1f\x8b':
    print('not gzip; first 200 bytes:'); print(pathlib.Path(fn).read_bytes()[:200]); sys.exit(2)
print('gzip ok')
PY
tar -xzf "$F"
python - <<'PY'
from pathlib import Path; import sys
sys.path.append('/content/drive/MyDrive/assistive_keyboard_7B/code')
from src.data.enron_loader import build_authors_jsonl
build_authors_jsonl('/content/enron_mail_20150507/maildir', '/content/drive/MyDrive/assistive_keyboard_7B/data/processed/authors.jsonl', min_doc_tokens=20)
print('authors.jsonl → /content/drive/MyDrive/assistive_keyboard_7B/data/processed/authors.jsonl')
PY

### splits (author‑disjoint) + per‑user slices (use config)

In [None]:
from pathlib import Path; import sys
root=Path('/content/drive/MyDrive/assistive_keyboard_7B')
sys.path.append(str(root/'code'))
sys.argv = [
  'splits',
  '--authors_jsonl', str(root/'data/processed/authors.jsonl'),
  '--out_dir',       str(root/'splits'),
  '--min_tokens',    str(ADAPT_TOKENS+VAL_TOKENS+TEST_TOKENS),
  '--adapt_tokens',  str(ADAPT_TOKENS),
  '--val_tokens',    str(VAL_TOKENS),
  '--test_tokens',   str(TEST_TOKENS),
  '--max_test_authors', str(MAX_TEST_AUTHORS),
  '--seed',          str(SEED)
]
from src.splits.make_splits import main as run; run()

### prefetch base model (warms HF cache)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if USE_BF16_INSTEAD_OF_4BIT:
    m = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=getattr(__import__('torch'),'bfloat16'))
else:
    m = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', load_in_4bit=True)
del m
print('prefetch ok')

### per‑user assets: lexicon + FAISS RAG

In [None]:
import sys
root=Path('/content/drive/MyDrive/assistive_keyboard_7B'); sys.path.append(str(root/'code'))
sys.argv = ['lex','--users_dir', str(root/'users'), '--out_dir', str(root/'lexicons'), '--max_items','4000']
from src.lexicon.build_lexicon import main as lex; lex()
sys.argv = ['rag','--users_dir', str(root/'users'), '--out_dir', str(root/'rag'), '--model_name','sentence-transformers/all-MiniLM-L6-v2', '--chunk_chars','300','--overlap','50']
from src.rag.build_rag import main as rag; rag()

### LoRA per user (resumable ckpts; bf16 toggle)

In [None]:
import sys
root=Path('/content/drive/MyDrive/assistive_keyboard_7B'); sys.path.append(str(root/'code'))

# hot‑patch loader to flip 4bit ↔ bf16 without editing files
import src.lora.train_lora as TL
orig = TL.AutoModelForCausalLM.from_pretrained
def patched(*a, **kw):
    kw = dict(kw)
    if USE_BF16_INSTEAD_OF_4BIT:
        kw.pop('load_in_4bit', None)
        kw['torch_dtype'] = getattr(__import__('torch'), 'bfloat16')
    else:
        kw['load_in_4bit'] = True
        kw['torch_dtype'] = getattr(__import__('torch'), 'float16')
    return orig(*a, **kw)
TL.AutoModelForCausalLM.from_pretrained = patched

sys.argv = [
  'train',
  '--users_dir',    str(root/'users'),
  '--adapters_dir', str(root/'adapters'),
  '--base_model',   BASE_MODEL,
  '--rank',         str(LORA_RANK),
  '--alpha',        str(LORA_ALPHA),
  '--dropout',      str(LORA_DROPOUT),
  '--lr',           '2e-4',
  '--steps',        str(LORA_STEPS),
  '--block_size',   '256'
]
from src.lora.train_lora import main as train; train()

### eval (ngram, llm_base, llm_lex, llm_full) → CSVs

In [None]:
import sys
root=Path('/content/drive/MyDrive/assistive_keyboard_7B'); sys.path.append(str(root/'code'))
def run_eval(mode, name):
    sys.argv = [
      'eval',
      '--users_dir',   str(root/'users'),
      '--mode',        mode,
      '--base_model',  BASE_MODEL,
      '--adapters_dir',str(root/'adapters'),
      '--lexicons_dir',str(root/'lexicons'),
      '--rag_dir',     str(root/'rag'),
      '--results_csv', str(root/'runs'/f'leaderboard_{name}.csv'),
    ] + (['--bf16'] if USE_BF16_INSTEAD_OF_4BIT else [])
    from src.eval.typing_sim import main as E; E()
for mode,name in [('ngram','ngram'),('llm_base','llm_base'),('llm_lex','llm_lex'),('llm_full','llm_full')]:
    run_eval(mode,name)
print('eval ok')

### results (means + per‑author spread + outliers)

In [None]:
import pandas as pd, os
root='/content/drive/MyDrive/assistive_keyboard_7B/runs'
dfs=[]
for f in ['leaderboard_ngram.csv','leaderboard_llm_base.csv','leaderboard_llm_lex.csv','leaderboard_llm_full.csv']:
    p=os.path.join(root,f)
    try: dfs.append(pd.read_csv(p).assign(model=f.replace('leaderboard_','').replace('.csv','')))
    except Exception as e: print('missing', p, e)
res=pd.concat(dfs, ignore_index=True)
print('== means by model ==')
print(res.groupby('model')[['kss','time_ms','accepts']].mean().round(3))
full=res[res['model']=='llm_full'][['user','kss']]
print('\n== per-author KSS (llm_full) describe ==')
print(full.describe().round(3))
med=full['kss'].median(); mad=(full['kss']-med).abs().median()
bad=full[full['kss']<med-1.5*mad]['user'].tolist()
print('\noutliers (low KSS):', bad)

### quick live check (single author suggestion)

In [None]:
from pathlib import Path
import sys
root=Path('/content/drive/MyDrive/assistive_keyboard_7B'); sys.path.append(str(root/'code'))
from src.infer.suggest import LLM, Lex, RAG
from transformers import AutoTokenizer
authors=[p.name for p in (root/'users').iterdir() if p.is_dir()]
author=authors[0] if authors else None
print('author:', author)
lex=None; lpath=root/'lexicons'/f'{author}.lexicon.json'
if lpath.exists():
    lex=Lex(AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True), lpath.read_text(encoding='utf-8'), cap=2.5)
rag=None; f=root/'rag'/f'{author}.faiss'; c=root/'rag'/f'{author}.chunks.json'
if f.exists() and c.exists():
    rag=RAG(f, c)
adapter=str(root/'adapters'/author/'lora_adapter')
sg=LLM(BASE_MODEL, adapter_dir=adapter, rag=rag, lex=lex, bf16=USE_BF16_INSTEAD_OF_4BIT)
ctx='Hi team, following up on the budget approval for Q4. If we can align by Friday,'
print('ctx:', ctx)
print('suggestions:', sg.suggest(ctx, k=3))