#**Vision Experiment 1**

In [None]:
# Install deps (quiet)
!pip -q install opencv-python-headless==4.10.0.84 pillow matplotlib gradio orjson

import os, sys, json, math, time, threading
from pathlib import Path
import numpy as np
import cv2
import matplotlib.pyplot as plt

# Make folders
Path("/content/spikes").mkdir(parents=True, exist_ok=True)
Path("/content").mkdir(parents=True, exist_ok=True)

print("Ready. OpenCV:", cv2.__version__)


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.9/49.9 MB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25hReady. OpenCV: 4.10.0


In [None]:
# babyai_vision_colab.py – single-file core for Colab
import json, time, math, threading
from dataclasses import dataclass, asdict
from pathlib import Path
import numpy as np
import cv2

def now_iso():
    return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def ema_update(prototype, x, beta=0.9):
    return beta * prototype + (1.0 - beta) * x

def cosine_sim(a, b, eps=1e-8):
    a = np.asarray(a, np.float32); b = np.asarray(b, np.float32)
    num = float(np.dot(a, b))
    den = float(np.linalg.norm(a) * np.linalg.norm(b)) + eps
    return num / den

def _json_compact_dump(path: Path, data: dict):
    try:
        import orjson
        with open(path, "wb") as f:
            f.write(orjson.dumps(data, option=orjson.OPT_SERIALIZE_NUMPY))
    except Exception:
        with open(path, "w", encoding="utf-8") as f:
            json.dump(data, f, separators=(",", ":"))

# ---------------- Retina ----------------
@dataclass
class RetinaParams:
    sigma_c: float = 1.5
    sigma_s: float = 3.5
    Kc: float = 1.0
    Ks: float = 0.8
    temporal_tau1: float = 8.0
    temporal_tau2: float = 16.0
    on_off_split: bool = True
    saccades: int = 5
    jitter_px: int = 2
    drive_gain: float = 1.0
    nl_alpha: float = 1.0

def _gaussian2d_same_grid(sigma, max_sigma):
    k = int(3 * max_sigma)
    ax = np.arange(-k, k + 1, dtype=np.float32)
    xx, yy = np.meshgrid(ax, ax, indexing="xy")
    g = np.exp(-(xx**2 + yy**2) / (2.0 * np.float32(sigma)**2))
    s = g.sum()
    if s > 0: g /= s
    return g

def dog_kernel(params: RetinaParams):
    max_sigma = max(params.sigma_c, params.sigma_s)
    gc = _gaussian2d_same_grid(params.sigma_c, max_sigma)
    gs = _gaussian2d_same_grid(params.sigma_s, max_sigma)
    kdog = params.Kc * gc - params.Ks * gs
    return (kdog - kdog.mean()).astype(np.float32)

def to_luminance(img_bgr):
    img = img_bgr.astype(np.float32) / 255.0
    b, g, r = img[...,0], img[...,1], img[...,2]
    return 0.299*r + 0.587*g + 0.114*b

