<a href="https://colab.research.google.com/github/BlacqTangent/BlacqTangent.github.io/blob/main/Inference_Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import os, sys, json, math, time, random, platform
from pathlib import Path
from typing import Dict, Any, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

random.seed(42)
np.random.seed(42)
plt.rcParams.update({"figure.figsize": (6, 4), "axes.grid": True})

print({
    "python": platform.python_version(),
    "cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES", "unset"),
})

{'python': '3.12.12', 'cuda_visible_devices': 'unset'}


In [9]:

!nvidia-smi || echo "No NVIDIA GPU detected"
!python -V

# Upgrade and install core deps
!pip -q install --upgrade pip setuptools wheel
!pip -q install transformers datasets peft accelerate huggingface_hub pillow opencv-python-headless hf_transfer bitsandbytes gdown

# Enable HF accelerated transfers
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import torch
print("torch:", torch.__version__, "cuda available:", torch.cuda.is_available())

Mon Oct 27 18:39:07 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   49C    P8              9W /   70W |       2MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [24]:
from huggingface_hub import notebook_login


notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [11]:
import huggingface_hub as hub
print("huggingface_hub version:", hub.__version__)

huggingface_hub version: 0.35.3


In [12]:
# Robust import for HfHubHTTPError across versions
try:
    from huggingface_hub.errors import HfHubHTTPError  # new path
except ImportError:
    from huggingface_hub.utils import HfHubHTTPError   # older path

from huggingface_hub import HfApi, HfFolder

print("Token present:", bool(HfFolder.get_token()))
api = HfApi()
try:
    info = api.model_info("google/medgemma-4b-it")
    print("Access OK:", info.modelId)
except HfHubHTTPError as e:
    print("Access error:", e)
except Exception as e:
    print("Unexpected error:", repr(e))

Token present: True
Access OK: google/medgemma-4b-it


In [13]:
# 3) Environment-aware local dataset configuration (no Drive)
import sys
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules
USE_LOCAL = True  # always use a local dataset (both on laptop and in Colab)

if IN_COLAB:
    LOCAL_DATA_ROOT = Path('/content/data/processed')
else:
    LOCAL_DATA_ROOT = Path('/Users/macbookair/derm-reasoning/David-Akan/data/processed')  # <- change if needed on your laptop

DATA_ROOT = LOCAL_DATA_ROOT
DATA_ROOT.mkdir(parents=True, exist_ok=True)
print('[Local] Using local dataset root:', DATA_ROOT, '| IN_COLAB =', IN_COLAB)

[Local] Using local dataset root: /content/data/processed | IN_COLAB = True


In [14]:
# 2) Load model
import torch, gc
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = 'google/medgemma-4b-it'
def load_model_processor(model_id: str = MODEL_ID):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    model = None
    try:
        import bitsandbytes as bnb  # noqa
        print('Loading 4-bit...')
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', load_in_4bit=True, bnb_4bit_quant_type='nf4', trust_remote_code=True)
    except Exception as e:
        print('4-bit failed, fallback to fp16/cpu:', e)
    if model is None:
        dtype = torch.float16 if device=='cuda' else torch.float32
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True)
        model.to(device)
    model.eval()
    return model, processor
model, processor = load_model_processor(MODEL_ID)
DEVICE = next(model.parameters()).device
print('Device:', DEVICE)

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Loading 4-bit...


config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

Device: cuda:0


In [15]:
# 5c) Colab uploader to add images into class folders (eczema, fungal, scabies)
from pathlib import Path
def _ensure_dirs():
    for c in ['eczema','fungal','scabies']:(DATA_ROOT/c).mkdir(parents=True, exist_ok=True)
_ensure_dirs()
try:
    from google.colab import files  # type: ignore
    IN_COLAB=True
except Exception:
    IN_COLAB=False
