In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from dataset import get_dataset
import torchvision
import numpy as np
from arch import DQN_Conv

# Tải mô hình đơn giản (ví dụ: ResNet)
from arch import CONV_MNIST

# 1. Cấu hình thiết bị
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Tải mô hình đã huấn luyện trước
model = CONV_MNIST()
model.load_state_dict(torch.load('trained_model/mnist_cnn_best.pth'))

model.eval().to(device)  # Đặt chế độ đánh giá (evaluation mode)

# 3. Định nghĩa FGSM Attack
def fgsm_attack(image, epsilon, gradient):
    # Lấy dấu của gradient
    sign_gradient = gradient.sign()

    new_sign = torch.zeros(28, 28)
    for i in range(28):
        for j in range(28):
            x, y = i//2, j//2
            new_sign[i][j] = torch.mean(sign_gradient[0][0][x*2:x*2+2, y*2:y*2+2])

    new_sign = new_sign.sign()
    
    # sign_gradient[sign_gradient] = 0
    # Tạo mẫu tấn công
    perturbed_image = image + epsilon * sign_gradient
    # Giới hạn giá trị pixel [0, 1]
    perturbed_image = torch.clamp(perturbed_image, 0, 1)

    return perturbed_image, sign_gradient

def fgsm_attack_2(image, epsilon, mask):
    # 14*14 -> 28*28
    fit_mask = torch.zeros(image.shape)
    for i in range(14):
        for j in range(14):
            fit_mask[0][0][i*2][j*2] = mask[i][j]
            fit_mask[0][0][i*2+1][j*2] = mask[i][j]
            fit_mask[0][0][i*2][j*2+1] = mask[i][j]
            fit_mask[0][0][i*2+1][j*2+1] = mask[i][j]
    # Tạo mẫu nhiễu
    perturbed_image = image + epsilon * fit_mask.to(device)
    # Giới hạn giá trị pixel [0, 1]
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

# 4. Chuẩn bị dữ liệu
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
dataset = get_dataset('mnist', split='train')
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# 5. Tấn công mẫu đầu tiên
epsilon = 0.2  # Mức độ nhiễu
criterion = nn.CrossEntropyLoss()


dqn = DQN_Conv(28*28, 14*14)
dqn.load_state_dict(torch.load('model_0_trrenvong_2.pth'))
dqn.eval().to(device)


  model.load_state_dict(torch.load('trained_model/mnist_cnn_best.pth'))
  dqn.load_state_dict(torch.load('model_0_trrenvong_2.pth'))


DQN_Conv(
  (classifier): Sequential(
    (0): Conv2d(2, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): Tanh()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): Tanh()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Dropout(p=0.25, inplace=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=9216, out_features=1024, bias=True)
    (8): Tanh()
    (9): Dropout(p=0.5, inplace=False)
    (10): Linear(in_features=1024, out_features=128, bias=True)
    (11): Tanh()
    (12): Dropout(p=0.5, inplace=False)
    (13): Linear(in_features=128, out_features=196, bias=True)
  )
)

In [28]:

t = 30

from tqdm import tqdm

total = 0
correct = 0

