<a href="https://colab.research.google.com/github/Alberto-97sc/mmshap_medclip/blob/others-clips-version/notebooks/01_pubmedclip_isa_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# PubMedCLIP + SHAP: An√°lisis ISA con Balance Multimodal

Este notebook demuestra:
- **Image-Sentence Alignment (ISA)** usando PubMedCLIP
- **An√°lisis de explicabilidad** con SHAP
- **Medici√≥n del balance multimodal** (TScore/IScore)
- **Visualizaci√≥n de mapas de calor** para parches de imagen y tokens de texto

Dataset: **ROCO** (Radiology Objects in COntext)  
Modelo: **PubMedCLIP** (flaviagiammarino/pubmed-clip-vit-base-patch32)


## üöÄ Configuraci√≥n inicial


In [None]:
# Montar Google Drive
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# Configuraci√≥n del repositorio
REPO_URL = "https://github.com/Alberto-97sc/mmshap_medclip.git"
LOCAL_DIR = "/content/mmshap_medclip"
BRANCH = "others-clips-version"

%cd /content
import os, shutil, subprocess, sys

if not os.path.isdir(f"{LOCAL_DIR}/.git"):
    !git clone $REPO_URL $LOCAL_DIR
else:
    %cd $LOCAL_DIR
    !git fetch origin
    !git checkout $BRANCH
    !git reset --hard origin/$BRANCH

%cd $LOCAL_DIR
!git rev-parse --short HEAD


In [None]:
# Instalar el paquete en modo editable
%pip install -e /content/mmshap_medclip

# Dependencias adicionales
%pip install tqdm


## üìä Carga de datos y modelo


In [None]:
CFG_PATH = "/content/mmshap_medclip/configs/roco_isa_pubmedclip.yaml"

from mmshap_medclip.io_utils import load_config
from mmshap_medclip.devices import get_device
from mmshap_medclip.registry import build_dataset, build_model

cfg = load_config(CFG_PATH)
device = get_device()
print(f"üñ•Ô∏è Dispositivo: {device}")

print("üìÅ Cargando dataset ROCO...")
dataset = build_dataset(cfg["dataset"])
print(f"‚úÖ Dataset cargado: {len(dataset)} muestras")

print("ü§ñ Cargando modelo PubMedCLIP...")
model = build_model(cfg["model"], device=device)
print("‚úÖ Modelo PubMedCLIP cargado")


## üß† ISA con SHAP y Balance Multimodal


In [None]:
from mmshap_medclip.tasks.isa import run_isa_one
import matplotlib.pyplot as plt
import numpy as np

# Seleccionar muestra
muestra_idx = 266
sample = dataset[muestra_idx]
image = sample['image']
caption = sample['text']

print(f"üìã Muestra {muestra_idx}:")
print(f"Caption: {caption}")

# Mostrar imagen
plt.figure(figsize=(10, 6))
plt.imshow(image)
plt.title(f"Muestra {muestra_idx} - ROCO Dataset\\n{caption[:80]}...")
plt.axis('off')
plt.show()


In [None]:
print("üî¨ Ejecutando ISA con SHAP...")
res_shap = run_isa_one(model, image, caption, device, explain=True, plot=True)

print(f"\\nüéØ Resultados:")
print(f"Logit (similitud): {res_shap['logit']:.4f}")
print(f"TScore (Texto): {res_shap['tscore']:.2%}")
print(f"IScore (Imagen): {res_shap['iscore']:.2%}")

# Interpretaciones
if res_shap['tscore'] > 0.6:
    balance_msg = "üî§ Enfoque en TEXTO"
elif res_shap['iscore'] > 0.6:
    balance_msg = "üñºÔ∏è Enfoque en IMAGEN"
else:
    balance_msg = "‚öñÔ∏è Balance equilibrado"

if res_shap['logit'] > 0:
    isa_msg = "‚úÖ BUENA alineaci√≥n"
else:
    isa_msg = "‚ö†Ô∏è Alineaci√≥n moderada/pobre"

print(f"Balance: {balance_msg}")
print(f"Alineaci√≥n: {isa_msg}")

