In [5]:
cd C:\Users\abc09\Desktop\master\蒙特婁理工大學實習\Poly_Project

C:\Users\abc09\Desktop\master\蒙特婁理工大學實習\Poly_Project


In [6]:
# === 1. 安裝需要的套件 ===
!pip install torch torchvision ipywidgets


# === 2. 匯入套件 ===
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from ipywidgets import interact, IntSlider, FloatSlider, Dropdown
import models

# === 3. Dataset (CIFAR-10) ===
transform = transforms.Compose([
    transforms.ToTensor(),
])

cifar10_test = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

CIFAR10_CLASSES = [
    "airplane","automobile","bird","cat","deer",
    "dog","frog","horse","ship","truck"
]

# === 4. Normalize / Denormalize (ImageNet stats for pretrained models) ===
__imagenet_stats = {'mean': [0.485, 0.456, 0.406],
                   'std': [0.229, 0.224, 0.225]}

def normalize(img):
    mean = torch.tensor(__imagenet_stats['mean']).view(3,1,1)
    std = torch.tensor(__imagenet_stats['std']).view(3,1,1)
    return (img - mean) / std

def tensor_to_img(t):
    t = t.cpu().detach().numpy()
    if t.ndim == 3:
        t = np.transpose(t, (1, 2, 0))
    return np.clip(t, 0, 1)


# === 5. 載入模型 (用預訓練 ResNet18 for demo) ===
device = "cuda" if torch.cuda.is_available() else "cpu"

model = models.__dict__["resnet_binary"]
model_config = {'input_size': 32, 'dataset': "cifar10"}
model = model(**model_config)
checkpoint_bin = torch.load(
    "C:/Users/abc09/Desktop/master/蒙特婁理工大學實習/Poly_Project/model_best_cifar10_bin_wo_bn.pth.tar"
)
model.load_state_dict(checkpoint_bin['state_dict'])
model.eval()
model.to(device)

# === 6. 抓取 activations ===
def get_activations(model, x):
    activations = {}
    activations_inp = {}
    hooks = []

    def save_output(name):
        def hook(module, inp, out):
            activations[name] = out.detach().cpu()
            activations_inp[name] = inp[0].detach().cpu()
        return hook

    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            hooks.append(module.register_forward_hook(save_output(name)))

    with torch.no_grad():
        _ = model(x)

    for h in hooks:
        h.remove()

    return activations, activations_inp

# === 7. baseline (挑一張測試圖) ===
img_raw, label = cifar10_test[0]
x_orig = normalize(img_raw).unsqueeze(0).to(device)

acts_orig, acts_orig_inp = get_activations(model, x_orig)

layer_names = list(acts_orig.keys())
layer_names_inp = list(acts_orig_inp.keys())

print(f"原始圖標籤: {CIFAR10_CLASSES[label]}")

找不到檔案 - C:\Users\abc09\AppData\Local\Temp\doskey-macros.txt
Defaulting to user installation because normal site-packages is not writeable