def upload_images_to_class(label: str):
    label=str(label).strip().lower()
    assert label in {'eczema','fungal','scabies'}, "label must be one of: 'eczema','fungal','scabies'"
    dest = (DATA_ROOT/label)
    dest.mkdir(parents=True, exist_ok=True)
    if not IN_COLAB:
        print('[Uploader] Not in Colab. Please manually place files under:', dest)
        return
    print(f"[Uploader] Select files to upload into {dest}...")
    uploaded = files.upload()  # opens file picker
    count=0
    for name, data in uploaded.items():
        out = dest/name if isinstance(name, Path) else dest/str(name)
        try:
            # Handle both dict-of-bytes and UploadedFile-like objects
            if isinstance(data, (bytes, bytearray)):
                payload = data
            elif hasattr(data, 'content'):
                payload = data.content
            else:
                try: payload = bytes(data)
                except Exception: payload = None
            if payload is None:
                print('[Uploader][Skip] Could not read payload for', name)
                continue
            with open(out, 'wb') as f: f.write(payload)
            count+=1
        except Exception as e:
            print('[Uploader][Err]', name, e)
    print(f"[Uploader] Saved {count} file(s) to {dest}")
print("Call upload_images_to_class('eczema')  # or 'fungal' / 'scabies'")

Call upload_images_to_class('eczema')  # or 'fungal' / 'scabies'


In [16]:
 # 5e) Batch uploader (Colab): pick files for eczema, fungal, then scabies
from pathlib import Path
from glob import glob
def _count_images(d: Path):
    pats=('*.jpg','*.jpeg','*.png','*.JPG','*.JPEG','*.PNG','*.webp','*.WEBP')
    return sum(len(glob(str(d/p))) for p in pats)
try:
    from google.colab import files  # type: ignore
    IN_COLAB=True
except Exception:
    IN_COLAB=False
if not IN_COLAB:
    print('[Batch Uploader] Not in Colab. Use filesystem to place files under', DATA_ROOT)
else:
    for label in ['eczema','fungal','scabies']:
        dest=(DATA_ROOT/label); dest.mkdir(parents=True, exist_ok=True)
        print(f"[Batch] Select files for '{label}'...")
        up=files.upload()
        saved=0
        for name, data in up.items():
            out=dest/str(name)
            payload = data if isinstance(data,(bytes,bytearray)) else getattr(data,'content',None)
            if payload is None:
                try: payload=bytes(data)
                except Exception: payload=None
            if payload is None:
                print('[Batch][Skip] Could not read payload for', name)
                continue
            with open(out,'wb') as f: f.write(payload)
            saved+=1
        print(f"[Batch] Saved {saved} to {dest}")
    print('Counts after batch upload:')
    for c in ['eczema','fungal','scabies']:
        print(f'  {c:8s}:', _count_images(DATA_ROOT/c))

[Batch] Select files for 'eczema'...


Saving AA00970170_1.jpg to AA00970170_1.jpg
Saving AA00970180_2.jpg to AA00970180_2.jpg
Saving AA00970203_3.jpg to AA00970203_3.jpg
Saving AA00970209_1.jpg to AA00970209_1.jpg
Saving AA00970397_3.jpg to AA00970397_3.jpg
Saving AA00970540_3.jpg to AA00970540_3.jpg
Saving AA00970533_2.jpg to AA00970533_2.jpg
Saving AA00970523_2.jpg to AA00970523_2.jpg
Saving AA00970513_1.jpg to AA00970513_1.jpg
Saving AA00970512_1.jpg to AA00970512_1.jpg
Saving AA00970541_1.jpg to AA00970541_1.jpg
Saving AA00970590_1.jpg to AA00970590_1.jpg
Saving AA00970591_2.jpg to AA00970591_2.jpg
Saving AA00970597_1.jpg to AA00970597_1.jpg
Saving AA00970606_2.jpg to AA00970606_2.jpg
Saving AA00970621_5.jpg to AA00970621_5.jpg
Saving AA00970621_4.jpg to AA00970621_4.jpg
Saving AA00970611_3.jpg to AA00970611_3.jpg
Saving AA00970608_3.jpg to AA00970608_3.jpg
Saving AA00970608_2.jpg to AA00970608_2.jpg
Saving AA00970622_2.jpg to AA00970622_2.jpg
Saving AA00970631_3.jpg to AA00970631_3.jpg
Saving AA00970633_3.jpg to AA009