for images, labels in tqdm(dataloader):
    images, labels = images.to(device), labels.to(device)

    img_2_channel = torch.concatenate((images, images), 1)
    dqn_predict = dqn(img_2_channel)
    predict = dqn_predict.argmax(1)

    dqn_img = dqn_predict.view(14, 14).detach().cpu().numpy()
    #normalize
    dqn_img = (dqn_img - dqn_img.min()) / (dqn_img.max() - dqn_img.min())
    # 1 if > 0 else 0
    # dqn_img = np.where(dqn_img > 0, 1, 0)


    # print(f"Predict: {dqn_img}")
    # dqn_img = np.zeros((14, 14))
    # dqn_img[predict//14, predict%14] = 1
    
    # Đặt chế độ tính gradient
    images.requires_grad = True
    
    # Dự đoán ban đầu
    outputs = model(images)

    start_label = outputs.argmax(1).item()

    loss = criterion(outputs, labels)
    
    # Tính gradient
    model.zero_grad()
    loss.backward()
    gradient = images.grad.data

    # grid_img = np.zeros((14, 14))
    # for i in range(14):
    #     for j in range(14):
    #         grid_img[i, j] = gradient[0, 0, i*2:(i+1)*2, j*2:(j+1)*2].mean().item()

    grid_img = gradient.view(28, 28).detach().sign().cpu().numpy()
    # 0 if < 0 else 1
    grid_img = np.where(grid_img < 0, 0, 1)

    # Tạo mẫu nhiễu
    perturbed_image, sign_gradient = fgsm_attack(images, epsilon, gradient)

    distance = torch.norm(perturbed_image - images)
    # print(f"Distance FGSM: {distance}")
    
    # Kiểm tra dự đoán trên mẫu bị tấn công
    outputs_perturbed = model(perturbed_image)
    _, final_pred = outputs_perturbed.max(1)
    
    # print(f"Label gốc: {start_label}, FGSM: {final_pred.item()}")

    attack2_image = fgsm_attack_2(images, epsilon, torch.tensor(dqn_img).float().to(device))
    outputs_perturbed2 = model(attack2_image)
    _, final_pred2 = outputs_perturbed2.max(1)
    # print(f"CC: {final_pred2.item()}")
    distance2 = torch.norm(attack2_image - images)
    # print(f"Distance CC: {distance2}")

    # # Hiển thị mẫu gốc, grad và mẫu bị tấn công
    # import matplotlib.pyplot as plt
    # plt.figure()
    # plt.subplot(151)
    # plt.imshow(images.squeeze().cpu().detach().numpy(), cmap='gray')
    # plt.title('Image')
    # plt.subplot(152)
    # plt.imshow(grid_img, cmap='gray')
    # plt.title('Gradient')
    # plt.subplot(153)
    # plt.imshow(perturbed_image.squeeze().detach().cpu().numpy(), cmap='gray')
    # plt.title('FGSM')
    # plt.subplot(154)
    # plt.imshow(attack2_image.squeeze().detach().cpu().numpy(), cmap='gray')
    # plt.title('CC')
    # plt.subplot(155)
    # plt.imshow(dqn_img, cmap='gray')
    # plt.title('DQN Mask')


    # plt.show()
    
    # t -= 1
    # if t == 0:
    #     break
    total += 1
    if start_label != final_pred.item():
        correct += 1

    print(f"Accuracy: {correct/total}")

# print(f"Accuracy: {correct/total}")

  0%|          | 2/60000 [00:00<2:17:53,  7.25it/s]

Accuracy: 1.0
Accuracy: 1.0


  0%|          | 4/60000 [00:00<2:23:23,  6.97it/s]

Accuracy: 1.0
Accuracy: 1.0


  0%|          | 6/60000 [00:00<2:36:15,  6.40it/s]

Accuracy: 0.8
Accuracy: 0.6666666666666666


  0%|          | 8/60000 [00:01<2:49:13,  5.91it/s]

Accuracy: 0.7142857142857143
Accuracy: 0.625


  0%|          | 10/60000 [00:01<2:50:27,  5.87it/s]

Accuracy: 0.6666666666666666
Accuracy: 0.7


  0%|          | 12/60000 [00:01<2:35:49,  6.42it/s]

Accuracy: 0.6363636363636364
Accuracy: 0.6666666666666666


  0%|          | 14/60000 [00:02<2:25:14,  6.88it/s]

Accuracy: 0.6923076923076923
Accuracy: 0.6428571428571429


  0%|          | 16/60000 [00:02<2:19:15,  7.18it/s]

Accuracy: 0.6
Accuracy: 0.625


  0%|          | 18/60000 [00:02<2:14:34,  7.43it/s]

Accuracy: 0.5882352941176471
Accuracy: 0.6111111111111112


  0%|          | 19/60000 [00:02<2:15:46,  7.36it/s]

