In [23]:
# ================================================
# Colab: Wireframe OCT Classifier (Simple UI + Grad-CAM)
# ================================================

# ---------- 0) Install deps ----------
!pip -q install streamlit==1.37.1 pillow numpy pandas timm pyngrok \
  opencv-python-headless==4.10.0.84 \
  --extra-index-url https://download.pytorch.org/whl/cu121
# Torch (CU121 wheels)
!pip -q install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121

# ---------- 1) Drive mount ----------
from google.colab import drive
drive.mount('/content/drive')

# ---------- 2) CONFIG ----------
WEIGHTS_DRIVE_PATH = "/content/drive/MyDrive/Edu/My ICBT/TOPUP/RESEARCH/resnet18_oct2017_final_weights.pth"
NGROK_AUTH_TOKEN   = "30W6skuqZhzyrPnVSLr2EsiTBbf_5t8t1gHHNR7CSNczLzCDy"

# ---------- 3) Project folder ----------
import os, shutil, time, subprocess
APP_DIR = "/content/app"
os.makedirs(APP_DIR, exist_ok=True)
%cd $APP_DIR

# Copy weights if found
dst_ckpt = os.path.join(APP_DIR, "resnet18_oct2017_final_weights.pth")
if os.path.exists(WEIGHTS_DRIVE_PATH):
    shutil.copy2(WEIGHTS_DRIVE_PATH, dst_ckpt)
    print(f"[OK] Copied weights -> {dst_ckpt}")
else:
    print(f"[WARN] Weights not found at: {WEIGHTS_DRIVE_PATH}")
    print("      Update WEIGHTS_DRIVE_PATH above and rerun if needed.")