Saving AA00970554_2.jpg to AA00970554_2.jpg
Saving AA00970556_1.jpg to AA00970556_1.jpg
Saving AA00970559_1.jpg to AA00970559_1.jpg
Saving AA00970600_1.jpg to AA00970600_1.jpg
Saving AA00970615_1.jpg to AA00970615_1.jpg
Saving AA00970725_1.jpg to AA00970725_1.jpg
Saving AA00970724_3.jpg to AA00970724_3.jpg
Saving AA00970711_1.jpg to AA00970711_1.jpg
Saving AA00970709_1.jpg to AA00970709_1.jpg
Saving AA00970681_1.jpg to AA00970681_1.jpg
Saving AA00970746_1.jpg to AA00970746_1.jpg
Saving AA00970756_2.jpg to AA00970756_2.jpg
Saving AA00970840_2.jpg to AA00970840_2.jpg
Saving AA00970855_1.jpg to AA00970855_1.jpg
Saving AA00970859_1.jpg to AA00970859_1.jpg
Saving AA00970963_1.jpg to AA00970963_1.jpg
Saving AA00970956_1.jpg to AA00970956_1.jpg
Saving AA00970945_1.jpg to AA00970945_1.jpg
Saving AA00970916_1.jpg to AA00970916_1.jpg
Saving AA00970915_1.jpg to AA00970915_1.jpg
Saving AA00970967_1.jpg to AA00970967_1.jpg
Saving AA00970983_1.jpg to AA00970983_1.jpg
Saving AA00970998_1.jpg to AA009

Saving AA00970047_1.jpg to AA00970047_1.jpg
Saving AA00970109_2.jpg to AA00970109_2.jpg
Saving AA00970110_1.jpg to AA00970110_1.jpg
Saving AA00970150_3.jpg to AA00970150_3.jpg
Saving AA00970161_1.jpg to AA00970161_1.jpg
Saving AA00970173_1.jpg to AA00970173_1.jpg
Saving AA00970206_1.jpg to AA00970206_1.jpg
Saving AA00970215_2.jpg to AA00970215_2.jpg
Saving AA00970241_2.jpg to AA00970241_2.jpg
Saving AA00970299_3.jpg to AA00970299_3.jpg
Saving AA00970315_1.jpg to AA00970315_1.jpg
Saving AA00970398_2.jpg to AA00970398_2.jpg
Saving AA00970399_3.jpg to AA00970399_3.jpg
Saving AA00970402_2.jpg to AA00970402_2.jpg
Saving AA00970435_2.jpg to AA00970435_2.jpg
Saving AA00970645_1.jpg to AA00970645_1.jpg
Saving AA00970536_1.jpg to AA00970536_1.jpg
Saving AA00970506_2.jpg to AA00970506_2.jpg
Saving AA00970506_1.jpg to AA00970506_1.jpg
Saving AA00970435_3.jpg to AA00970435_3.jpg
Saving AA00970712_1.jpg to AA00970712_1.jpg
Saving AA00970713_1.jpg to AA00970713_1.jpg
Saving AA00970726_1.jpg to AA009

In [17]:
# 5a) Quick data check (optional): counts and a few sample paths
from glob import glob
from pathlib import Path
def _count_images(d: Path):
    pats=('*.jpg','*.jpeg','*.png','*.JPG','*.JPEG','*.PNG','*.webp','*.WEBP')
    return sum(len(glob(str(d/p))) for p in pats)
print('DATA_ROOT =', DATA_ROOT)
for c in ['eczema','fungal','scabies']:
    d=DATA_ROOT/c
    print(f'  {c:8s}:', _count_images(d))
    samples=[]
    for p in ('*.jpg','*.jpeg','*.png','*.webp','*.JPG','*.JPEG','*.PNG','*.WEBP'):
        samples+=glob(str(d/p))
    print('    sample:', [Path(s).name for s in sorted(samples)[:3]])

DATA_ROOT = /content/data/processed
  eczema  : 50
    sample: ['AA00970170_1.jpg', 'AA00970180_2.jpg', 'AA00970203_3.jpg']
  fungal  : 50
    sample: ['AA00970554_2.jpg', 'AA00970556_1.jpg', 'AA00970559_1.jpg']
  scabies : 50
    sample: ['AA00970047_1.jpg', 'AA00970109_2.jpg', 'AA00970110_1.jpg']


In [18]:
# 3) Robust classifier utilities (image-token safe)
import math, numpy as np, pandas as pd
from PIL import Image
from glob import glob
from collections import defaultdict
from typing import List, Dict, Tuple
LABELS = ['eczema','fungal','scabies']
IMAGE_TOKEN_OVERRIDE = None
LAST_IMAGE_TOKEN_STR=LAST_IMAGE_TOKEN_ID=LAST_IMAGE_TOKEN_COUNT=None
def _sanitize(t):
    return torch.nan_to_num(t, nan=-1e9, posinf=1e9, neginf=-1e9) if not torch.isfinite(t).all() else t