# Mostrar mapa de calor
if 'fig' in res_shap:
    display(res_shap['fig'])


<a href="https://colab.research.google.com/github/Alberto-97sc/mmshap_medclip/blob/others-clips-version/notebooks/01_pubmedclip_isa_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# PubMedCLIP + SHAP: An√°lisis ISA con Balance Multimodal

Este notebook demuestra:
- **Image-Sentence Alignment (ISA)** usando PubMedCLIP
- **An√°lisis de explicabilidad** con SHAP
- **Medici√≥n del balance multimodal** (TScore/IScore)
- **Visualizaci√≥n de mapas de calor** para parches de imagen y tokens de texto

Dataset: **ROCO** (Radiology Objects in COntext)
Modelo: **PubMedCLIP** (flaviagiammarino/pubmed-clip-vit-base-patch32)


## üöÄ Configuraci√≥n inicial


In [None]:
# Montar Google Drive
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# Configuraci√≥n del repositorio
REPO_URL = "https://github.com/Alberto-97sc/mmshap_medclip.git"
LOCAL_DIR = "/content/mmshap_medclip"
BRANCH = "others-clips-version"  # Rama con RClip y PubMedCLIP

%cd /content
import os, shutil, subprocess, sys

if not os.path.isdir(f"{LOCAL_DIR}/.git"):
    # No est√° clonado a√∫n
    !git clone $REPO_URL $LOCAL_DIR
else:
    # Ya existe: actualiza a la √∫ltima versi√≥n del remoto
    %cd $LOCAL_DIR
    !git fetch origin
    !git checkout $BRANCH
    !git reset --hard origin/$BRANCH

%cd $LOCAL_DIR
!git rev-parse --short HEAD


In [None]:
# Instalar el paquete en modo editable
%pip install -e /content/mmshap_medclip

# Dependencias adicionales si son necesarias
%pip install matplotlib seaborn pillow


## üìä Carga de datos y modelo


In [None]:
# Configuraci√≥n para PubMedCLIP + ISA
CFG_PATH = "/content/mmshap_medclip/configs/roco_isa_pubmedclip.yaml"

from mmshap_medclip.io_utils import load_config
from mmshap_medclip.devices import get_device
from mmshap_medclip.registry import build_dataset, build_model

# Cargar configuraci√≥n
cfg = load_config(CFG_PATH)
device = get_device()
print(f"üñ•Ô∏è Dispositivo: {device}")

# Cargar dataset ROCO
print("üìÅ Cargando dataset ROCO...")
dataset = build_dataset(cfg["dataset"])
print(f"‚úÖ Dataset cargado: {len(dataset)} muestras")

# Cargar modelo PubMedCLIP
print("ü§ñ Cargando modelo PubMedCLIP...")
model = build_model(cfg["model"], device=device)
print("‚úÖ Modelo PubMedCLIP cargado")


## üîç Ejemplo 1: ISA simple (sin SHAP)


In [None]:
from mmshap_medclip.tasks.isa import run_isa_one
import matplotlib.pyplot as plt
import numpy as np

# Seleccionar una muestra del dataset
muestra_idx = 266  # Cambiar por cualquier √≠ndice v√°lidoSELECT 
    schemaname,
    tablename,
    pg_size_pretty(pg_total_relation_size(schemaname|| '.' || tablename)) as total_size
FROM pg_tables t
LEFT JOIN (
    SELECT schemaname as schema, 
           tablename as table, 
           (xpath('/row/cnt/text()', xml_count))[1]::text::int as row_count
    FROM (
        SELECT schemaname, tablename, table_xml as xml_count
        FROM (
            SELECT schemaname, tablename,
                   query_to_xml(format('select count(*) as cnt from %I.%I', schemaname, tablename), false, true, '') as table_xml
            FROM pg_tables
            WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
        ) AS t
    ) AS pt
) AS subq 
ON t.schemaname = subq.schema 
AND t.tablename = subq.table
WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
AND row_count = 0 OR row_count IS NULL
ORDER BY schemaname, tablename;
sample = dataset[muestra_idx]
image = sample['image']
caption = sample['text']