# ---------- 4) predictor.py ----------
from pathlib import Path
Path("predictor.py").write_text(r"""# predictor.py
import io, pickle
from typing import List, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision import models
from PIL import Image
import cv2
import timm  # for DeiT / ViTs

class PadAndResize:
    def __init__(self, target_size=224, fill_color=(0,0,0)):
        self.target_size = target_size
        self.fill_color = fill_color
    def __call__(self, img: Image.Image):
        w, h = img.size
        max_side = max(w, h)
        canvas = Image.new("RGB", (max_side, max_side), self.fill_color)
        canvas.paste(img, ((max_side - w)//2, (max_side - h)//2))
        try:
            resample = Image.Resampling.BILINEAR
        except AttributeError:
            resample = Image.BILINEAR
        return canvas.resize((self.target_size, self.target_size), resample)

DEFAULT_CLASSES = ["CNV", "DME", "DRUSEN", "NORMAL"]

def _build_model(model_name: str, num_classes: int) -> nn.Module:
    if model_name == "resnet18":
        m = models.resnet18(weights=None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
    elif model_name == "resnet50":
        m = models.resnet50(weights=None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
    elif model_name == "convnext_tiny":
        m = models.convnext_tiny(weights=None)
        m.classifier[-1] = nn.Linear(m.classifier[-1].in_features, num_classes)
    elif model_name == "deit_small_distilled_patch16_224":
        m = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
    else:
        raise ValueError(f"Unsupported model_name: {model_name}")
    return m

def _is_state_dict_like(obj: dict) -> bool:
    if not isinstance(obj, dict):
        return False
    if "state_dict" in obj: return True
    if "model" in obj and isinstance(obj["model"], dict): return True
    return any(("weight" in k) or ("bias" in k) for k in list(obj.keys()))

def _safe_load_any(path: str):
    try:
        return torch.load(path, map_location="cpu")
    except Exception:
        with open(path, "rb") as f:
            return pickle.load(f)

def _normalize_state_dict(obj):
    if isinstance(obj, dict):
        if "state_dict" in obj and isinstance(obj["state_dict"], dict):
            sd = obj["state_dict"]
        elif "model" in obj and isinstance(obj["model"], dict):
            sd = obj["model"]
        else:
            sd = obj
    else:
        raise ValueError("Unsupported checkpoint format")
    def strip_prefix(d, pfx):
        if any(k.startswith(pfx) for k in d.keys()):
            return {k[len(pfx):]: v for k, v in d.items()}
        return d
    sd = strip_prefix(sd, "_orig_mod.")
    sd = strip_prefix(sd, "module.")
    sd = strip_prefix(sd, "model.")
    return sd

def load_model(ckpt_path: str,
               fallback_model_name: str = "resnet18",
               fallback_classes: Optional[List[str]] = None,
               img_size: int = 224):
    blob = _safe_load_any(ckpt_path)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    mean = [0.485, 0.456, 0.406]; std = [0.229, 0.224, 0.225]

    if isinstance(blob, dict) and "state_dict" in blob and "model_name" in blob:
        model_name = blob["model_name"]
        class_names = blob.get("class_names", fallback_classes or DEFAULT_CLASSES)
        img_size = int(blob.get("img_size", img_size))
        mean = blob.get("normalize_mean", mean); std  = blob.get("normalize_std",  std)
        state_dict = _normalize_state_dict(blob)
    elif isinstance(blob, dict) and _is_state_dict_like(blob):
        model_name = fallback_model_name
        class_names = fallback_classes or DEFAULT_CLASSES
        state_dict = _normalize_state_dict(blob)
    else:
        model_name = fallback_model_name
        class_names = fallback_classes or DEFAULT_CLASSES
        state_dict = blob

    model = _build_model(model_name, num_classes=len(class_names))
    try:
        model.load_state_dict(state_dict, strict=True)
    except RuntimeError:
        changed = False
        if hasattr(model, "head") and isinstance(model.head, nn.Linear):
            if model.head.out_features != len(class_names):
                model.head = nn.Linear(model.head.in_features, len(class_names)); changed = True
        if hasattr(model, "head_dist") and isinstance(model.head_dist, nn.Linear):
            if model.head_dist.out_features != len(class_names):
                model.head_dist = nn.Linear(model.head_dist.in_features, len(class_names)); changed = True
        model.load_state_dict(state_dict, strict=False if changed else False)

    model.to(device).eval()
    preprocess = T.Compose([
        PadAndResize(target_size=img_size, fill_color=(0,0,0)),
        T.Grayscale(num_output_channels=3),
        T.ToTensor(),
        T.Normalize(mean=mean, std=std),
    ])
    return model, preprocess, class_names, device, img_size, model_name

@torch.inference_mode()
def predict_image_bytes(model, preprocess, class_names, device, img_bytes: bytes) -> Dict:
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    x = preprocess(img).unsqueeze(0).to(device)
    logits = model(x)
    probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy().tolist()
    best_idx = int(torch.argmax(logits, dim=1).item())
    return {"pred_label": class_names[best_idx],
            "pred_idx": best_idx,
            "probs": {cls: float(probs[i]) for i, cls in enumerate(class_names)}}

def _resolve_target_layer(model: nn.Module, model_name: str):
    if model_name == "resnet18":
        return getattr(model.layer4[-1], "conv2", model.layer4[-1])
    if model_name == "resnet50":
        return getattr(model.layer4[-1], "conv3", model.layer4[-1])
    if model_name == "convnext_tiny":
        try: return model.features[6][-1].dwconv
        except Exception: return model.features[6]
    return None

class GradCAM:
    def __init__(self, model: nn.Module, target_layer: nn.Module):
        self.model = model; self.target_layer = target_layer
        self._activations = None; self._gradients = None
        self._fwd_handle = self.target_layer.register_forward_hook(self._forward_hook)
        self._bwd_handle = self.target_layer.register_full_backward_hook(self._backward_hook)
    def _forward_hook(self, module, inp, out): self._activations = out.clone()
    def _backward_hook(self, module, grad_in, grad_out): self._gradients = grad_out[0]
    def remove(self): self._fwd_handle.remove(); self._bwd_handle.remove()
    def __call__(self, input_tensor: torch.Tensor, class_idx: Optional[int] = None):
        self.model.zero_grad(set_to_none=True)
        with torch.inference_mode(False), torch.enable_grad():
            x = input_tensor.clone().detach().requires_grad_(True)
            logits = self.model(x)
        if class_idx is None: class_idx = int(torch.argmax(logits, dim=1).item())
        one_hot = torch.zeros_like(logits); one_hot[0, class_idx] = 1.0
        with torch.inference_mode(False): logits.backward(gradient=one_hot)
        A = self._activations; dA = self._gradients
        if A is None or dA is None: raise RuntimeError("Grad-CAM hooks failed.")
        weights = dA.mean(dim=(2,3), keepdim=True)
        cam = (weights * A).sum(dim=1).relu().squeeze(0)
        cam -= cam.min(); cam = cam / (cam.max() + 1e-8)
        return cam.detach().cpu().numpy(), class_idx, logits.detach()

def _apply_colormap_on_image(pil_img: Image.Image, cam: np.ndarray, alpha: float=0.35):
    img = np.array(pil_img.convert("RGB")); H, W = img.shape[:2]
    cam_resized = cv2.resize(cam, (W, H))
    heat = np.uint8(255 * cam_resized); heat = cv2.applyColorMap(heat, cv2.COLORMAP_JET)
    base = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    over = cv2.addWeighted(heat, alpha, base, 1 - alpha, 0)
    return Image.fromarray(cv2.cvtColor(over, cv2.COLOR_BGR2RGB))

def gradcam_image_bytes(model, preprocess, device, model_name, img_bytes, target_class, img_size, alpha=0.35):
    target_layer = _resolve_target_layer(model, model_name)
    if target_layer is None:
        raise NotImplementedError("Grad-CAM only for CNN backbones (ResNet/ConvNeXt).")
    pil = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    x = preprocess(pil).unsqueeze(0).to(device)
    cam_obj = GradCAM(model, target_layer)
    with torch.inference_mode(False), torch.enable_grad():
        logits = model(x)
    pred_idx = int(torch.argmax(logits, dim=1).item())
    class_idx = target_class if target_class is not None else pred_idx
    cam_map, used_idx, _ = cam_obj(x, class_idx=class_idx)
    cam_obj.remove()
    base = PadAndResize(target_size=img_size, fill_color=(0,0,0))(pil)
    overlay = _apply_colormap_on_image(base, cam_map, alpha=alpha)
    probs = torch.softmax(logits, dim=1).squeeze(0).detach().cpu().numpy().tolist()
    return overlay, used_idx, probs
""", encoding="utf-8")