def _softmax(x, dim=-1):
    x=_sanitize(x); x=x-x.max(dim=dim,keepdim=True).values; return torch.softmax(x,dim=dim)
def _discover_image_tokens(proc):
    toks=[]
    try:
        tok=getattr(proc,'tokenizer',None)
        if tok is not None:
            for d in [getattr(tok,'special_tokens_map',{}),getattr(tok,'special_tokens_map_extended',{})]:
                if isinstance(d,dict):
                    for v in d.values():
                        if isinstance(v,(list,tuple)):
                            for s in v:
                                if isinstance(s,str) and 'image' in s.lower(): toks.append(s)
                        elif isinstance(v,str) and 'image' in v.lower(): toks.append(v)
            added=getattr(tok,'added_tokens_encoder',{}) or {}
            for s in added.keys():
                if isinstance(s,str) and 'image' in s.lower(): toks.append(s)
    except Exception:
        pass
    for c in ['<image>','<img>','<image_token>','<|image|>','<image_1>']:
        if c not in toks: toks.append(c)
    return toks or ['<image>']
def _encode_one(tok, s):
    try:
        ids=tok(s,add_special_tokens=False,return_attention_mask=False,return_token_type_ids=False).get('input_ids',[])
        return int(ids[0]) if isinstance(ids,list) and len(ids)==1 else None
    except Exception:
        return None
def _get_image_token_id():
    global IMAGE_TOKEN_OVERRIDE
    tok=getattr(processor,'tokenizer',None)
    if isinstance(IMAGE_TOKEN_OVERRIDE,str) and IMAGE_TOKEN_OVERRIDE and tok is not None:
        tid=_encode_one(tok, IMAGE_TOKEN_OVERRIDE)
        if tid is not None: return IMAGE_TOKEN_OVERRIDE, tid
    try:
        cfg_id=getattr(getattr(model,'config',None),'image_token_id',None)
        if isinstance(cfg_id,int): return None,cfg_id
    except Exception:
        pass
    if tok is not None:
        for cand in _discover_image_tokens(processor):
            tid=_encode_one(tok,cand)
            if tid is not None: return cand,tid
        tid=_encode_one(tok,'<image>')
        if tid is not None: return '<image>',tid
    return '<image>',None
def _build_chat(messages, image: Image.Image):
    global LAST_IMAGE_TOKEN_STR, LAST_IMAGE_TOKEN_ID, LAST_IMAGE_TOKEN_COUNT
    LAST_IMAGE_TOKEN_STR=LAST_IMAGE_TOKEN_ID=LAST_IMAGE_TOKEN_COUNT=None
    user_text=next((m.get('content') for m in messages if m.get('role')=='user'),'Please analyze the image and answer.')
    # Try structured
    try:
        structured=[]
        sys_msg=next((m for m in messages if m.get('role')=='system'),None)
        if sys_msg: structured.append(sys_msg)
        structured.append({'role':'user','content':[{'type':'image'},{'type':'text','text':user_text}]})
        enc=processor.apply_chat_template(structured, add_generation_prompt=True, return_tensors='pt', images=[image])
        if isinstance(enc,dict) and 'input_ids' in enc and 'pixel_values' in enc:
            cand_str,cand_id=_get_image_token_id()
            if cand_id is not None:
                cnt=int((enc['input_ids'][0]==cand_id).sum().item())
                LAST_IMAGE_TOKEN_STR, LAST_IMAGE_TOKEN_ID, LAST_IMAGE_TOKEN_COUNT=cand_str,cand_id,cnt
                if cnt==1: return enc
            else: return enc
    except Exception:
        pass
    # Fallback brute force
    try:
        tok=getattr(processor,'tokenizer',None)
        if tok is not None:
            for cand in _discover_image_tokens(processor):
                tid=_encode_one(tok,cand)
                if tid is None: continue
                text=f"{cand}\n{user_text}"
                enc=processor(text=[text], images=[image], return_tensors='pt')
                if 'input_ids' not in enc: continue
                cnt=int((enc['input_ids'][0]==tid).sum().item())
                if cnt==1:
                    if 'pixel_values' not in enc:
                        vis=processor(images=[image], return_tensors='pt')
                        if isinstance(vis,dict) and 'pixel_values' in vis: enc['pixel_values']=vis['pixel_values']
                    LAST_IMAGE_TOKEN_STR, LAST_IMAGE_TOKEN_ID, LAST_IMAGE_TOKEN_COUNT=cand,tid,cnt
                    return enc
    except Exception:
        pass
    tok=getattr(processor,'tokenizer',None)
    cand_str,cand_id=_get_image_token_id()
    if tok is not None and cand_id is not None:
        ids_text=tok(user_text,add_special_tokens=False,return_attention_mask=False,return_token_type_ids=False).get('input_ids',[]) or []
        ids=[cand_id]+[int(i) for i in ids_text]
        enc={'input_ids':torch.tensor([ids],dtype=torch.long),'attention_mask':None,'pixel_values':processor(images=[image],return_tensors='pt').get('pixel_values')}
        LAST_IMAGE_TOKEN_STR, LAST_IMAGE_TOKEN_ID, LAST_IMAGE_TOKEN_COUNT=cand_str,cand_id,1
        return enc
    text=f"<image>\n{user_text}"
    enc=processor(text=[text], images=[image], return_tensors='pt')
    if 'pixel_values' not in enc:
        vis=processor(images=[image],return_tensors='pt')
        if isinstance(vis,dict) and 'pixel_values' in vis: enc['pixel_values']=vis['pixel_values']
    LAST_IMAGE_TOKEN_STR, LAST_IMAGE_TOKEN_ID, LAST_IMAGE_TOKEN_COUNT='<image>',None,None
    return enc
