In [88]:
# Imports and device
import torch
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import numpy as np
from matplotlib import cm

device = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cpu")
)

print('Device:', device)

Device: mps


In [89]:
# SmallCNN definition (must match the trained model architecture)
class SmallCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 12, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2,2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(12*7*7, 12)
        self.fc2 = nn.Linear(12, num_classes)

    def forward(self, x):
        x = self.pool(torch.nn.functional.relu(self.conv1(x)))  # 14x14
        x = self.pool(x)                      # 7x7
        x = x.view(x.size(0), -1)
        x = self.dropout(torch.nn.functional.relu(self.fc1(x)))
        return self.fc2(x)

# convenience transform
to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()

In [90]:
models_dir = Path('../models')
model_name = models_dir / 'small_cnn.pth'
model = torch.load(model_name, map_location=device)
#model = torch.load(model_name, weights_only=False)
model.to(device)
model.eval()

  model = torch.load(model_name, map_location=device)


SmallCNN(
  (conv1): Conv2d(3, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout(p=0.25, inplace=False)
  (fc1): Linear(in_features=588, out_features=12, bias=True)
  (fc2): Linear(in_features=12, out_features=10, bias=True)
)

In [91]:
def smoothgrad(model, x, y, n, noise):
    grads = torch.zeros_like(x)
    for _ in range(n):
        noise_tensor = torch.randn_like(x) * noise
        x_noisy = (x + noise_tensor).clamp(0,1)
        x_noisy.requires_grad_(True)
        out = model(x_noisy)
        score = out[0,y]
        model.zero_grad(set_to_none=True)
        score.backward()
        grads += x_noisy.grad.detach()
    grads /= n
    grads = grads.abs().mean(1, keepdim=True)
    grads = grads[0] / (grads.max() + 1e-8)
    return grads.squeeze().cpu().numpy()

In [92]:
def colorize_cam(cam_01):
    # cam_01 -> np.float32 [H,W] en 0..1, luego aplica 'jet' y quita alfa
    if isinstance(cam_01, torch.Tensor):
        cam = cam_01.detach().cpu().squeeze().float().numpy()
    else:
        cam = np.array(cam_01, dtype=np.float32)
    cam = np.nan_to_num(cam, nan=0.0, posinf=1.0, neginf=0.0)
    cam = np.clip(cam, 0.0, 1.0).astype(np.float32)
    jet = cm.get_cmap('jet')(cam)[:, :, :3]        # (H,W,3) RGB 0..1
    return jet.astype(np.float32)

def overlay_color(x01, cam01, alpha=0.55):
    base = x01.detach().cpu().permute(1,2,0).float().numpy()      # (H,W,3) 0..1
    base = np.clip(base, 0.0, 1.0).astype(np.float32)
    heat = colorize_cam(cam01)                                     # (H,W,3) 0..1
    out  = np.clip((1 - alpha) * base + alpha * heat, 0.0, 1.0)
    return torch.from_numpy(out).permute(2,0,1) 

In [93]:
def smooth_grad_true_label(model, img, y_true, n=125, noise=0.3):
    # img = Image.open(img).convert("RGB").resize((28,28))
    # x = to_tensor(img).unsqueeze(0).to(device)   # [1,3,28,28]

    sg_map = smoothgrad(model, img.clone(), y_true, n=n, noise=noise)

    ov_sg  = overlay_color(img.squeeze(0).cpu(), torch.as_tensor(sg_map), alpha=0.55)

    out_dir = Path("../data/MNIST/smooth_grad_true_label"); out_dir.mkdir(parents=True, exist_ok=True)
    from torchvision.utils import save_image
    save_image(ov_sg,  out_dir/f"{y_true}_SG_true_label.png") 
    return sg_map
    

In [94]:
def saliency_true_logit(model, x, y_idx):
    """
    Saliency respecto al LOGIT de la clase correcta: |∇_x logit_y|.
    x: (1,C,28,28) en [0,1]
    """
    x = x.clone().detach().requires_grad_(True)
    logits = model(x)                 # pre-softmax
    target = logits[0, y_idx]
    model.zero_grad(set_to_none=True)
    target.backward()
    g = x.grad.detach().abs()[0]      # (C,H,W)
    if g.shape[0] > 1:
        g = g.mean(dim=0)
    else:
        g = g[0]
    g = g - g.min()
    g = g / (g.max() + 1e-8)
    return g.cpu().numpy()            # (H,W) en [0,1]


In [95]:
# Funciones necesarias luego
def predict_np(img_np):
    """ img_np: numpy array HWC [0,1] """
    img_t = torch.from_numpy(img_np.transpose(2,0,1)).unsqueeze(0).float().to(device)
    with torch.no_grad():
        pred = int(model(img_t).argmax(dim=1).item())
    return pred

In [96]:
def edit_with_saliency_minimal(img_pil, true_label, top_pct=0.15, darken_amt=0.30, s_map=None, n_sg=125, noise=0.3):
    """
    Edita SOLO dentro de la máscara top-k del saliency del logit verdadero:
      A) borrar ruido claro (llevar a blanco)
      B) reforzar trazo (oscurecer un poco)
    Devuelve (edited_np_uint8, pred_before, pred_after)
    """
    img28 = img_pil.convert("RGB").resize((28,28), Image.BILINEAR)
    x = to_tensor(img28).unsqueeze(0).to(device)

    # saliency & mask
    if s_map is None:
        s_map = smooth_grad_true_label(model, x.clone(), true_label, n=n_sg, noise=noise)  # (H,W) en [0,1]
    else:
        s_map = np.asarray(s_map, dtype=np.float32)

    th = np.percentile(s_map, 100*(1 - top_pct))
    mask = (s_map >= th).astype(np.float32)  # (H,W) {0,1}

    # a numpy [0,1]
    arr = (np.array(img28).astype(np.float32) / 255.0)
    pred0 = predict_np(arr)
    
    # PASO A: borrar ruido claro dentro de la máscara
    outA = arr.copy()
    bright = (outA.mean(axis=2) > 0.6).astype(np.float32)
    mA = (mask * bright)[:, :, None]
    outA = outA*(1 - mA) + 1.0*mA
    predA = predict_np(outA)

    if predA == true_label:
        return (outA*255).astype(np.uint8), pred0, predA

    # PASO B: reforzar trazo oscuro dentro de la máscara
    outB = outA.copy()
    dark = (outB.mean(axis=2) < 0.5).astype(np.float32)
    mB = (mask * dark)[:, :, None]
    outB = np.clip(outB - darken_amt*mB, 0, 1)
    predB  = predict_np(outB)
    
    return (outB*255).astype(np.uint8), pred0, predB

In [97]:
# Check percentage of editing an image
def calculate_edit_percentage(original_img, edited_img):
    original_pixels = original_img.load()
    edited_pixels = edited_img.load()
    width, height = original_img.size
    total_pixels = width * height
    changed_pixels = 0

    for x in range(width):
        for y in range(height):
            if original_pixels[x, y] != edited_pixels[x, y]:
                changed_pixels += 1

    return (changed_pixels / total_pixels) * 100

In [98]:
def try_fix_with_saliency(img_path, top_pcts=(0.12, 0.18, 0.24), darken_amt=0.30, grad_min=1e-6, n=125, noise=0.3):
    """
    Intenta corregir una imagen usando saliency si hay gradiente suficiente.
    Devuelve dict con resultado y guarda en .../edited/ si es válido.
    """
    edited_dir = img_path.parent / "edited"
    edited_dir.mkdir(parents=True, exist_ok=True)

    img_pil = Image.open(img_path).convert("RGB")
    true_label = int(Path(img_path).stem.split("label")[-1][0])  # último dígito del stem

    # mide norma de gradiente (respecto al logit de la clase verdadera)
    x28 = to_tensor(img_pil.convert("RGB").resize((28,28), Image.BILINEAR)).unsqueeze(0).to(device)
    s_map = smooth_grad_true_label(model, x28.clone(), true_label, n=n, noise=noise)
    gnorm = float(np.linalg.norm(s_map.reshape(-1)))

    if gnorm < grad_min:
        # sin gradiente útil: no intentamos saliency aquí
        return {"ok": False, "skipped": False, "reason": "vanishing_gradients", "grad_norm": gnorm}

    best = None
    for tp in top_pcts:
        edited_np, pred_before, pred_after = edit_with_saliency_minimal(img_pil, true_label, top_pct=tp, darken_amt=darken_amt, s_map=s_map, n_sg=n, noise=noise)
        edited = Image.fromarray(edited_np)

        p = calculate_edit_percentage(img_pil, edited)
        edited_np = (np.array(edited).astype(np.float32) / 255.0)
        pred_final = predict_np(edited_np)
    
        logline = f"{img_path.name} | before={pred_before} -> after={pred_final} | true={true_label} | edit%={p:.1f} | top_pct={tp} | grad_norm={gnorm:.2e}"
        print(logline)

        if (pred_final == true_label) and (p <= 40.0):
            edited.save(edited_dir / img_path.name)
            best = {"ok": True, "pred_before": pred_before, "pred_after": pred_final,
                    "edit_pct": p, "top_pct": tp, "grad_norm": gnorm}
            break

    if best is None:
        best = {"ok": False, "pred_before": pred_before, "pred_after": pred_final,
                "edit_pct": p, "top_pct": tp, "grad_norm": gnorm}
    return best

In [103]:
challenge_dir = Path("../data/MNIST/challenge")  # ajusta si tu ruta es otra
files = sorted(challenge_dir.glob("*.png"))

ok = 0
attempted = 0
skipped_grad = 0
for f in files:
    r = try_fix_with_saliency(f, top_pcts=(0.50, 0.40, 0.30, 0.20, 0.10),
        darken_amt=0.40,   # un poco más de efecto que 0.30
        grad_min=1e-6,
        n=150,             # mapa de saliencia algo más estable
        noise=0.25         # ruido moderado
        )

    if r.get("skipped", False):
        continue
    attempted += 1
    if r.get("reason") == "vanishing_gradients":
        skipped_grad += 1
        print(f"  -> saltada (vanishing gradients): {f.name} (‖grad‖={r['grad_norm']:.2e})")
    if r["ok"]:
        ok += 1

print(f"\nHe intentado corregir {attempted} imágenes (saltadas por gradiente: {skipped_grad}).")
print(f"Guardadas válidas (pred correcta y ≤40% editado): {ok}")
print("Se han guardado en:", challenge_dir / "edited")


  jet = cm.get_cmap('jet')(cam)[:, :, :3]        # (H,W,3) RGB 0..1


0_label5.png | before=1 -> after=3 | true=5 | edit%=49.4 | top_pct=0.5 | grad_norm=3.98e+00
0_label5.png | before=1 -> after=3 | true=5 | edit%=39.4 | top_pct=0.4 | grad_norm=3.98e+00
0_label5.png | before=1 -> after=1 | true=5 | edit%=29.3 | top_pct=0.3 | grad_norm=3.98e+00
0_label5.png | before=1 -> after=1 | true=5 | edit%=19.5 | top_pct=0.2 | grad_norm=3.98e+00
0_label5.png | before=1 -> after=1 | true=5 | edit%=9.7 | top_pct=0.1 | grad_norm=3.98e+00
  -> saltada (vanishing gradients): 1_label3.png (‖grad‖=0.00e+00)
2_label3.png | before=1 -> after=1 | true=3 | edit%=48.6 | top_pct=0.5 | grad_norm=6.81e+00
2_label3.png | before=1 -> after=1 | true=3 | edit%=38.8 | top_pct=0.4 | grad_norm=6.81e+00
2_label3.png | before=1 -> after=1 | true=3 | edit%=28.8 | top_pct=0.3 | grad_norm=6.81e+00
2_label3.png | before=1 -> after=1 | true=3 | edit%=19.3 | top_pct=0.2 | grad_norm=6.81e+00
2_label3.png | before=1 -> after=1 | true=3 | edit%=9.8 | top_pct=0.1 | grad_norm=6.81e+00
3_label7.png | 

In [104]:
# Check percentage of editing an image
def calculate_edit_percentage(original_img, edited_img):
    original_pixels = original_img.load()
    edited_pixels = edited_img.load()
    width, height = original_img.size
    total_pixels = width * height
    changed_pixels = 0

    for x in range(width):
        for y in range(height):
            if original_pixels[x, y] != edited_pixels[x, y]:
                changed_pixels += 1

    return (changed_pixels / total_pixels) * 100

In [105]:
# Create edited directory
edited_dir = challenge_dir / 'edited'

# Load edited images, check that they are predicted correctly and calculate edit percentages
for original_img_file, edited_img_file in zip(challenge_dir.glob('*.png'), edited_dir.glob('*.png')):
    original_img = Image.open(original_img_file)
    edited_img = Image.open(edited_img_file)
    # Convert the edited image to RGB if it's not
    if edited_img.mode != 'RGB':
        edited_img = edited_img.convert('RGB')

    # Check prediction
    img_tensor = to_tensor(edited_img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(img_tensor)
        pred = output.argmax(dim=1).item()
    
    print(f'Edited {edited_img_file.name}: Pred: {pred}, Label: {original_img_file.stem[-1]}, correct: {pred == int(original_img_file.stem[-1])}')

    # Calculate edit percentage
    edit_percentage = calculate_edit_percentage(original_img, edited_img)
    print(f'Edit Percentage: {edit_percentage:.2f}%')


Edited 3_label7.png: Pred: 7, Label: 7, correct: True
Edit Percentage: 36.86%
Edited 1_label3.png: Pred: 1, Label: 3, correct: False
Edit Percentage: 0.00%
Edited 4_label2.png: Pred: 2, Label: 2, correct: True
Edit Percentage: 8.93%
Edited 2_label3.png: Pred: 1, Label: 3, correct: False
Edit Percentage: 0.00%
Edited 0_label5.png: Pred: 1, Label: 5, correct: False
Edit Percentage: 0.00%
