In [None]:
# === CELDA: Importaciones y carga de datos + modelo para explicabilidad ===
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from captum.attr import IntegratedGradients
import matplotlib.pyplot as plt
import numpy as np
from your_model_file import FineTuneResNet50  # ajusta al nombre de tu módulo

# 1) Cargar checkpoint
ckpt_path = "checkpoints/resnet50_animals_exp.pth"
checkpoint = torch.load(ckpt_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# 2) Reconstruir transform
tf = transforms.Compose([
    transforms.Resize(checkpoint["transform"]["resize"]),
    transforms.ToTensor(),
    transforms.Normalize(mean=checkpoint["transform"]["normalize_mean"],
                         std= checkpoint["transform"]["normalize_std"])
])

# 3) Cargar dataset y split
images_dir = r"C:\Users\juanj\Desktop\Reconocimineto-AnimalesDomesticos-CNN-Explicabilidad\data\images"
full_ds = datasets.ImageFolder(root=images_dir, transform=tf)
n = len(full_ds)
n_train = int(checkpoint["train_val_ratio"] * n)
n_val   = n - n_train
_, val_ds = random_split(full_ds, [n_train, n_val],
                         generator=torch.Generator().manual_seed(checkpoint["split_seed"]))

val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

# 4) Reconstruir y cargar modelo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FineTuneResNet50(num_classes=checkpoint["num_classes"])
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
model.to(device)

# 5) Preparar Captum
ig = IntegratedGradients(model)

# 6) Función de visualización
def show_attr(image, attr, title="Attribution"):
    attr = attr.sum(dim=0).cpu().detach().numpy()
    attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-8)
    plt.imshow(image.permute(1,2,0).cpu(), alpha=0.8)
    plt.imshow(attr, cmap='hot', alpha=0.4)
    plt.title(title)
    plt.axis('off')
    plt.show()

# ¡Listo! Ahora puedes usar el bloque de IG:
# images, labels = next(iter(val_loader))
# img, lbl = images[0:1].to(device), labels[0].item()
# attr_ig, _ = ig.attribute(img, target=lbl, return_convergence_delta=True)
# show_attr(img[0], attr_ig[0], title=f"IG para clase {lbl}")


In [None]:
images, labels = next(iter(val_loader))
img, lbl = images[0:1].to(device), labels[0].item()
attr_ig, _ = ig.attribute(img, target=lbl, return_convergence_delta=True)
show_attr(img[0], attr_ig[0], title=f"IG para clase {lbl}")

In [None]:
from captum.attr import IntegratedGradients
import matplotlib.pyplot as plt
import numpy as np

ig = IntegratedGradients(model)

def show_attr(image, attr, title="Attribution"):
    attr = attr.sum(dim=0).cpu().detach().numpy()
    # normalizamos para visualizar
    attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-8)
    plt.imshow(image.permute(1,2,0).cpu(), alpha=0.8)
    plt.imshow(attr, cmap='hot', alpha=0.4)
    plt.title(title)
    plt.axis('off')
    plt.show()

# Ejemplo con un batch de validación
model.eval()
images, labels = next(iter(val_loader))
img = images[0:1].to(device)
lbl = labels[0].item()

# Calculamos atribuciones
attr_ig, _ = ig.attribute(img, target=lbl, return_convergence_delta=True)

# Visualizamos
show_attr(img[0], attr_ig[0], title=f"IG para clase {lbl}")