def _normalize_token_str(s:str)->str:
    if s is None: return ''
    s=s.replace('▁',' ').replace('Ġ',' ').strip()
    while len(s) and s[-1] in ")].,:;!?\'\"": s=s[:-1]
    return s
def _collect_letter_token_ids(sm):
    V=sm.numel(); k=min(512,V)
    _,ids=torch.topk(sm,k); ids=ids.tolist()
    letter={'A':set(),'B':set(),'C':set()}
    tok=getattr(processor,'tokenizer',None)
    for tid in ids:
        try: ts=tok.convert_ids_to_tokens(int(tid))
        except Exception: ts=None
        tn=_normalize_token_str(ts or '')
        if tn.upper() in ('A','B','C'): letter[tn.upper()].add(int(tid))
    variants={'A':['A',' A','A)',' A)','A.',' A.','a',' a','a)',' a)'],
              'B':['B',' B','B)',' B)','B.',' B.','b',' b','b)',' b)'],
              'C':['C',' C','C)',' C)','C.',' C.','c',' c','c)',' c)']}
    for L,txts in variants.items():
        for t in txts:
            try:
                enc=tok(t,add_special_tokens=False,return_attention_mask=False,return_token_type_ids=False)
                for x in enc.get('input_ids',[]): letter[L].add(int(x))
            except Exception: pass
    return {k:sorted(list(v)) for k,v in letter.items()}
@torch.no_grad()
def _score_letter_sequences_ll(enc_base, letters):
    tok=getattr(processor,'tokenizer',None); assert tok is not None
    base_ids=enc_base['input_ids'].to(DEVICE)
    pv=enc_base.get('pixel_values',None)
    if pv is not None: pv=pv.to(DEVICE)
    variants={'A':['A',' A','A)',' A)','A.',' A.','a',' a','a)',' a)'],
              'B':['B',' B','B)',' B)','B.',' B.','b',' b','b)',' b)'],
              'C':['C',' C','C)',' C)','C.',' C.','c',' c','c)',' c)']}
    scores={k:-float('inf') for k in letters}
    for L in letters:
        best=-float('inf')
        for txt in variants[L]:
            ids=tok(txt,add_special_tokens=False,return_attention_mask=False,return_token_type_ids=False).get('input_ids',[]) or []
            if not ids: continue
            ids_t=torch.tensor([ids],dtype=torch.long,device=DEVICE)
            full_ids=torch.cat([base_ids,ids_t],dim=1)
            with torch.autocast(device_type=('cuda' if torch.cuda.is_available() else 'cpu'), enabled=False):
                out=model(input_ids=full_ids,pixel_values=pv)
                logits=_sanitize(out.logits.float())
                logp=0.0; base_len=base_ids.shape[1]
                for i,tok_id in enumerate(ids):
                    pos=base_len-1+i
                    step_lp=torch.log_softmax(logits[0,pos,:],dim=-1)
                    logp+=float(step_lp[int(tok_id)].item())
            if logp>best: best=logp
        scores[L]=best
    return scores
