<a href="https://colab.research.google.com/github/Alberto-97sc/mmshap_medclip/blob/others-clips-version/notebooks/02_rclip_classification_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# RClip + SHAP: Clasificación Médica con Balance Multimodal

Este notebook demuestra:
- **Clasificación de imágenes médicas** usando RClip
- **Análisis de explicabilidad** con SHAP
- **Medición del balance multimodal** (TScore/IScore)
- **Visualización de mapas de calor** para parches de imagen y tokens de texto

Dataset: **ROCO** (Radiology Objects in COntext)  
Modelo: **RClip** (kaveh/rclip)


## 🚀 Configuración inicial


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
REPO_URL = "https://github.com/Alberto-97sc/mmshap_medclip.git"
LOCAL_DIR = "/content/mmshap_medclip"
BRANCH = "others-clips-version"

%cd /content
import os, shutil, subprocess, sys

if not os.path.isdir(f"{LOCAL_DIR}/.git"):
    !git clone $REPO_URL $LOCAL_DIR
else:
    %cd $LOCAL_DIR
    !git fetch origin
    !git checkout $BRANCH
    !git reset --hard origin/$BRANCH

%cd $LOCAL_DIR
!git rev-parse --short HEAD


In [None]:
%pip install -e /content/mmshap_medclip
%pip install tqdm


## 📊 Carga de datos y modelo


In [None]:
CFG_PATH = "/content/mmshap_medclip/configs/roco_classification_rclip.yaml"

from mmshap_medclip.io_utils import load_config
from mmshap_medclip.devices import get_device
from mmshap_medclip.registry import build_dataset, build_model

cfg = load_config(CFG_PATH)
device = get_device()
print(f"🖥️ Dispositivo: {device}")

print("📁 Cargando dataset ROCO...")
dataset = build_dataset(cfg["dataset"])
print(f"✅ Dataset cargado: {len(dataset)} muestras")

print("🤖 Cargando modelo RClip...")
model = build_model(cfg["model"], device=device)
print("✅ Modelo RClip cargado")


In [None]:
class_names = [
    "Chest X-Ray", "Brain MRI", "Abdominal CT Scan",
    "Ultrasound", "OPG", "Mammography", "Bone X-Ray"
]

print(f"🏷️ Clases definidas: {len(class_names)}")
for i, clase in enumerate(class_names):
    print(f"  {i+1}. {clase}")


## 🧠 Clasificación con SHAP y Balance Multimodal


In [None]:
from mmshap_medclip.tasks.classification import run_classification_one
import matplotlib.pyplot as plt

muestra_idx = 266
sample = dataset[muestra_idx]
image = sample['image']
caption = sample['text']

print(f"📋 Muestra {muestra_idx}:")
print(f"Caption original: {caption[:100]}...")

plt.figure(figsize=(8, 6))
plt.imshow(image)
plt.title(f"Muestra {muestra_idx} - ROCO Dataset")
plt.axis('off')
plt.show()


In [None]:
print("🔬 Ejecutando clasificación con SHAP...")
res_shap = run_classification_one(
    model, image, class_names, device, 
    explain=True, plot=True
)

print(f"\\n🎯 Resultados:")
print(f"Clase predicha: {res_shap['predicted_class']}")
print(f"Confianza: {res_shap['probabilities'].max():.2%}")
print(f"TScore (Texto): {res_shap['tscore']:.2%}")
print(f"IScore (Imagen): {res_shap['iscore']:.2%}")

# Interpretación del balance
if res_shap['tscore'] > 0.6:
    balance_msg = "🔤 Enfoque en TEXTO"
elif res_shap['iscore'] > 0.6:
    balance_msg = "🖼️ Enfoque en IMAGEN"
else:
    balance_msg = "⚖️ Balance equilibrado"
    
print(f"Balance: {balance_msg}")

# Mostrar probabilidades de todas las clases
print(f"\\n📊 Probabilidades por clase:")
for clase, prob in zip(class_names, res_shap['probabilities']):
    bar = "█" * int(prob * 20)
    print(f"  {clase:<20}: {prob:.2%} {bar}")

# Mostrar mapa de calor
if 'fig' in res_shap:
    display(res_shap['fig'])