[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: C:\Users\abc09\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


原始圖標籤: cat


In [19]:
import ipywidgets as widgets
from IPython.display import display

# === 你的原始函數 ===
def visualize_pixel_layer(channel=0 ,row=0, col=0, val=0.0, layer=layer_names[0]):
    img_mod = img_raw.clone()
    img_mod[channel, row, col] = val

    fig, ax = plt.subplots(1, 2, figsize=(6,3))
    ax[0].imshow(tensor_to_img(img_raw))
    ax[0].set_title("Original")
    ax[0].axis("off")

    ax[1].imshow(tensor_to_img(img_mod))
    ax[1].set_title("Modified")
    ax[1].axis("off")
    plt.show()

    x_mod = normalize(img_mod).unsqueeze(0).to(device)
    acts_mod, acts_mod_inp = get_activations(model, x_mod)

    a_orig = acts_orig[layer][0]   # shape: (C,H,W)
    a_orig_inp = torch.sign(acts_orig_inp[layer][0])
    
    a_mod = acts_mod[layer][0]
    a_mod_inp = torch.sign(acts_mod_inp[layer][0])

    if layer == "fc":
        print("orig_act:",a_orig)
        print("mod_act:", a_mod)
        print("diff:", a_mod - a_orig)
        print("orig_pred:",torch.max(a_mod,0)[1], "mod_pred:",torch.max(a_orig,0)[1])
    else:
    
        # === output feature maps ===
        num_channels = min(6, a_orig.shape[0])
        fig, axes = plt.subplots(3, num_channels, figsize=(15, 7))
        fig.suptitle(f"Layer_output: {layer} | Pixel=({row},{col}) val={val:.2f}", fontsize=14)
    
        for ch in range(num_channels):
            im0 = axes[0, ch].imshow(a_orig[ch].numpy(), cmap="viridis")
            fig.colorbar(im0, ax=axes[0, ch], fraction=0.046, pad=0.04)
            axes[0, ch].set_title(f"Orig ch{ch}")
            axes[0, ch].axis("off")
    
            im1 = axes[1, ch].imshow(a_mod[ch].numpy(), cmap="viridis")
            fig.colorbar(im1, ax=axes[1, ch], fraction=0.046, pad=0.04)
            axes[1, ch].set_title(f"Mod ch{ch}")
            axes[1, ch].axis("off")
    
            diff = a_mod[ch] - a_orig[ch]
            im2 = axes[2, ch].imshow(diff.numpy(), cmap="bwr", vmin=-diff.abs().max(), vmax=diff.abs().max())
            fig.colorbar(im2, ax=axes[2, ch], fraction=0.046, pad=0.04)
            axes[2, ch].set_title("Diff")
            axes[2, ch].axis("off")
    
        plt.show()

    # # === input feature maps ===
    # num_channels = min(6, a_orig_inp.shape[0])
    # fig, axes = plt.subplots(3, num_channels, figsize=(15, 7))
    # fig.suptitle(f"Layer_input: {layer} | Pixel=({row},{col}) val={val:.2f}", fontsize=14)
    
    # for ch in range(num_channels):
    #     im0 = axes[0, ch].imshow(a_orig_inp[ch].numpy(), cmap="viridis")
    #     fig.colorbar(im0, ax=axes[0, ch], fraction=0.046, pad=0.04)
    #     axes[0, ch].set_title(f"Orig ch{ch}")
    #     axes[0, ch].axis("off")

    #     im1 = axes[1, ch].imshow(a_mod_inp[ch].numpy(), cmap="viridis")
    #     fig.colorbar(im1, ax=axes[1, ch], fraction=0.046, pad=0.04)
    #     axes[1, ch].set_title(f"Mod ch{ch}")
    #     axes[1, ch].axis("off")

    #     diff = a_mod_inp[ch] - a_orig_inp[ch]
    #     print(diff)
    #     im2 = axes[2, ch].imshow(diff.numpy(), cmap="bwr", vmin=-diff.abs().max(), vmax=diff.abs().max())
    #     fig.colorbar(im2, ax=axes[2, ch], fraction=0.046, pad=0.04)
    #     axes[2, ch].set_title("Diff")
    #     axes[2, ch].axis("off")

    # plt.show()


# === 建立控制元件 ===
channel = widgets.IntSlider(min=0, max=2, step=1, value=0, description="channel")
row     = widgets.IntSlider(min=0, max=31, step=1, value=0, description="row")
col     = widgets.IntSlider(min=0, max=31, step=1, value=0, description="col")
val     = widgets.FloatSlider(min=0.0, max=1.0, step=0.05, value=0.5, description="val")
layer   = widgets.Dropdown(options=layer_names, value=layer_names[0], description="layer")

button = widgets.Button(description="Run")
out = widgets.Output()

# === 按下按鈕才會執行 ===
def on_button_click(b):
    with out:
        out.clear_output()
        visualize_pixel_layer(
            channel=channel.value,
            row=row.value,
            col=col.value,
            val=val.value,
            layer=layer.value
        )

button.on_click(on_button_click)

# === 顯示 UI ===
display(channel, row, col, val, layer, button, out)


IntSlider(value=0, description='channel', max=2)

IntSlider(value=0, description='row', max=31)

IntSlider(value=0, description='col', max=31)

FloatSlider(value=0.5, description='val', max=1.0, step=0.05)

Dropdown(description='layer', options=('conv1', 'layer1.0.conv1', 'layer1.0.conv2', 'layer1.1.conv1', 'layer1.…

Button(description='Run', style=ButtonStyle())

Output()