def _load_image(p):
    img=Image.open(p).convert('RGB')
    if min(img.size)<8: img=img.resize((max(32,img.size[0]*4),max(32,img.size[1]*4)))
    return img
@torch.no_grad()
def mcq_letter_predict_safe(image: Image.Image, labels=LABELS):
    messages=[{'role':'system','content':'You are a medical assistant for skin disease triage.'},
              {'role':'user','content':f'Look at the image and choose the single best answer. Options: A) {labels[0]}, B) {labels[1]}, C) {labels[2]}. Reply with only A, B, or C.'}]
    enc=_build_chat(messages,image)
    input_ids=enc['input_ids'].to(DEVICE)
    pv=enc.get('pixel_values',None)
    if pv is not None: pv=pv.to(DEVICE)
    with torch.autocast(device_type=('cuda' if torch.cuda.is_available() else 'cpu'), enabled=False):
        out=model(input_ids=input_ids,pixel_values=pv)
        logits=_sanitize(out.logits.float()); last=logits[:,-1,:]
    sm=_softmax(last,dim=-1)[0]
    letter_ids=_collect_letter_token_ids(sm)
    probs={}
    for L in ['A','B','C']:
        p=0.0
        for tid in letter_ids.get(L,[]):
            if 0<=tid<sm.numel(): p+=float(sm[tid].item())
        probs[L]=p
    total=sum(probs.values())
    fallback=(not math.isfinite(total)) or (total<=0) or (abs(probs['A']-probs['B'])<1e-9 and abs(probs['B']-probs['C'])<1e-9)
    if fallback:
        scores=_score_letter_sequences_ll(enc,['A','B','C'])
        arr=torch.tensor([scores['A'],scores['B'],scores['C']],dtype=torch.float32); arr=arr-arr.max(); arr=torch.softmax(arr,dim=-1)
        probs={'A':float(arr[0].item()),'B':float(arr[1].item()),'C':float(arr[2].item())}
    else:
        probs={k:(v/total if total>0 else 1/3) for k,v in probs.items()}
    letter=max(probs.items(),key=lambda kv:kv[1])[0]; idx={'A':0,'B':1,'C':2}[letter]; label=labels[idx]
    return label, {labels[0]:probs.get('A',0.0),labels[1]:probs.get('B',0.0),labels[2]:probs.get('C',0.0)}
def detect_and_set_image_token_override(example_image_path: str, user_text: str = 'Please analyze the image and answer.'):
    global IMAGE_TOKEN_OVERRIDE
    img=_load_image(example_image_path)
    tok=getattr(processor,'tokenizer',None)
    if tok is None: print('Tokenizer missing'); return None
    cands=_discover_image_tokens(processor); print('[Detect] candidates:', cands)
    for cand in cands:
        tid=_encode_one(tok,cand)
        if tid is None: continue
        try:
            text=f"{cand}\n{user_text}"
            enc=processor(text=[text],images=[img],return_tensors='pt')
            if 'input_ids' not in enc: continue
            cnt=int((enc['input_ids'][0]==tid).sum().item())
            print(f"[Detect] {cand} -> id={tid}, count={cnt}")
            if cnt==1: IMAGE_TOKEN_OVERRIDE=cand; print('Set IMAGE_TOKEN_OVERRIDE=',cand); return cand
        except Exception as e: print('[Detect fail]',cand, e)
    print('No single-match token found')

In [19]:
# 6a) Estimate blank-image priors and define debiased predictor
from collections import defaultdict
import math
try:
    from PIL import Image  # type: ignore
except Exception:
    Image=None
def _blank_image(w=512,h=512,color=255):
    if Image is None:
        raise RuntimeError('Pillow not available; install pillow or run earlier cells.')
    return Image.new('RGB',(w,h),(color,color,color))
