In [None]:
import os
import torch
import timm
import shap
import numpy as np
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset
import matplotlib.pyplot as plt
import PIL.Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
torch.set_grad_enabled(False)
model = timm.create_model('vit_base_patch16_224_dino', pretrained=False, num_classes=13)
model.load_state_dict(torch.load("best_vit_dino.pth", map_location=device))
model = model.to(device)
model.eval()


  model = create_fn(
  model.load_state_dict(torch.load("best_vit_dino.pth", map_location=device))


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

inv_transform = transforms.Normalize(
    mean=[-m / s for m, s in zip(mean, std)],
    std=[1 / s for s in std]
)

In [None]:
data_path = "paddy-disease-classification"
dataset = ImageFolder(data_path, transform=transform)
class_names = dataset.classes
num_classes = len(class_names)

test_indices = torch.load("test_indices.pth")
test_subset = Subset(dataset, test_indices)

  test_indices = torch.load("test_indices.pth")


In [None]:
def nchw_to_nhwc(x): return x.permute(0, 2, 3, 1)
def nhwc_to_nchw(x): return x.permute(0, 3, 1, 2)

def predict(imgs_np):
    imgs = nhwc_to_nchw(torch.tensor(imgs_np).float()).to(device)
    with torch.no_grad():
        logits = model(imgs)
    return logits

seen = set()
X_batch = []
y_batch = []
image_ids = []

for idx in range(len(test_subset)):
    img, label = test_subset[idx]
    if label not in seen:
        seen.add(label)
        X_batch.append(img)
        y_batch.append(label)
        image_ids.append(idx)
    if len(seen) == num_classes:
        break

if len(seen) < num_classes:
    missing = set(range(num_classes)) - seen
    print(f"Test set doesn't cover all classes. Missing: {[class_names[m] for m in missing]}")

In [None]:
X_tensor = torch.stack(X_batch)
X_nhwc = nchw_to_nhwc(X_tensor)

masker = shap.maskers.Image("blur(128,128)", X_nhwc[0].shape)
explainer = shap.Explainer(predict, masker, output_names=class_names)

shap_values = explainer(
    X_nhwc.numpy(),
    max_evals=5000,
    batch_size=10,
    outputs=shap.Explanation.argsort.flip[:1],
)

X_inv = torch.stack([inv_transform(img) for img in X_tensor])
X_inv_nhwc = nchw_to_nhwc(X_inv).numpy()
shap_values.data = X_inv_nhwc

output_dir = "shap_outputs"
os.makedirs(output_dir, exist_ok=True)

for i in range(len(y_batch)):
    label = y_batch[i]
    class_label = class_names[label]
    output_path = os.path.join(output_dir, f"{class_label}.png")
    shap.image_plot(
        shap_values[i:i+1],
        show=False
    )
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
    print(f"Saved SHAP for class '{class_label}' → {output_path}")

print(f"\nAll SHAP images saved in: {output_dir}")



PartitionExplainer explainer:   8%|████▋                                                        | 1/13 [00:00<?, ?it/s]
[A%|                                                                                         | 0/4998 [00:00<?, ?it/s]
[A%|█████████████▌                                                               | 880/4998 [00:00<00:00, 6980.41it/s]
[A%|████████████████████████▎                                                    | 1580/4998 [00:04<00:11, 303.00it/s]
[A%|████████████████████████████▉                                                | 1880/4998 [00:06<00:12, 251.27it/s]
[A%|███████████████████████████████▌                                             | 2050/4998 [00:07<00:12, 229.13it/s]
[A%|█████████████████████████████████▎                                           | 2160/4998 [00:07<00:13, 216.28it/s]
[A%|██████████████████████████████████▌                                          | 2240/4998 [00:08<00:13, 206.91it/s]
[A%|███████████████████████████████████

Saved SHAP for class 'bacterial_leaf_blight' → shap_outputs\bacterial_leaf_blight.png
Saved SHAP for class 'normal' → shap_outputs\normal.png
Saved SHAP for class 'blast' → shap_outputs\blast.png
Saved SHAP for class 'black_stem_borer' → shap_outputs\black_stem_borer.png
Saved SHAP for class 'leaf_roller' → shap_outputs\leaf_roller.png
Saved SHAP for class 'bacterial_panicle_blight' → shap_outputs\bacterial_panicle_blight.png
Saved SHAP for class 'tungro' → shap_outputs\tungro.png
Saved SHAP for class 'white_stem_borer' → shap_outputs\white_stem_borer.png
Saved SHAP for class 'downy_mildew' → shap_outputs\downy_mildew.png
Saved SHAP for class 'hispa' → shap_outputs\hispa.png
Saved SHAP for class 'yellow_stem_borer' → shap_outputs\yellow_stem_borer.png
Saved SHAP for class 'brown_spot' → shap_outputs\brown_spot.png
Saved SHAP for class 'bacterial_leaf_streak' → shap_outputs\bacterial_leaf_streak.png

All SHAP images saved in: shap_outputs