# ---------- 5) app.py (simple wireframe UI; no preset/preset radio/CSV) ----------
Path("app.py").write_text(r"""# app.py — simple wireframe UI
import json, pandas as pd, streamlit as st
from predictor import load_model, predict_image_bytes, gradcam_image_bytes, DEFAULT_CLASSES

st.set_page_config(page_title="Retinal OCT Classifier", page_icon="👁️", layout="wide")

WIREFRAME_CSS = '''
<style>
.block-container { padding-top: 1.25rem; padding-bottom: 1rem; }
.w-card { border: 2px solid #1f1f1f20; border-radius: 8px; padding: 16px 18px; background: #fafafa; }
.w-title { font-weight: 700; margin-bottom: 6px; }
.w-subtle { color: #666; font-size: 0.9rem; }
.stButton>button, .stDownloadButton>button { border: 1px solid #33333355; background: #e9e9e9; color: #111; border-radius: 8px; }
section[data-testid="stSidebar"] .stSelectbox,
section[data-testid="stSidebar"] .stTextInput,
section[data-testid="stSidebar"] .stNumberInput,
section[data-testid="stSidebar"] .stCheckbox,
section[data-testid="stSidebar"] .stSlider { border-bottom: 1px solid #cfcfcf; padding-bottom: .35rem; margin-bottom: .65rem; }
</style>
'''
st.markdown(WIREFRAME_CSS, unsafe_allow_html=True)

st.title("Retinal OCT Classifier")
st.caption("CNV · DME · DRUSEN · NORMAL | padded-resize input | Grad-CAM for CNN backbones")

# ----- Sidebar: minimal controls -----
with st.sidebar:
    st.header("Model Options")
    backbone = st.selectbox("Backbone", ["resnet18", "deit_small_distilled_patch16_224"], index=0)
    ckpt_path = st.text_input("Checkpoint path", "resnet18_oct2017_final_weights.pth")
    img_size = st.number_input("Image size (px)", min_value=128, max_value=1024, value=224, step=32)
    st.subheader("Grad-CAM")
    show_cam = st.checkbox("Show Grad-CAM", value=True)
    cam_alpha = st.slider("Heatmap strength", 0.1, 0.9, 0.35, 0.05)

if not ckpt_path:
    st.warning("Provide a checkpoint path in the sidebar.")
    st.stop()

try:
    model, preprocess, class_names, device, real_imgsz, real_modelname = load_model(
        ckpt_path, fallback_model_name=backbone, fallback_classes=DEFAULT_CLASSES, img_size=int(img_size)
    )
    st.success(f"Loaded {real_modelname} on {device} · classes: {', '.join(class_names)} · image {real_imgsz}px")
except Exception as e:
    st.error(f"Load error: {e}")
    st.stop()

# ----- Upload card -----
st.markdown('<div class="w-card">', unsafe_allow_html=True)
st.markdown('<div class="w-title">Upload OCT Image</div><div class="w-subtle">Drag & drop file here or use Browse</div>', unsafe_allow_html=True)
up = st.file_uploader("", type=["png","jpg","jpeg","tif","bmp"])
st.markdown('</div>', unsafe_allow_html=True)

if up is None:
    st.info("Tip: toggle Grad-CAM in the sidebar. Anonymize patient data before upload.")
    st.stop()

raw = up.read()

# ----- Middle row: Original + Grad-CAM -----
c1, c2 = st.columns([1,1], gap="large")
with c1:
    st.markdown('<div class="w-card"><div class="w-title">Original OCT</div><div class="w-subtle">uploaded OCT image</div>', unsafe_allow_html=True)
    st.image(up, use_column_width=True)
    st.markdown("</div>", unsafe_allow_html=True)
with c2:
    st.markdown('<div class="w-card"><div class="w-title">Grad-CAM Heatmap</div><div class="w-subtle">red = higher importance</div>', unsafe_allow_html=True)
    if show_cam:
        try:
            overlay, used_idx, _ = gradcam_image_bytes(model, preprocess, device, real_modelname, raw,
                                                       target_class=None, img_size=real_imgsz, alpha=cam_alpha)
            st.image(overlay, use_column_width=True)
        except NotImplementedError as e:
            st.info(f"{e}  Switch to a CNN (e.g., ResNet18) to view Grad-CAM.")
        except Exception as e:
            st.warning(f"Grad-CAM failed: {e}")
    else:
        st.caption("Grad-CAM disabled")
    st.markdown("</div>", unsafe_allow_html=True)

# ----- Bottom row: Prediction + Probabilities -----
p1, p2 = st.columns([1,1], gap="large")
with p1:
    st.markdown('<div class="w-card"><div class="w-title">Prediction</div><div class="w-subtle">Top-1 class and confidence</div>', unsafe_allow_html=True)
    try:
        out = predict_image_bytes(model, preprocess, class_names, device, raw)
    except Exception as e:
        st.error(f"Prediction error: {e}")
        st.stop()
    pred_label = out["pred_label"]; conf = float(out["probs"].get(pred_label, 0.0))
    st.subheader(f"Predicted class: {pred_label}")
    st.write(f"Confidence: {conf:.2f}")
    st.download_button("Download results",
                       data=json.dumps({"prediction": pred_label, "confidence": conf, "probs": out["probs"]}, indent=2),
                       file_name="oct_prediction.json", mime="application/json")
    st.markdown("</div>", unsafe_allow_html=True)
with p2:
    st.markdown('<div class="w-card"><div class="w-title">Class Probabilities</div>', unsafe_allow_html=True)
    df = pd.DataFrame({"Class": list(out["probs"].keys()),
                       "Probability": [float(v) for v in out["probs"].values()]}) \
         .sort_values("Class").reset_index(drop=True)
    st.dataframe(df, use_container_width=True, hide_index=True)
    st.markdown("</div>", unsafe_allow_html=True)
""", encoding="utf-8")