print(f"üìã Muestra {muestra_idx}:")
print(f"Caption: {caption}")
print(f"Metadata: {sample['meta']}")

# Mostrar la imagen
plt.figure(figsize=(10, 6))
plt.imshow(image)
plt.title(f"Muestra {muestra_idx} - ROCO Dataset\n{caption[:80]}...")
plt.axis('off')
plt.show()


In [None]:
# ISA r√°pido sin SHAP
print("‚ö° Ejecutando ISA r√°pido (sin SHAP)...")
res_simple = run_isa_one(
    model, image, caption, device, 
    explain=False  # Sin explicabilidad para mayor velocidad
)

print(f"\nüéØ Resultados de ISA:")
print(f"Logit (similitud imagen-texto): {res_simple['logit']:.4f}")

# Interpretaci√≥n del logit
if res_simple['logit'] > 0:
    alignment_msg = "‚úÖ BUENA alineaci√≥n imagen-texto"
elif res_simple['logit'] > -1:
    alignment_msg = "‚ö†Ô∏è Alineaci√≥n MODERADA imagen-texto"
else:
    alignment_msg = "‚ùå POBRE alineaci√≥n imagen-texto"
    
print(f"Interpretaci√≥n: {alignment_msg}")


## üß† Ejemplo 2: ISA con SHAP y Balance Multimodal


In [None]:
print("üî¨ Ejecutando ISA con SHAP (esto puede tomar varios minutos)...")
res_shap = run_isa_one(
    model, image, caption, device, 
    explain=True,  # Con explicabilidad SHAP
    plot=True      # Generar mapas de calor
)

print(f"\nüéØ Resultados con SHAP:")
print(f"Logit (similitud): {res_shap['logit']:.4f}")

print(f"\n‚öñÔ∏è Balance Multimodal:")
print(f"TScore (Text Score): {res_shap['tscore']:.2%}")
print(f"IScore (Image Score): {res_shap['iscore']:.2%}")

# Interpretaci√≥n del balance
if res_shap['tscore'] > 0.6:
    balance_msg = "üî§ Modelo se enfoca m√°s en el TEXTO"
elif res_shap['iscore'] > 0.6:
    balance_msg = "üñºÔ∏è Modelo se enfoca m√°s en la IMAGEN"
else:
    balance_msg = "‚öñÔ∏è Balance equilibrado entre texto e imagen"
    
print(f"Interpretaci√≥n del balance: {balance_msg}")

# Interpretaci√≥n del ISA
if res_shap['logit'] > 0:
    isa_msg = "‚úÖ BUENA alineaci√≥n imagen-texto"
elif res_shap['logit'] > -1:
    isa_msg = "‚ö†Ô∏è Alineaci√≥n MODERADA imagen-texto"
else:
    isa_msg = "‚ùå POBRE alineaci√≥n imagen-texto"
    
print(f"Interpretaci√≥n del ISA: {isa_msg}")


In [None]:
# Mostrar mapa de calor si est√° disponible
if 'fig' in res_shap:
    print("üó∫Ô∏è Mapa de calor con importancia de parches y tokens:")
    display(res_shap['fig'])
else:
    print("‚ö†Ô∏è No se gener√≥ mapa de calor")


## üìà Ejemplo 3: An√°lisis de m√∫ltiples muestras


In [None]:
import pandas as pd
from tqdm import tqdm

# Analizar m√∫ltiples muestras para obtener estad√≠sticas
num_samples = 15  # Cambiar seg√∫n necesidades (m√°s muestras = m√°s tiempo)
samples_indices = np.random.choice(len(dataset), num_samples, replace=False)

results = []

print(f"üìä Analizando {num_samples} muestras aleatorias...")