def estimate_priors_blank(n=4):
    if 'mcq_letter_predict_safe' not in globals():
        raise RuntimeError("Run the 'Robust classifier utilities' cell first.")
    pri_sum=defaultdict(float)
    for i in range(n):
        img=_blank_image()
        _, probs = mcq_letter_predict_safe(img, LABELS)
        for k,v in probs.items(): pri_sum[k]+=float(v)
    pri={k: (pri_sum[k]/max(1e-8,n)) for k in LABELS}
    # guard against zeros
    for k in pri: pri[k]=max(pri[k],1e-6)
    s=sum(pri.values())
    pri={k: pri[k]/s for k in pri}
    print('[Debias] Estimated priors:', pri)
    return pri
def debiased_predict(image, labels, priors):
    pred, p = mcq_letter_predict_safe(image, labels)
    # divide by priors and renormalize
    q={k: float(p.get(k,0.0))/float(priors.get(k,1.0)) for k in labels}
    # clamp and renorm
    for k in q: q[k]=max(q[k],1e-9)
    s=sum(q.values()); q={k: q[k]/s for k in q}
    pred2=max(q, key=q.get)
    return pred2, q

In [20]:
# 6a) Estimate blank-image priors and define debiased predictor (fast + cached)
from collections import defaultdict
import math, time
try:
    import torch  # type: ignore
except Exception:
    torch=None
try:
    from PIL import Image  # type: ignore
except Exception:
    Image=None
DEBIAS_PRIORS_CACHE=None
def _blank_image(w=384,h=384,color=255):
    if Image is None:
        raise RuntimeError('Pillow not available; install pillow or run earlier cells.')
    return Image.new('RGB',(w,h),(color,color,color))
def estimate_priors_blank(n=3, force=False):
    global DEBIAS_PRIORS_CACHE
    if (DEBIAS_PRIORS_CACHE is not None) and not force:
        print('[Debias] Using cached priors:', DEBIAS_PRIORS_CACHE)
        return DEBIAS_PRIORS_CACHE
    if 'mcq_letter_predict_safe' not in globals():
        raise RuntimeError("Run the 'Robust classifier utilities' cell first.")
    t0=time.time()
    pri_sum=defaultdict(float)
    ctx=(torch.inference_mode() if torch is not None else nullcontext()) if 'nullcontext' in globals() else (torch.inference_mode() if torch is not None else None)
    def _ctx():
        if torch is None:
            class _N:
                def __enter__(self): return None
                def __exit__(self, *a): return False
            return _N()
        return torch.inference_mode()
    with _ctx():
        for i in range(n):
            if i==0: print(f'[Debias] Estimating priors on {n} blank images ...')
            img=_blank_image()
            _, probs = mcq_letter_predict_safe(img, LABELS)
            for k,v in probs.items(): pri_sum[k]+=float(v)
    pri={k: (pri_sum[k]/max(1e-8,n)) for k in LABELS}
    for k in pri: pri[k]=max(pri[k],1e-6)
    s=sum(pri.values()); pri={k: pri[k]/s for k in pri}
    DEBIAS_PRIORS_CACHE=pri
    print('[Debias] Estimated priors:', pri, f'| took {time.time()-t0:.1f}s')
    return pri
def debiased_predict(image, labels, priors):
    # Guard fast inference context
    if torch is not None:
        with torch.inference_mode():
            pred, p = mcq_letter_predict_safe(image, labels)
    else:
        pred, p = mcq_letter_predict_safe(image, labels)
    q={k: float(p.get(k,0.0))/float(priors.get(k,1.0)) for k in labels}
    for k in q: q[k]=max(q[k],1e-9)
    s=sum(q.values()); q={k: q[k]/s for k in q}
    pred2=max(q, key=q.get)
    return pred2, q