Accuracy: 0.631578947368421


  0%|          | 21/60000 [00:03<2:32:46,  6.54it/s]

Accuracy: 0.6
Accuracy: 0.6190476190476191


  0%|          | 22/60000 [00:03<2:47:05,  5.98it/s]

Accuracy: 0.5909090909090909


  0%|          | 24/60000 [00:03<2:45:11,  6.05it/s]

Accuracy: 0.5652173913043478
Accuracy: 0.5416666666666666


  0%|          | 25/60000 [00:03<2:39:35,  6.26it/s]

Accuracy: 0.56


  0%|          | 27/60000 [00:04<2:42:30,  6.15it/s]

Accuracy: 0.5384615384615384
Accuracy: 0.5185185185185185


  0%|          | 29/60000 [00:04<2:27:38,  6.77it/s]

Accuracy: 0.5357142857142857
Accuracy: 0.5172413793103449


  0%|          | 31/60000 [00:04<2:19:56,  7.14it/s]

Accuracy: 0.5333333333333333
Accuracy: 0.5483870967741935


  0%|          | 32/60000 [00:04<2:16:37,  7.32it/s]

Accuracy: 0.5625


  0%|          | 34/60000 [00:05<2:40:34,  6.22it/s]

Accuracy: 0.5757575757575758
Accuracy: 0.5882352941176471


  0%|          | 36/60000 [00:05<2:25:42,  6.86it/s]

Accuracy: 0.6
Accuracy: 0.5833333333333334


  0%|          | 38/60000 [00:05<2:28:45,  6.72it/s]

Accuracy: 0.5945945945945946
Accuracy: 0.6052631578947368


  0%|          | 40/60000 [00:06<2:21:50,  7.05it/s]

Accuracy: 0.6153846153846154
Accuracy: 0.625


  0%|          | 42/60000 [00:06<2:16:43,  7.31it/s]

Accuracy: 0.6097560975609756
Accuracy: 0.6190476190476191


  0%|          | 44/60000 [00:06<2:16:16,  7.33it/s]

Accuracy: 0.6046511627906976
Accuracy: 0.5909090909090909


  0%|          | 46/60000 [00:06<2:14:29,  7.43it/s]

Accuracy: 0.5777777777777777
Accuracy: 0.5869565217391305


  0%|          | 47/60000 [00:07<2:32:06,  6.57it/s]

Accuracy: 0.5957446808510638


  0%|          | 49/60000 [00:07<2:41:19,  6.19it/s]

Accuracy: 0.5833333333333334
Accuracy: 0.5918367346938775


  0%|          | 51/60000 [00:07<2:26:46,  6.81it/s]

Accuracy: 0.6
Accuracy: 0.6078431372549019


  0%|          | 53/60000 [00:07<2:18:27,  7.22it/s]

Accuracy: 0.6153846153846154
Accuracy: 0.6226415094339622


  0%|          | 55/60000 [00:08<2:20:32,  7.11it/s]

Accuracy: 0.6296296296296297
Accuracy: 0.6181818181818182


  0%|          | 57/60000 [00:08<2:18:38,  7.21it/s]

Accuracy: 0.6071428571428571
Accuracy: 0.5964912280701754


  0%|          | 58/60000 [00:08<2:18:07,  7.23it/s]

Accuracy: 0.5862068965517241
Accuracy: 0.576271186440678


  0%|          | 61/60000 [00:09<2:24:32,  6.91it/s]

Accuracy: 0.5833333333333334
Accuracy: 0.5737704918032787


  0%|          | 63/60000 [00:09<2:18:20,  7.22it/s]

Accuracy: 0.5806451612903226
Accuracy: 0.5714285714285714


  0%|          | 65/60000 [00:09<2:15:10,  7.39it/s]

Accuracy: 0.578125
Accuracy: 0.5692307692307692


  0%|          | 65/60000 [00:09<2:29:24,  6.69it/s]


KeyboardInterrupt: 