for i, idx in enumerate(tqdm(samples_indices, desc="Procesando ISA")):
    sample = dataset[idx]
    image = sample['image']
    caption = sample['text']
    
    # ISA con SHAP (sin plot para mayor velocidad)
    res = run_isa_one(
        model, image, caption, device, 
        explain=True, plot=False
    )
    
    # Clasificar la alineaci√≥n
    if res['logit'] > 0:
        alignment_category = "Buena"
    elif res['logit'] > -1:
        alignment_category = "Moderada"  
    else:
        alignment_category = "Pobre"
    
    results.append({
        'sample_idx': idx,
        'logit': res['logit'],
        'alignment_category': alignment_category,
        'tscore': res['tscore'],
        'iscore': res['iscore'],
        'original_caption': caption[:60] + "..."
    })

# Crear DataFrame con resultados
df_results = pd.DataFrame(results)
print("\\n‚úÖ An√°lisis completado")
display(df_results)


In [None]:
# Estad√≠sticas del balance multimodal y ISA
import seaborn as sns

print("üìà Estad√≠sticas del an√°lisis ISA:")
print(f"Logit promedio: {df_results['logit'].mean():.4f} ¬± {df_results['logit'].std():.4f}")
print(f"TScore promedio: {df_results['tscore'].mean():.2%} ¬± {df_results['tscore'].std():.2%}")
print(f"IScore promedio: {df_results['iscore'].mean():.2%} ¬± {df_results['iscore'].std():.2%}")

# Distribuci√≥n de categor√≠as de alineaci√≥n
alignment_counts = df_results['alignment_category'].value_counts()
print(f"\\nüéØ Distribuci√≥n de alineaci√≥n:")
for categoria, count in alignment_counts.items():
    percentage = (count / len(df_results)) * 100
    print(f"  {categoria}: {count} ({percentage:.1f}%)")

# Visualizaci√≥n
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Distribuci√≥n de Logits (ISA)
axes[0,0].hist(df_results['logit'], bins=10, alpha=0.7, color='purple', edgecolor='black')
axes[0,0].set_title('Distribuci√≥n de Logits (ISA)')
axes[0,0].set_xlabel('Logit (Similitud Imagen-Texto)')
axes[0,0].set_ylabel('Frecuencia')
axes[0,0].axvline(df_results['logit'].mean(), color='red', linestyle='--', label='Media')
axes[0,0].axvline(0, color='green', linestyle=':', label='Umbral Buena Alineaci√≥n')
axes[0,0].legend()

# Distribuci√≥n de TScore
axes[0,1].hist(df_results['tscore'], bins=10, alpha=0.7, color='blue', edgecolor='black')
axes[0,1].set_title('Distribuci√≥n TScore (Importancia Texto)')
axes[0,1].set_xlabel('TScore')
axes[0,1].set_ylabel('Frecuencia')
axes[0,1].axvline(df_results['tscore'].mean(), color='red', linestyle='--', label='Media')
axes[0,1].legend()

# Scatter plot TScore vs IScore
axes[1,0].scatter(df_results['tscore'], df_results['iscore'], alpha=0.7, s=60)
axes[1,0].set_xlabel('TScore (Importancia Texto)')
axes[1,0].set_ylabel('IScore (Importancia Imagen)')
axes[1,0].set_title('Balance Texto vs Imagen')
axes[1,0].plot([0, 1], [1, 0], 'r--', alpha=0.5, label='Balance perfecto')
axes[1,0].legend()
axes[1,0].grid(True, alpha=0.3)

# Gr√°fico de barras para categor√≠as de alineaci√≥n
colors = ['green', 'orange', 'red']
bars = axes[1,1].bar(alignment_counts.index, alignment_counts.values, 
                     color=colors[:len(alignment_counts)], alpha=0.7, edgecolor='black')
axes[1,1].set_title('Distribuci√≥n de Categor√≠as de Alineaci√≥n')
axes[1,1].set_xlabel('Categor√≠a')
axes[1,1].set_ylabel('Frecuencia')

# A√±adir valores en las barras
for bar in bars:
    height = bar.get_height()
    axes[1,1].annotate(f'{int(height)}',
                       xy=(bar.get_x() + bar.get_width() / 2, height),
                       xytext=(0, 3),
                       textcoords="offset points",
                       ha='center', va='bottom')

plt.tight_layout()
plt.show()