print("Files:", os.listdir(APP_DIR))



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/app
[OK] Copied weights -> /content/app/resnet18_oct2017_final_weights.pth
Files: ['predictor.py', '__pycache__', 'app.py', 'resnet18_oct2017_final_weights.pth']


In [24]:
# ---------- 6) Run Streamlit with ngrok (single-session, cleaned) ----------
import subprocess, time, os, signal
from pyngrok import ngrok, conf

# 1) Hard kill any leftover ngrok/Streamlit processes in this runtime
!pkill -f "ngrok" || true
!pkill -f "streamlit run app.py" || true

# 2) Double-check with pyngrok API and close any existing tunnels
try:
    for t in ngrok.get_tunnels():
        try:
            ngrok.disconnect(t.public_url)
        except Exception:
            pass
    ngrok.kill()
except Exception:
    pass

# 3) Your token (single agent only)
NGROK_AUTH_TOKEN = "30W6skuqZhzyrPnVSLr2EsiTBbf_5t8t1gHHNR7CSNczLzCDy"
!ngrok config add-authtoken "$NGROK_AUTH_TOKEN"

# 4) Start Streamlit server
server = subprocess.Popen(
    [
        "streamlit", "run", "app.py",
        "--server.port", "8501",
        "--server.address", "0.0.0.0",
        "--server.headless", "true",
    ],
    stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1
)

# 5) Give Streamlit a moment to boot
time.sleep(5)

# 6) Open ONE public tunnel (HTTP)
public_tunnel = ngrok.connect(8501, "http")
print("🔗 Public URL:", public_tunnel.public_url)
print("----- STREAMLIT LOGS -----")
try:
    for line in server.stdout:
        if line is None:
            break
        print(line, end="")
except KeyboardInterrupt:
    print("\n[Interrupted]")


^C
^C
Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml
🔗 Public URL: https://0df02015ddfe.ngrok-free.app
----- STREAMLIT LOGS -----

Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.


  You can now view your Streamlit app in your browser.

  URL: http://0.0.0.0:8501

2025-09-06 18:17:35.059 `label` got an empty value. This is discouraged for accessibility reasons and may be disallowed in the future by raising an exception. Please provide a non-empty label and hide it with label_visibility if needed.
2025-09-06 18:17:35.471 Examining the path of torch.classes raised: Tried to instantiate class '__path__._path', but it does not exist! Ensure that it is registered via torch::class_
2025-09-06 18:17:42.174 `label` got an empty value. This is discouraged for accessibility reasons and may be disallowed in the future by raising an exception. Please provide a non-empty label and hide it with label_visibility if needed.
2025-09-06 18:17:42.9