-
Notifications
You must be signed in to change notification settings - Fork 0
/
saliency_map.py
53 lines (48 loc) · 1.33 KB
/
saliency_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from glasses_dataset import CustomImageDataset
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
def test(device):
batch_size = 1
cid = CustomImageDataset(is_train=False)
dataloader = DataLoader(cid, batch_size=batch_size, shuffle=False)
model = torch.load("models/cnn_trans.h5")
model.to(device)
x = None
y = None
count = 0
for (x, y) in dataloader:
z = x.cpu()
z = z.squeeze()
z = torch.permute(z, (1, 2, 0))
z = z.detach().numpy()
plt.imshow(z)
plt.show()
x = x.to(device)
x.requires_grad = True
y = y.to(device)
if y[0] != 1:
continue
y_hat = model(x)
index = y_hat.argmax()
final = y_hat[0,index]
final.backward()
x = x.grad
x = torch.abs(x)
max_val = torch.max(x)
scale = float(255/max_val)
#x = x * scale
x = x.squeeze()
x = torch.permute(x, (1, 2, 0))
x = torch.max(x, dim=2)[0]
x = x.cpu().detach().numpy()
plt.imshow(x, cmap="hot")
plt.show()
count += 1
if count == 10:
break
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test(device)