In [21]:
# 6b) Debiased evaluation on same dataset (progress + faster defaults)
from pathlib import Path
from glob import glob
import pandas as pd
import time
from collections import defaultdict
def evaluate_dataset_debiased(root: Path, per_class: int = 50, outfile: str = '/content/outputs/predictions_150_mcq_debiased.csv', n_prior=3):
    if 'LABELS' not in globals() or 'mcq_letter_predict_safe' not in globals():
        print('[Debias][Error] Missing LABELS or classifier. Run earlier cells first.')
        return
    pri=estimate_priors_blank(n=n_prior)
    # Resolve helpers
    _collect = collect_images if 'collect_images' in globals() else None
    _choose = _choose_eval_root if '_choose_eval_root' in globals() else None
    if _choose is None:
        def _choose(root: Path):
            cands=[root,root/'test',root/'processed'/'test',root/'processed_dataset'/'test']
            for r in cands:
                if all((r/c).exists() for c in ['eczema','fungal','scabies']): return r
            return root
    if _collect is None:
        def _collect(root: Path, per_class: int):
            items=[]
            pats=('*.jpg','*.jpeg','*.png','*.JPG','*.JPEG','*.PNG','*.webp','*.WEBP')
            for c in ['eczema','fungal','scabies']:
                paths=[]
                for ext in pats: paths+=glob(str((root/c)/ext))
                for p in sorted(paths)[:per_class]: items.append((c,p))
            return items
    root=Path(root)
    root=_choose(root)
    print('[Debias] Using evaluation root:', root)
    Path(outfile).parent.mkdir(parents=True, exist_ok=True)
    items=_collect(root, per_class)
    print('[Debias] Total images:', len(items), f'(target {per_class} per class)')
    rows=[]; dist=defaultdict(int); correct=0
    if len(items)==0:
        print('[Debias] No images found.');
        pd.DataFrame(columns=['filename','true_label','predicted_label','prob_eczema','prob_fungal','prob_scabies']).to_csv(outfile, index=False)
        print('Saved empty CSV to:', outfile)
        return
    try: from PIL import Image
    except Exception: Image=None
    def _ldr(p):
        if '_load_image' in globals():
            try: return globals()['_load_image'](p)
            except Exception: pass
        if Image is None: raise RuntimeError('Pillow not available')
        return Image.open(p).convert('RGB')
    t0=time.time()
    for idx,(true_label, p) in enumerate(items, start=1):
        if idx==1: print('[Debias] Starting predictions ...')
        if idx%10==0: print(f'  progress: {idx}/{len(items)}')
        try:
            img=_ldr(p)
            pred, probs=debiased_predict(img, LABELS, pri)
        except Exception as e:
            print('[Debias][ERR]', Path(p).name, e)
            pred, probs='eczema', {'eczema':1.0,'fungal':0.0,'scabies':0.0}
        rows.append({'filename':Path(p).name,'true_label':true_label,'predicted_label':pred,
                     'prob_eczema':float(probs.get('eczema',0.0)),'prob_fungal':float(probs.get('fungal',0.0)),
                     'prob_scabies':float(probs.get('scabies',0.0))})
        dist[pred]+=1; correct+=int(pred==true_label)
    df=pd.DataFrame(rows); df.to_csv(outfile, index=False)
    print('Saved debiased CSV to:', outfile, f'| took {time.time()-t0:.1f}s')
    n=len(df); acc=correct/n if n else 0.0
    print(f'[Debias] Overall accuracy: {acc:.3f} ({correct}/{n})')
    for c in ['eczema','fungal','scabies']:
        sub=df[df['true_label']==c]
        if len(sub):
            acc_c=(sub['predicted_label']==c).mean()
            print(f'  {c}: {acc_c:.3f} ({int((sub["predicted_label"]==c).sum())}/{len(sub)})')
        else: print('  ', c, ': no samples')
    print('\n[Debias] Predicted distribution:', dict(dist))
evaluate_dataset_debiased(DATA_ROOT, per_class=50)

[Debias] Estimating priors on 3 blank images ...
[Debias] Estimated priors: {'eczema': 0.3333333333333333, 'fungal': 0.3333333333333333, 'scabies': 0.3333333333333333} | took 195.4s
[Debias] Using evaluation root: /content/data/processed
[Debias] Total images: 150 (target 50 per class)
[Debias] Starting predictions ...
  progress: 10/150
  progress: 20/150
  progress: 30/150
  progress: 40/150
  progress: 50/150
  progress: 60/150
  progress: 70/150
  progress: 80/150
  progress: 90/150
  progress: 100/150
  progress: 110/150
  progress: 120/150
  progress: 130/150
  progress: 140/150
  progress: 150/150
Saved debiased CSV to: /content/outputs/predictions_150_mcq_debiased.csv | took 9649.4s
[Debias] Overall accuracy: 0.333 (50/150)
  eczema: 1.000 (50/50)
  fungal: 0.000 (0/50)
  scabies: 0.000 (0/50)

[Debias] Predicted distribution: {'eczema': 150}
