In [None]:
import matplotlib.pyplot as plt
import requests
from google.colab import drive
import shap

In [None]:
# Montar Drive (solo en notebook)
from google.colab import drive; drive.mount('/content/drive')

# Asegurar imports del paquete (elige una de estas dos líneas)
#%pip install -e /content/mmshap_medclip
import sys; sys.path.append("/content/mmshap_medclip")

from mmshap_medclip.io_utils import load_config
from mmshap_medclip.devices import get_device
from mmshap_medclip.registry import build_dataset, build_model

cfg = load_config("/content/mmshap_medclip/configs/roco_isa_pubmedclip.yaml")
device  = get_device()  # o get_device(prefer_cuda=True)

dataset = build_dataset(cfg["dataset"])
model   = build_model(cfg["model"], device=device)   # CLIPWrapper

muestra  = 357
sample   = dataset[muestra]
image    = sample["image"]     # PIL.Image
caption  = sample["text"]      # str
processor = model.processor
tokenizer = model.tokenizer

In [None]:
from mmshap_medclip.tasks.utils import prepare_batch

inputs, logits = prepare_batch(
    model_wrapper=model,
    texts=[caption],
    images=[image],
    device=device,
    debug_tokens=True,   # ponlo False cuando ya no necesites inspección
    amp_if_cuda=True
)
model_prediction = float(logits[0, 0])
print("🔮 logits_per_image[0,0]:", model_prediction)

In [None]:
from mmshap_medclip.tasks.utils import compute_text_token_lengths, make_image_token_ids

# 1) Longitudes reales de texto (sin PAD)
nb_text_tokens_tensor, nb_text_tokens = compute_text_token_lengths(inputs, tokenizer)
print("Tokens por texto (sin PAD):", nb_text_tokens)

# 2) IDs de parches de imagen
image_token_ids_expanded, imginfo = make_image_token_ids(inputs, model, debug=True)
print("Forma image_token_ids_expanded:", tuple(image_token_ids_expanded.shape))  # (B, N)
print("Info imagen:", imginfo)  # {'patch_size':..., 'grid_h':..., 'grid_w':..., 'num_patches':...}


In [None]:

from mmshap_medclip.tasks.utils import (
    compute_text_token_lengths, make_image_token_ids, concat_text_image_tokens
)
from mmshap_medclip.shap_tools.masker import build_masker
from mmshap_medclip.shap_tools.predictor import Predictor

# 1) Longitudes reales de texto
nb_text_tokens_tensor, nb_text_tokens = compute_text_token_lengths(inputs, tokenizer)

# 2) IDs de parches e input para SHAP
image_token_ids_expanded, imginfo = make_image_token_ids(inputs, model, debug=False)
X_clean, seq_len = concat_text_image_tokens(inputs, image_token_ids_expanded, device=device)

# 3) Masker (preserva BOS/EOS)
masker = build_masker(nb_text_tokens_tensor, tokenizer=tokenizer)

# 4) Predictor (usa el mismo device del modelo)
predict_fn = Predictor(
    model_wrapper=model,
    base_inputs=inputs,                       # dict del processor (ya en device)
    patch_size=imginfo["patch_size"],         # o None para inferir
    device=device,
    use_amp=True
)

# 5) Explainer SHAP (SHAP suele mandar x en CPU → no pasa nada)
explainer = shap.Explainer(predict_fn, masker, silent=True)
shap_values = explainer(X_clean.cpu())        # devuelve valores para [texto|parches]

print("SHAP values shape:", shap_values.values.shape)


In [None]:
from mmshap_medclip.metrics import compute_mm_score, compute_iscore

# Para el primer ejemplo del batch (i=0)
tscore, word_shap = compute_mm_score(
    shap_values=shap_values,
    tokenizer=tokenizer,
    inputs=inputs,
    i=0
)
iscore = compute_iscore(shap_values, inputs, i=0)

print(f"TScore: {tscore:.2%} | IScore: {iscore:.2%}")
for w, s in word_shap.items():
    print(f"{w:>15s}: {s:+.4f}")


In [None]:
from mmshap_medclip.vis.heatmaps import plot_text_image_heatmaps

# texts = lista de captions del batch
texts = [caption]  # si B=1
images = image     # PIL o lista de PILs

# mm_scores ya lo calculaste con compute_mm_score para cada i del batch:
# mm_scores = [(tscore_i, word_shap_i), ...]
plot_text_image_heatmaps(
    shap_values=shap_values,
    inputs=inputs,
    tokenizer=tokenizer,
    images=images,
    texts=texts,
    mm_scores=mm_scores,
    model_wrapper=model,     # CLIPWrapper
    cmap_name="RdYlBu_r",
    alpha_img=0.60,
    return_fig=False
)