def microsaccades(img, n, jitter):
    H, W = img.shape
    frames = []
    rng = np.random.default_rng(0)
    for _ in range(n):
        dx = rng.integers(-jitter, jitter+1)
        dy = rng.integers(-jitter, jitter+1)
        M = np.float32([[1,0,dx],[0,1,dy]])
        warped = cv2.warpAffine(img, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
        frames.append(warped)
    return np.stack(frames, axis=0)

def biphasic_temporal_filter(T, tau1, tau2):
    t = np.arange(T, dtype=np.float32)
    k = t*np.exp(-t/tau1) - 0.6*t*np.exp(-t/tau2)
    k = k - k.mean()
    norm = np.sqrt((k**2).sum()) + 1e-8
    return k / norm

def retina_drive(img_bgr, params: RetinaParams):
    if img_bgr is None:
        raise ValueError("retina_drive: img_bgr is None")
    if img_bgr.ndim == 2:
        img_bgr = cv2.cvtColor(img_bgr, cv2.COLOR_GRAY2BGR)
    Y = to_luminance(img_bgr)
    frames = microsaccades(Y, params.saccades, params.jitter_px)
    kdog = dog_kernel(params)
    drive = np.zeros_like(frames)
    for t in range(frames.shape[0]):
        drive[t] = cv2.filter2D(frames[t], -1, kdog, borderType=cv2.BORDER_REFLECT)
    ktemp = biphasic_temporal_filter(frames.shape[0], params.temporal_tau1, params.temporal_tau2)
    drive_t = np.tensordot(ktemp, drive, axes=(0,0)).astype(np.float32)
    if params.on_off_split:
        on = np.clip(drive_t, 0, None)
        off = np.clip(-drive_t, 0, None)
        return on, off
    else:
        return np.clip(drive_t, 0, None), None

# --------------- V1 / spikes ---------------
def poisson_spikes(rate, T_ms=50, dt_ms=1.0, rng=None):
    if rng is None: rng = np.random.default_rng()
    steps = int(T_ms / dt_ms)
    lam = (rate / 1000.0) * dt_ms
    lam = np.clip(lam, 0, 1.0)
    spikes = rng.random((steps, rate.shape[0])) < lam[None, :]
    return spikes.astype(np.uint8)

def rgc_spike_generator(on_map, off_map, gain=1.0, nl_alpha=1.0, topk=1024):
    def pick_topk(X):
        v = X.reshape(-1)
        k = min(topk, v.size)
        idx = np.argpartition(-v, k-1)[:k]
        idx = idx[np.argsort(-v[idx])]
        return idx, v[idx]
    on_idx, on_vals = pick_topk(on_map)
    off_vals_base = np.zeros_like(on_map) if off_map is None else off_map
    off_idx, off_vals = pick_topk(off_vals_base)
    v = np.concatenate([on_vals, off_vals], axis=0)
    rate = gain * np.maximum(0.0, np.exp(nl_alpha * (v / (v.std()+1e-6))) - 1.0)
    rate = np.nan_to_num(rate, nan=0.0, posinf=0.0, neginf=0.0)
    spikes = poisson_spikes(rate, T_ms=50, dt_ms=1.0)
    return spikes, rate, np.concatenate([on_idx, off_idx + on_map.size], axis=0)

def divisive_normalization(x, pool_alpha=1.5, sigma=0.1):
    if x.ndim == 2:
        denom = sigma + cv2.GaussianBlur(np.abs(x)**pool_alpha, (0,0), 1.0)
        return x / denom
    else:
        denom = sigma + np.mean(np.abs(x)**pool_alpha)
        return x / denom

@dataclass
class V1Params:
    orientations: int = 8
    scales: int = 4
    lam0: float = 4.0
    lam_mul: float = 1.8
    sigma_frac: float = 0.6
    gamma: float = 0.9

_GABOR_CACHE = {}
def build_gabor_bank(H, W, vp: V1Params):
    key = (H, W, int(vp.orientations), int(vp.scales), float(vp.lam0), float(vp.lam_mul), float(vp.sigma_frac), float(vp.gamma))
    if key in _GABOR_CACHE: return _GABOR_CACHE[key]
    bank = []
    thetas = np.linspace(0, np.pi, vp.orientations, endpoint=False)
    lams = [vp.lam0 * (vp.lam_mul**s) for s in range(vp.scales)]
    for lam in lams:
        sigma = vp.sigma_frac * lam
        ksize = int(max(7, 6*sigma)) | 1
        for th in thetas:
            ge = cv2.getGaborKernel((ksize, ksize), sigma, th, lam, vp.gamma, psi=0, ktype=cv2.CV_32F)
            go = cv2.getGaborKernel((ksize, ksize), sigma, th, lam, vp.gamma, psi=np.pi/2, ktype=cv2.CV_32F)
            bank.append((ge, go))
    _GABOR_CACHE[key] = bank
    return bank

def v1_energy_maps(img_gray, bank):
    energies = []
    for (ge, go) in bank:
        fe = cv2.filter2D(img_gray, -1, ge, borderType=cv2.BORDER_REFLECT)
        fo = cv2.filter2D(img_gray, -1, go, borderType=cv2.BORDER_REFLECT)
        e = np.sqrt(fe*fe + fo*fo).astype(np.float32)
        energies.append(e)
    return energies

def spatial_pool(energies, grid=(8,8)):
    H, W = energies[0].shape
    gh, gw = grid
    h, w = H//gh, W//gw
    feats = []
    for e in energies:
        ec = e[:gh*h, :gw*w]
        ec = ec.reshape(gh, h, gw, w).mean(axis=(1,3))
        feats.append(ec.reshape(-1))
    return np.concatenate(feats, axis=0).astype(np.float32)

def features_to_spikes(feat, scale=12.0):
    z = (feat - feat.mean()) / (feat.std() + 1e-6)
    rate = np.clip(np.exp(z) - 1.0, 0, None) * scale
    spk = poisson_spikes(rate, T_ms=50, dt_ms=1.0)
    return spk, rate

# --------------- Learning / STDP ---------------
@dataclass
class LearningParams:
    A_plus: float = 0.02
    A_minus: float = -0.025
    tau_plus: float = 20.0
    tau_minus: float = 20.0
    decay_lambda: float = 0.9995
    dopamine: float = 1.0

def stdp_update(W, pre_spk, post_spk, lp: LearningParams):
    T, N = pre_spk.shape
    if N != W.shape[0]: W = np.resize(W, (N,))
    if not post_spk.any():
        W *= lp.decay_lambda; return W
    post_t = np.where(post_spk)[0].mean()
    for n in range(N):
        tlist = np.where(pre_spk[:, n])[0]
        if tlist.size == 0: continue
        dt = tlist.mean() - post_t
        if dt < 0:
            dw = lp.A_plus * np.exp(dt / lp.tau_plus)
        else:
            dw = lp.A_minus * np.exp(-dt / lp.tau_minus)
        W[n] = max(0.0, W[n] + lp.dopamine * dw)
    W *= lp.decay_lambda
    return W

# --------------- Auto Dopamine ---------------
@dataclass
class AutoDopaConfig:
    a_pos: float = 0.8
    a_neg: float = 0.6
    alpha_V: float = 0.15
    kappa: float = 0.1
    b_rho: float = 0.6
    b_nov: float = 0.25
    b_unc: float = 0.20
    b_fatigue: float = 0.15
    rho0: float = 0.5
    wP: float = 0.65
    wT: float = 0.35
    Dmin: float = 0.4
    Dmax: float = 2.0
    beta_softmax: float = 8.0
    eta_novelty: float = 0.05
    fatigue_rise: float = 0.10
    fatigue_decay_sec: float = 120.0

class AutoDopamine:
    def __init__(self, meta: dict):
        self.cfg = AutoDopaConfig()
        st = meta.get("dopamine_state", {})
        self.rhat = float(st.get("rhat", 0.5))
        self.rho  = float(st.get("rho", 0.5))
        self.fatigue = float(st.get("fatigue", 0.0))
        self.last_ts = float(st.get("last_ts", time.time()))
        meta["dopamine_state"] = {
            "rhat": self.rhat, "rho": self.rho, "fatigue": self.fatigue, "last_ts": self.last_ts
        }
        self._meta = meta

    def _write_back(self):
        self._meta["dopamine_state"].update(
            {"rhat": self.rhat, "rho": self.rho, "fatigue": self.fatigue, "last_ts": self.last_ts}
        )

    def _entropy_from_scores(self, sims, beta):
        if len(sims) == 0: return 0.0
        z = np.array(sims, dtype=np.float32)
        p = np.exp(beta * z - np.max(beta*z))
        p = p / (p.sum() + 1e-8)
        return float(-(p * (np.log(p + 1e-8))).sum())

    def step(self, feat: np.ndarray, label: str, all_protos: dict) -> dict:
        cfg = self.cfg
        now = time.time(); dt = max(1e-3, now - self.last_ts)

        sims = [cosine_sim(feat, p) for p in all_protos.values()]
        sims_sorted = sorted(sims, reverse=True) if sims else []
        max_sim = sims_sorted[0] if sims_sorted else -1.0

        s_lab = cosine_sim(feat, all_protos[label]) if label in all_protos else -1.0
        s01 = (s_lab + 1.0) * 0.5

        novelty = 1.0 - max(-1.0, min(1.0, max_sim))
        novelty = max(0.0, min(2.0, novelty))
        H = self._entropy_from_scores(sims_sorted[:8], cfg.beta_softmax)

        r_t = float(max(0.0, min(1.0, s01 + cfg.eta_novelty * novelty)))
        delta = r_t - self.rhat
        P = 1.0 + cfg.a_pos * max(0.0, delta) - cfg.a_neg * max(0.0, -delta)

        decay = math.exp(-dt / max(1.0, cfg.fatigue_decay_sec))
        self.fatigue = self.fatigue * decay + cfg.fatigue_rise * (1.0 - decay)
        T = 1.0 + cfg.b_rho * (self.rho - cfg.rho0) + cfg.b_nov * novelty + cfg.b_unc * H - cfg.b_fatigue * self.fatigue

        D = float(max(cfg.Dmin, min(cfg.Dmax, cfg.wP * P + cfg.wT * T)))

        self.rhat = (1.0 - cfg.alpha_V) * self.rhat + cfg.alpha_V * r_t
        self.rho  = (1.0 - cfg.kappa)   * self.rho  + cfg.kappa   * r_t
        self.last_ts = now
        self._write_back()

        return {"dopamine": D, "P": P, "T": T, "delta": float(delta),
                "r": r_t, "s_label": s_lab, "novelty": float(novelty), "uncertainty": float(H),
                "rhat": self.rhat, "rho": self.rho, "fatigue": self.fatigue}

# --------------- IT invariance ---------------
@dataclass
class ITParams:
    use_it: bool = True
    s2_rbf_sigma: float = 0.8
    max_exemplars_per_label: int = 12
    pool_mode: str = "max"           # "max" | "median"
    growth_start_deg: float = 0.0
    growth_max_deg: float = 60.0
    growth_start_scale: float = 1.0
    growth_max_scale: float = 1.6
    growth_steps: int = 5
    jitter_px: int = 4
    contrast_jit: float = 0.15
    ema_beta_proto: float = 0.90
    min_margin_target: float = 0.12

_IT_MAX_VIEWS = 10  # set to 4 for fast training

def _rand_contrast(img, jitter=0.15, rng=None):
    rng = rng or np.random.default_rng()
    c = 1.0 + float(rng.uniform(-jitter, jitter))
    out = np.clip(img * c, 0, 255).astype(img.dtype)
    return out

def _affine_warp(img, angle_deg=0.0, scale=1.0, dx=0, dy=0):
    H, W = img.shape[:2]
    M = cv2.getRotationMatrix2D((W/2, H/2), angle_deg, scale); M[:,2] += (dx, dy)
    return cv2.warpAffine(img, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)

def _generate_augmented_views(img_bgr, itp: ITParams, experience_ratio: float):
    rot_max = itp.growth_start_deg + (itp.growth_max_deg - itp.growth_start_deg) * experience_ratio
    scl_max = itp.growth_start_scale + (itp.growth_max_scale - itp.growth_start_scale) * experience_ratio
    rng = np.random.default_rng(0)

    views = []
    base = _rand_contrast(img_bgr, itp.contrast_jit, rng)
    views.append(base)

    angles = [0.0, +rot_max, -rot_max/2]
    scales = [1.0, scl_max, max(1.0, 1.0/np.sqrt(max(1e-6, scl_max)))]
    shifts = [(0,0), (itp.jitter_px, 0), (0, itp.jitter_px)]
    for th in angles:
        for sc in scales:
            for (dx,dy) in shifts[:2]:
                v = _affine_warp(base, angle_deg=th, scale=sc, dx=dx, dy=dy)
                views.append(v)
    return views[:_IT_MAX_VIEWS]

def v1_feature_from_bgr(img_bgr, ret_params: RetinaParams, v1_params: V1Params):
    on_map, _ = retina_drive(img_bgr, ret_params)
    on_lgn = divisive_normalization(on_map)
    bank = build_gabor_bank(on_map.shape[0], on_map.shape[1], v1_params)
    energies = v1_energy_maps(on_lgn, bank)
    feat = spatial_pool(energies, grid=(8,8))
    return feat

def rbf_sim(x, c, sigma=0.8, eps=1e-8):
    d2 = np.sum((x - c)**2) / ( (np.linalg.norm(x)+eps)*(np.linalg.norm(c)+eps) )
    return float(np.exp(- d2 / (2.0 * sigma**2)))

# --------------- Persistent Brain ---------------
class VisionBrain:
    def __init__(self, path="/content/brain.json"):
        self.path = Path(path)
        self.data = {"vision": {}, "crossmodal": {}, "meta": {}, "it": {}, "experience": {"n":0}}
        if self.path.exists():
            with open(self.path, "r", encoding="utf-8") as f:
                self.data = json.load(f)
        self.data.setdefault("vision", {})
        self.data.setdefault("crossmodal", {})
        self.data.setdefault("meta", {})
        self.data.setdefault("it", {})
        self.data.setdefault("experience", {"n":0})
        self.data["meta"].setdefault("dopamine_state", {"rhat":0.5, "rho":0.5, "fatigue":0.0, "last_ts": time.time()})

    def save(self):
        tmp = self.path.with_suffix(".tmp.json")
        _json_compact_dump(tmp, self.data)
        tmp.replace(self.path)

    def save_async(self):
        threading.Thread(target=self.save, daemon=True).start()

    def _ensure_label(self, label):
        v = self.data["vision"]
        if label not in v:
            v[label] = {"count":0,"updated_at":now_iso(),"prototype":None,"prototype_norm":0.0,
                        "W_sparse":None,"threshold_low":0.1,"history":[], "params":{}}

    def _all_prototypes(self):
        protos = {}
        for lab, rec in self.data["vision"].items():
            if rec.get("prototype") is not None:
                protos[lab] = np.array(rec["prototype"], dtype=np.float32)
        return protos

    def _it_experience_ratio(self, itp: ITParams):
        n = int(self.data.get("experience", {}).get("n", 0))
        step = max(1, itp.growth_steps)
        return float(min(1.0, n / (10.0 * step)))

    def _ensure_it_label(self, label):
        it = self.data["it"]
        if label not in it:
            it[label] = {"exemplars": [], "updated_at": now_iso()}

    def _it_learn(self, label, img_bgr, ret_params, v1_params, itp: ITParams):
        self._ensure_it_label(label)
        it = self.data["it"][label]
        ratio = self._it_experience_ratio(itp)
        views = _generate_augmented_views(img_bgr, itp, ratio)
        for v in views:
            f = v1_feature_from_bgr(v, ret_params, v1_params)
            it["exemplars"].append({"feat": f.tolist()})
        # prune diversity set
        ex = [np.array(e["feat"], dtype=np.float32) for e in it["exemplars"]]
        if len(ex) > itp.max_exemplars_per_label:
            keep = [ex[-1]]
            while len(keep) < itp.max_exemplars_per_label:
                best_i, best_d = None, -1.0
                for i, vec in enumerate(ex):
                    dmin = min(np.linalg.norm(vec-k) for k in keep)
                    if dmin > best_d: best_d, best_i = dmin, i
                keep.append(ex[best_i])
            it["exemplars"] = [{"feat": k.tolist()} for k in keep]
        it["updated_at"] = now_iso()

    def _it_score_label(self, label, img_bgr, ret_params, v1_params, itp: ITParams):
        it = self.data["it"].get(label, None)
        if not it or not it.get("exemplars"): return None
        ratio = self._it_experience_ratio(itp)
        views = _generate_augmented_views(img_bgr, itp, ratio)
        ex = [np.array(e["feat"], dtype=np.float32) for e in it["exemplars"]]
        sigma = itp.s2_rbf_sigma
        view_scores = []
        for v in views:
            f = v1_feature_from_bgr(v, ret_params, v1_params)
            scores = [rbf_sim(f, c, sigma) for c in ex]
            view_scores.append(max(scores))
        return float(np.median(view_scores) if itp.pool_mode=="median" else np.max(view_scores))

    def record_image(self, img_bgr, label, ret_params=None, v1_params=None,
                     learn_params=None, store_dir="/content/spikes",
                     auto_dopamine=False, it_params: ITParams=None, **kwargs):
        ret_params = ret_params or RetinaParams()
        v1_params  = v1_params  or V1Params()
        it_params  = it_params  or ITParams()
        learn_params = learn_params or LearningParams()

        on_map, off_map = retina_drive(img_bgr, ret_params)
        rgc_spk, rgc_rate, rgc_idx = rgc_spike_generator(on_map, off_map, gain=ret_params.drive_gain, nl_alpha=ret_params.nl_alpha)

        on_lgn = divisive_normalization(on_map)
        bank = build_gabor_bank(on_map.shape[0], on_map.shape[1], v1_params)
        energies = v1_energy_maps(on_lgn, bank)
        feat = spatial_pool(energies, grid=(8,8))
        v1_spk, v1_rate = features_to_spikes(feat)

        try:
            store = Path(store_dir); ensure_dir(store)
            tstamp = int(time.time())
            np.savez_compressed(store / f"vision_{label}_{tstamp}.npz",
                                rgc_spk=rgc_spk, rgc_rate=rgc_rate, rgc_idx=rgc_idx,
                                v1_spk=v1_spk, v1_rate=v1_rate, feat=feat)
        except Exception:
            pass

        self._ensure_label(label)
        rec = self.data["vision"][label]

        dopa_used = float(learn_params.dopamine)
        dopa_info = None
        if auto_dopamine:
            ad = AutoDopamine(self.data["meta"])
            dopa_info = ad.step(feat, label, self._all_prototypes())
            dopa_used = float(dopa_info["dopamine"])

        if rec["prototype"] is None:
            rec["prototype"] = feat.tolist()
            rec["prototype_norm"] = float(np.linalg.norm(feat))
        else:
            proto = np.array(rec["prototype"], dtype=np.float32)
            proto = ema_update(proto, feat, beta=0.9)
            rec["prototype"] = proto.tolist()
            rec["prototype_norm"] = float(np.linalg.norm(proto))

        label_spk = np.zeros((v1_spk.shape[0],), dtype=bool)
        label_spk[label_spk.size//2] = True
        if rec["W_sparse"] is None:
            rec["W_sparse"] = (0.01 * np.ones(v1_spk.shape[1], dtype=np.float32)).tolist()
        W = np.array(rec["W_sparse"], dtype=np.float32)
        lp = LearningParams(**asdict(learn_params)); lp.dopamine = float(dopa_used)
        W = stdp_update(W, v1_spk, label_spk, lp)
        rec["W_sparse"] = W.tolist()

        if it_params.use_it:
            self._it_learn(label, img_bgr, ret_params, v1_params, it_params)

        self.data["experience"]["n"] = int(self.data["experience"].get("n",0)) + 1

        try:
            it_scores = {}
            for lab in self.data["vision"].keys():
                sc = self._it_score_label(lab, img_bgr, ret_params, v1_params, it_params)
                if sc is not None: it_scores[lab] = sc
            if it_scores:
                best_sc = it_scores.get(label, 0.0)
                second = sorted([v for k,v in it_scores.items() if k != label], reverse=True)
                sec_sc = second[0] if second else 0.0
                margin = best_sc - sec_sc
                if margin < it_params.min_margin_target:
                    self.data["experience"]["n"] += 1
        except Exception:
            pass

        rec["count"] += 1
        rec["updated_at"] = now_iso()
        hist_item = {"t": rec["updated_at"], "feat_norm": float(np.linalg.norm(feat)),
                     "proto_norm": rec["prototype_norm"], "dopamine": float(dopa_used)}
        if dopa_info is not None:
            hist_item.update({"r": dopa_info["r"], "delta": dopa_info["delta"], "novelty": dopa_info["novelty"],
                              "uncertainty": dopa_info["uncertainty"], "rhat": dopa_info["rhat"],
                              "rho": dopa_info["rho"], "fatigue": dopa_info["fatigue"]})
        rec["history"].append(hist_item); rec["history"] = rec["history"][-40:]

        rec["params"] = {"retina": asdict(ret_params), "v1": asdict(v1_params),
                         "learn": asdict(learn_params), "auto_dopamine": bool(auto_dopamine)}

        try: self.save_async()
        except Exception: self.save()

        return float(dopa_used), dopa_info

    def recognize_image(self, img_bgr, ret_params=None, v1_params=None, it_params: ITParams=None, blend=0.6):
        ret_params = ret_params or RetinaParams()
        v1_params  = v1_params  or V1Params()
        it_params  = it_params  or ITParams()

        on_map, _ = retina_drive(img_bgr, ret_params)
        on_lgn = divisive_normalization(on_map)
        bank = build_gabor_bank(on_map.shape[0], on_map.shape[1], v1_params)
        energies = v1_energy_maps(on_lgn, bank)
        feat = spatial_pool(energies, grid=(8,8))

        v1_scores = {}
        for label, rec in self.data.get("vision", {}).items():
            if rec.get("prototype") is None: continue
            proto = np.array(rec["prototype"], dtype=np.float32)
            v1_scores[label] = float(cosine_sim(feat, proto))

        it_scores = {}
        if it_params.use_it:
            for label in self.data.get("vision", {}).keys():
                sc = self._it_score_label(label, img_bgr, ret_params, v1_params, it_params)
                if sc is not None: it_scores[label] = sc

        labels = set(v1_scores.keys()) | set(it_scores.keys())
        final = []
        for lab in labels:
            v1s = v1_scores.get(lab, -1.0); v1u = (v1s + 1.0) * 0.5
            its = it_scores.get(lab, 0.0)
            score01 = (1.0 - blend) * v1u + blend * its
            final.append((lab, score01))
        final.sort(key=lambda x: -x[1])
        return final[:5], feat

# Helpers for plotting in Colab
def simulate_spikes_for_image(img_bgr, ret_params: RetinaParams, v1_params: V1Params):
    on_map, off_map = retina_drive(img_bgr, ret_params)
    rgc_spk, rgc_rate, rgc_idx = rgc_spike_generator(on_map, off_map, gain=ret_params.drive_gain, nl_alpha=ret_params.nl_alpha, topk=min(1024, on_map.size//2))
    on_lgn = divisive_normalization(on_map)
    bank = build_gabor_bank(on_map.shape[0], on_map.shape[1], v1_params)
    energies = v1_energy_maps(on_lgn, bank)
    feat = spatial_pool(energies, grid=(8,8))
    v1_spk, v1_rate = features_to_spikes(feat)
    return rgc_spk, v1_spk, energies, on_map


In [None]:
import matplotlib.pyplot as plt

def normalize01(x, eps=1e-8):
    x = x.astype(np.float32)
    mn, mx = x.min(), x.max()
    if mx - mn < eps: return np.zeros_like(x)
    return (x - mn) / (mx - mn + eps)

def heat_overlay(base_bgr, heat, alpha=0.5, colormap=cv2.COLORMAP_TURBO):
    H, W = base_bgr.shape[:2]
    if heat.shape[:2] != (H, W):
        heat = cv2.resize(heat, (W, H), interpolation=cv2.INTER_LINEAR)
    heat8 = (normalize01(heat) * 255).astype(np.uint8)
    cmap = cv2.applyColorMap(heat8, colormap)
    return cv2.addWeighted(base_bgr, 1.0 - alpha, cmap, alpha, 0.0)

def draw_orientation_strokes(base_bgr, energies, grid=(16,16), scale_len=14, thickness=2):
    H, W = base_bgr.shape[:2]
    gh, gw = grid
    cell_h, cell_w = H // gh, W // gw
    L = len(energies)
    for guess in (8,12,6,4):
        if L % guess == 0:
            orientations, scales = guess, L // guess
            break
    else:
        orientations, scales = L, 1
    Eo = []
    for o in range(orientations):
        stack = []
        for s in range(scales):
            idx = s * orientations + o
            stack.append(energies[idx])
        Eo.append(np.sum(stack, axis=0))
    Esum = np.sum(Eo, axis=0) + 1e-6
    out = base_bgr.copy()
    angles = np.linspace(0, np.pi, orientations, endpoint=False)
    for gy in range(gh):
        for gx in range(gw):
            y0, x0 = gy*cell_h, gx*cell_w
            y1, x1 = min((gy+1)*cell_h, H), min((gx+1)*cell_w, W)
            cell_vals = [E[y0:y1, x0:x1].mean() for E in Eo]
            cell_sum = Esum[y0:y1, x0:x1].mean()
            if cell_sum <= 0: continue
            o_idx = int(np.argmax(cell_vals))
            strength = float(cell_vals[o_idx] / (cell_sum + 1e-6))
            if strength < 0.02: continue
            theta = angles[o_idx]
            cx, cy = (x0+x1)//2, (y0+y1)//2
            Lstroke = int(2 + scale_len * strength**0.5)
            dx, dy = int(np.cos(theta)*Lstroke), int(np.sin(theta)*Lstroke)
            hue = int(180 * (theta / np.pi))
            color = tuple(int(c) for c in cv2.cvtColor(np.uint8([[[hue,200,255]]]), cv2.COLOR_HSV2BGR)[0,0])
            cv2.line(out, (cx - dx, cy - dy), (cx + dx, cy + dy), color, thickness, lineType=cv2.LINE_AA)
    return out

def make_feature_previews(img_bgr, ret_params, v1_params):
    on_map, _ = retina_drive(img_bgr, ret_params)
    on_lgn = divisive_normalization(on_map)
    bank = build_gabor_bank(on_map.shape[0], on_map.shape[1], v1_params)
    energies = v1_energy_maps(on_lgn, bank)
    base_vis = cv2.cvtColor(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY), cv2.COLOR_GRAY2BGR)

    retina_overlay = heat_overlay(base_vis, on_map, alpha=0.55, colormap=cv2.COLORMAP_PLASMA)
    v1_sum = np.zeros_like(energies[0], dtype=np.float32)
    for e in energies: v1_sum += e.astype(np.float32)
    v1_overlay = heat_overlay(base_vis, v1_sum, alpha=0.55, colormap=cv2.COLORMAP_TURBO)
    orient_vis = draw_orientation_strokes(base_vis, energies, grid=(16,16), scale_len=14, thickness=2)
    return retina_overlay, v1_overlay, orient_vis

def spike_raster(ax, spikes, title):
    T, N = spikes.shape
    ys, xs = np.where(spikes > 0)
    ax.scatter(xs, ys, s=2, marker='|')
    ax.set_xlim(0, N); ax.set_ylim(T, 0)
    ax.set_xlabel("Neuron index"); ax.set_ylabel("Time (step)")
    ax.set_title(title)

def plot_neuron_graphs(img_bgr, ret_params, v1_params):
    rgc_spk, v1_spk, energies, on_map = simulate_spikes_for_image(img_bgr, ret_params, v1_params)
    fig = plt.figure(figsize=(13,8))
    ax1 = fig.add_subplot(2,2,1); spike_raster(ax1, rgc_spk, "RGC spike raster")
    ax2 = fig.add_subplot(2,2,2); spike_raster(ax2, v1_spk, "V1 spike raster")
    # small multiples for first 8 channels
    for i in range(min(8, len(energies))):
        ax = fig.add_subplot(2,4,i+1) if i<4 else fig.add_subplot(2,4,i+1)
    plt.tight_layout()
    return rgc_spk, v1_spk, energies, on_map


In [None]:
# ==== Robust image input loaders for Gradio (Colab) ====
import numpy as np, cv2

def _unwrap_gallery_item(item):
    """
    Gradio Gallery may return:
      - PIL.Image.Image
      - numpy.ndarray
      - (image, caption) tuple
      - dict with 'image'/'data'/'orig'/'value' keys
    Normalize to a bare PIL.Image or numpy array.
    """
    # Tuple: (image, caption/metadata)
    if isinstance(item, tuple) and len(item) > 0:
        item = item[0]

    # Dict: try common keys
    if isinstance(item, dict):
        for k in ("image", "data", "orig", "value"):
            if k in item and item[k] is not None:
                item = item[k]
                break

    return item

def np_from_input(image_like):
    """
    Convert various Gradio inputs (PIL / np / tuple / dict / file-like) into OpenCV BGR ndarray.
    """
    if image_like is None:
        return None

    item = _unwrap_gallery_item(image_like)

    # PIL.Image
    try:
        from PIL import Image
        if isinstance(item, Image.Image):
            arr = np.array(item.convert("RGB"))
            return arr[:, :, ::-1]  # RGB->BGR
    except Exception:
        pass

    # numpy array (H,W,3) RGB or BGR, or (H,W) gray
    if isinstance(item, np.ndarray):
        arr = item
        if arr.ndim == 2:
            arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
        elif arr.ndim == 3 and arr.shape[2] == 3:
            # Heuristic: assume RGB from Gradio and flip to BGR
            arr = arr[:, :, ::-1]
        return arr

    # file-like (bytes)
    if hasattr(item, "read"):
        try:
            data = item.read()
            buf = np.frombuffer(data, dtype=np.uint8)
            img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
            return img
        except Exception:
            return None

    # Couldn’t parse
    return None


In [None]:
# ==== Gradio UI (Record / Recognize + previews + table) ====
import gradio as gr

BRAIN_PATH = "/content/brain.json"
brain = VisionBrain(BRAIN_PATH)

def learn_images(label, auto_dopa, use_it, sigma_c, sigma_s, saccades, jitter, orients, scales, images):
    if not label or images is None or len(images) == 0:
        return "Provide a label and at least one image.", None, None, None, "—"

    ret = RetinaParams(sigma_c=float(sigma_c), sigma_s=float(sigma_s), saccades=int(saccades), jitter_px=int(jitter))
    v1  = V1Params(orientations=int(orients), scales=int(scales))
    itp = ITParams(use_it=bool(use_it))
    learn = LearningParams(dopamine=1.2)

    last = None
    D_used_last = None
    ok, fail = 0, 0

    # images can be PIL / np / tuple / dict — robust loader handles all.
    for item in images:
        img_bgr = np_from_input(item)
        if img_bgr is None:
            fail += 1
            continue
        try:
            D_used, _ = brain.record_image(
                img_bgr=img_bgr, label=label, ret_params=ret, v1_params=v1,
                learn_params=learn, store_dir="/content/spikes",
                auto_dopamine=bool(auto_dopa), it_params=itp
            )
            D_used_last = D_used; ok += 1; last = img_bgr
        except Exception as e:
            fail += 1

    retina_viz = v1_viz = orient_viz = None
    if last is not None:
        r, v, o = make_feature_previews(last, ret, v1)
        retina_viz = cv2.cvtColor(r, cv2.COLOR_BGR2RGB)
        v1_viz = cv2.cvtColor(v, cv2.COLOR_BGR2RGB)
        orient_viz = cv2.cvtColor(o, cv2.COLOR_BGR2RGB)

    dopa_text = f"{D_used_last:.3f}" if D_used_last is not None else "—"
    return f"Learned {ok} image(s), failed {fail}.", retina_viz, v1_viz, orient_viz, dopa_text

def recognize_image(use_it, blend, sigma_c, sigma_s, saccades, jitter, orients, scales, image):
    if image is None:
        return "Upload an image first.", None, None, None, []

    ret = RetinaParams(sigma_c=float(sigma_c), sigma_s=float(sigma_s), saccades=int(saccades), jitter_px=int(jitter))
    v1  = V1Params(orientations=int(orients), scales=int(scales))
    itp = ITParams(use_it=bool(use_it))

    img_bgr = np_from_input(image)
    if img_bgr is None:
        return "Invalid image input.", None, None, None, [["—","Bad input"]]

    top5, feat = brain.recognize_image(img_bgr, ret_params=ret, v1_params=v1, it_params=itp, blend=float(blend))
    r, v, o = make_feature_previews(img_bgr, ret, v1)
    retina_viz = cv2.cvtColor(r, cv2.COLOR_BGR2RGB)
    v1_viz = cv2.cvtColor(v, cv2.COLOR_BGR2RGB)
    orient_viz = cv2.cvtColor(o, cv2.COLOR_BGR2RGB)

    table = [[lab, float(f"{sc:.3f}")] for lab, sc in top5] if top5 else [["—","No prototypes"]]
    return f"Done. Best: {top5[0][0]} ({top5[0][1]:.3f})" if top5 else "No match.", retina_viz, v1_viz, orient_viz, table

with gr.Blocks(title="BabyAI Vision — Colab") as demo:
    gr.Markdown("## BabyAI Vision — Retina→V1→IT with Auto-Dopamine (Colab)")
    with gr.Tabs():
        with gr.Tab("Record (Learn)"):
            with gr.Row():
                with gr.Column(scale=2):
                    label = gr.Textbox(label="Label", value="eman")
                    # Gallery can emit tuples/dicts; np_from_input handles it.
                    images = gr.Gallery(label="Training Images", show_label=True,
                                        columns=3, height=240, allow_preview=True, type="pil")
                    learn_btn = gr.Button("Learn Images", variant="primary")
                    status = gr.Markdown("Status: Ready")
                with gr.Column(scale=1):
                    auto_dopa = gr.Checkbox(label="Auto Dopamine", value=True)
                    use_it = gr.Checkbox(label="Use IT invariance", value=True)
                    sigma_c = gr.Slider(0.5, 4.0, value=1.5, step=0.1, label="Retina sigma_c")
                    sigma_s = gr.Slider(0.5, 8.0, value=3.5, step=0.1, label="Retina sigma_s")
                    saccades = gr.Slider(1, 8, value=5, step=1, label="Microsaccades")
                    jitter = gr.Slider(0, 6, value=2, step=1, label="Jitter (px)")
                    orients = gr.Slider(4, 16, value=8, step=1, label="V1 orientations")
                    scales = gr.Slider(1, 6, value=4, step=1, label="V1 scales")
            with gr.Row():
                retina_img = gr.Image(label="Retina (DoG heat overlay)", interactive=False)
                v1_img = gr.Image(label="V1 Energy heat overlay", interactive=False)
                orient_img = gr.Image(label="V1 Orientation map", interactive=False)
            dopa_used = gr.Textbox(label="Dopamine used (last)", value="—")
            learn_btn.click(
                fn=learn_images,
                inputs=[label, auto_dopa, use_it, sigma_c, sigma_s, saccades, jitter, orients, scales, images],
                outputs=[status, retina_img, v1_img, orient_img, dopa_used],
            )

        with gr.Tab("Recognize"):
            with gr.Row():
                with gr.Column(scale=2):
                    image = gr.Image(label="Query Image", type="pil")
                    rec_btn = gr.Button("Recognize", variant="primary")
                    rec_status = gr.Markdown("Status: Ready")
                with gr.Column(scale=1):
                    use_it2 = gr.Checkbox(label="Use IT invariance", value=True)
                    blend = gr.Slider(0.0, 1.0, value=0.6, step=0.05, label="Blend (0=V1, 1=IT)")
                    sigma_c2 = gr.Slider(0.5, 4.0, value=1.5, step=0.1, label="Retina sigma_c")
                    sigma_s2 = gr.Slider(0.5, 8.0, value=3.5, step=0.1, label="Retina sigma_s")
                    saccades2 = gr.Slider(1, 8, value=5, step=1, label="Microsaccades")
                    jitter2 = gr.Slider(0, 6, value=2, step=1, label="Jitter (px)")
                    orients2 = gr.Slider(4, 16, value=8, step=1, label="V1 orientations")
                    scales2 = gr.Slider(1, 6, value=4, step=1, label="V1 scales")
            with gr.Row():
                retina_img2 = gr.Image(label="Retina (DoG heat overlay)", interactive=False)
                v1_img2 = gr.Image(label="V1 Energy heat overlay", interactive=False)
                orient_img2 = gr.Image(label="V1 Orientation map", interactive=False)
            table = gr.Dataframe(headers=["Label","Score"], row_count=5)

            rec_btn.click(
                fn=recognize_image,
                inputs=[use_it2, blend, sigma_c2, sigma_s2, saccades2, jitter2, orients2, scales2, image],
                outputs=[rec_status, retina_img2, v1_img2, orient_img2, table],
            )

demo.launch(debug=False, share=False)


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.
* To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

