In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from dataset import get_dataset
import torchvision
import numpy as np
from arch import DQN_Conv

# Tải mô hình đơn giản (ví dụ: ResNet)
from arch import MNIST_CC

# 1. Cấu hình thiết bị
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Tải mô hình đã huấn luyện trước
model = MNIST_CC()
model.load_state_dict(torch.load('trained_model\mnist_cc.pth'))

model.eval().to(device)  # Đặt chế độ đánh giá (evaluation mode)

# 3. Định nghĩa FGSM Attack
def fgsm_attack(image, epsilon, gradient):
    # Lấy dấu của gradient
    sign_gradient = gradient.sign()
    sign_gradient[sign_gradient < 0] = 0
    # Tạo mẫu tấn công
    perturbed_image = image + epsilon * sign_gradient
    # Giới hạn giá trị pixel [0, 1]
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

# 4. Chuẩn bị dữ liệu
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
dataset = get_dataset('mnist', split='train')
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# 5. Tấn công mẫu đầu tiên
epsilon = 0.1  # Mức độ nhiễu
criterion = nn.CrossEntropyLoss()


dqn = DQN_Conv(28*28, 14*14)
dqn.load_state_dict(torch.load('model_0_trenvong_2.pth'))
dqn.eval().to(device)


  model.load_state_dict(torch.load('trained_model\mnist_cc.pth'))
  model.load_state_dict(torch.load('trained_model\mnist_cc.pth'))
  dqn.load_state_dict(torch.load('model_0_trenvong_2.pth'))


DQN_Conv(
  (classifier): Sequential(
    (0): Conv2d(2, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Dropout(p=0.25, inplace=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=9216, out_features=1024, bias=True)
    (8): ReLU()
    (9): Dropout(p=0.5, inplace=False)
    (10): Linear(in_features=1024, out_features=128, bias=True)
    (11): ReLU()
    (12): Dropout(p=0.5, inplace=False)
    (13): Linear(in_features=128, out_features=196, bias=True)
  )
)

In [20]:

t = 5

from tqdm import tqdm

total = 0
correct = 0

for images, labels in tqdm(dataloader):
    images, labels = images.to(device), labels.to(device)

    img_2_channel = torch.concatenate((images, images), 1)
    dqn_predict = dqn(img_2_channel)
    predict = dqn_predict.argmax(1)

    # dqn_img = dqn_predict.view(14, 14).detach().cpu().numpy()
    dqn_img = np.zeros((14, 14))
    dqn_img[predict//14, predict%14] = 1
    
    # Đặt chế độ tính gradient
    images.requires_grad = True
    
    # Dự đoán ban đầu
    outputs = model(images)
    loss = criterion(outputs, labels)
    
    # Tính gradient
    model.zero_grad()
    loss.backward()
    gradient = images.grad.data

    grid_img = np.zeros((14, 14))
    for i in range(14):
        for j in range(14):
            grid_img[i, j] = gradient[0, 0, i*2:(i+1)*2, j*2:(j+1)*2].mean().item()

    # Tạo mẫu nhiễu
    perturbed_image = fgsm_attack(images, epsilon, gradient)
    
    # Kiểm tra dự đoán trên mẫu bị tấn công
    outputs_perturbed = model(perturbed_image)
    _, final_pred = outputs_perturbed.max(1)
    
    # print(f"Label gốc: {labels.item()}, Dự đoán sau tấn công: {final_pred.item()}")

    # Hiển thị mẫu gốc, grad và mẫu bị tấn công
    # import matplotlib.pyplot as plt
    # plt.figure()
    # plt.subplot(141)
    # plt.imshow(images.squeeze().cpu().detach().numpy(), cmap='gray')
    # plt.title('Original Image')
    # plt.subplot(142)
    # plt.imshow(grid_img, cmap='gray')
    # plt.title('Gradient')
    # plt.subplot(143)
    # plt.imshow(perturbed_image.squeeze().detach().cpu().numpy(), cmap='gray')
    # plt.title('Perturbed Image')
    # plt.subplot(144)
    # plt.imshow(dqn_img, cmap='gray')
    # plt.title('DQN Image')


    # plt.show()
    
    # t -= 1
    # if t == 0:
    #     break
    total += 1
    if labels.item() != final_pred.item():
        correct += 1

    print(f"Accuracy: {correct/total}")

print(f"Accuracy: {correct/total}")

  0%|          | 3/60000 [00:00<43:57, 22.75it/s]

Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.25
Accuracy: 0.2
Accuracy: 0.16666666666666666


  0%|          | 10/60000 [00:00<35:57, 27.80it/s]

Accuracy: 0.14285714285714285
Accuracy: 0.25
Accuracy: 0.2222222222222222
Accuracy: 0.2
Accuracy: 0.2727272727272727
Accuracy: 0.3333333333333333


  0%|          | 16/60000 [00:00<36:35, 27.32it/s]

Accuracy: 0.38461538461538464
Accuracy: 0.35714285714285715
Accuracy: 0.4
Accuracy: 0.4375
Accuracy: 0.4117647058823529
Accuracy: 0.3888888888888889


  0%|          | 22/60000 [00:00<37:01, 27.00it/s]

Accuracy: 0.3684210526315789
Accuracy: 0.4
Accuracy: 0.42857142857142855
Accuracy: 0.4090909090909091
Accuracy: 0.391304347826087
Accuracy: 0.375


  0%|          | 28/60000 [00:01<37:17, 26.80it/s]

Accuracy: 0.36
Accuracy: 0.34615384615384615
Accuracy: 0.3333333333333333
Accuracy: 0.35714285714285715
Accuracy: 0.3448275862068966
Accuracy: 0.36666666666666664


  0%|          | 34/60000 [00:01<35:56, 27.81it/s]

Accuracy: 0.3548387096774194
Accuracy: 0.34375
Accuracy: 0.3333333333333333
Accuracy: 0.3235294117647059
Accuracy: 0.3142857142857143
Accuracy: 0.3055555555555556


  0%|          | 40/60000 [00:01<36:34, 27.33it/s]

Accuracy: 0.32432432432432434
Accuracy: 0.3157894736842105
Accuracy: 0.3333333333333333
Accuracy: 0.325
Accuracy: 0.3170731707317073
Accuracy: 0.30952380952380953


  0%|          | 46/60000 [00:01<36:19, 27.51it/s]

Accuracy: 0.32558139534883723
Accuracy: 0.3181818181818182
Accuracy: 0.3333333333333333
Accuracy: 0.32608695652173914
Accuracy: 0.3404255319148936
Accuracy: 0.3333333333333333


  0%|          | 52/60000 [00:01<37:48, 26.43it/s]

Accuracy: 0.32653061224489793
Accuracy: 0.32
Accuracy: 0.3137254901960784
Accuracy: 0.3076923076923077
Accuracy: 0.3018867924528302
Accuracy: 0.2962962962962963


  0%|          | 58/60000 [00:02<36:07, 27.65it/s]

Accuracy: 0.2909090909090909
Accuracy: 0.2857142857142857
Accuracy: 0.2982456140350877
Accuracy: 0.29310344827586204
Accuracy: 0.3050847457627119
Accuracy: 0.3


  0%|          | 64/60000 [00:02<36:19, 27.50it/s]

Accuracy: 0.29508196721311475
Accuracy: 0.2903225806451613
Accuracy: 0.2857142857142857
Accuracy: 0.28125
Accuracy: 0.2923076923076923
Accuracy: 0.30303030303030304


  0%|          | 70/60000 [00:02<36:10, 27.61it/s]

Accuracy: 0.29850746268656714
Accuracy: 0.29411764705882354
Accuracy: 0.2898550724637681
Accuracy: 0.2857142857142857
Accuracy: 0.28169014084507044
Accuracy: 0.2916666666666667


  0%|          | 76/60000 [00:02<35:24, 28.21it/s]

Accuracy: 0.2876712328767123
Accuracy: 0.28378378378378377
Accuracy: 0.28
Accuracy: 0.27631578947368424
Accuracy: 0.2857142857142857
Accuracy: 0.28205128205128205


  0%|          | 82/60000 [00:03<35:59, 27.75it/s]

Accuracy: 0.2911392405063291
Accuracy: 0.3
Accuracy: 0.30864197530864196
Accuracy: 0.3048780487804878
Accuracy: 0.3132530120481928
Accuracy: 0.30952380952380953
Accuracy: 0.3176470588235294


  0%|          | 91/60000 [00:03<35:12, 28.36it/s]

Accuracy: 0.313953488372093
Accuracy: 0.3103448275862069
Accuracy: 0.3181818181818182
Accuracy: 0.3258426966292135
Accuracy: 0.3333333333333333
Accuracy: 0.34065934065934067
Accuracy: 0.34782608695652173


  0%|          | 97/60000 [00:03<35:26, 28.17it/s]

Accuracy: 0.34408602150537637
Accuracy: 0.3404255319148936
Accuracy: 0.3368421052631579
Accuracy: 0.3333333333333333
Accuracy: 0.3402061855670103
Accuracy: 0.3469387755102041
Accuracy: 0.3434343434343434


  0%|          | 103/60000 [00:03<35:23, 28.21it/s]

Accuracy: 0.34
Accuracy: 0.33663366336633666
Accuracy: 0.3431372549019608
Accuracy: 0.33980582524271846
Accuracy: 0.33653846153846156
Accuracy: 0.3333333333333333


  0%|          | 109/60000 [00:03<35:41, 27.97it/s]

Accuracy: 0.330188679245283
Accuracy: 0.32710280373831774
Accuracy: 0.32407407407407407
Accuracy: 0.3211009174311927
Accuracy: 0.3181818181818182
Accuracy: 0.3153153153153153


  0%|          | 115/60000 [00:04<36:08, 27.62it/s]

Accuracy: 0.3125
Accuracy: 0.30973451327433627
Accuracy: 0.30701754385964913
Accuracy: 0.30434782608695654
Accuracy: 0.3017241379310345
Accuracy: 0.3076923076923077


  0%|          | 121/60000 [00:04<36:20, 27.46it/s]

Accuracy: 0.3135593220338983
Accuracy: 0.31092436974789917
Accuracy: 0.30833333333333335
Accuracy: 0.3140495867768595
Accuracy: 0.319672131147541
Accuracy: 0.3170731707317073


  0%|          | 127/60000 [00:04<37:17, 26.76it/s]

Accuracy: 0.3225806451612903
Accuracy: 0.32
Accuracy: 0.31746031746031744
Accuracy: 0.31496062992125984
Accuracy: 0.3125
Accuracy: 0.31007751937984496


  0%|          | 133/60000 [00:04<37:16, 26.77it/s]

Accuracy: 0.3076923076923077
Accuracy: 0.3053435114503817
Accuracy: 0.30303030303030304
Accuracy: 0.3007518796992481
Accuracy: 0.29850746268656714
Accuracy: 0.3037037037037037


  0%|          | 140/60000 [00:05<36:04, 27.66it/s]

Accuracy: 0.3014705882352941
Accuracy: 0.29927007299270075
Accuracy: 0.2971014492753623
Accuracy: 0.302158273381295
Accuracy: 0.3
Accuracy: 0.2978723404255319


  0%|          | 146/60000 [00:05<36:21, 27.44it/s]

Accuracy: 0.3028169014084507
Accuracy: 0.3076923076923077
Accuracy: 0.3055555555555556
Accuracy: 0.30344827586206896
Accuracy: 0.3013698630136986
Accuracy: 0.29931972789115646


  0%|          | 153/60000 [00:05<35:20, 28.22it/s]

Accuracy: 0.30405405405405406
Accuracy: 0.30201342281879195
Accuracy: 0.3
Accuracy: 0.2980132450331126
Accuracy: 0.29605263157894735
Accuracy: 0.29411764705882354
Accuracy: 0.2922077922077922


  0%|          | 160/60000 [00:05<35:17, 28.26it/s]

Accuracy: 0.2903225806451613
Accuracy: 0.28846153846153844
Accuracy: 0.28662420382165604
Accuracy: 0.2848101265822785
Accuracy: 0.2830188679245283
Accuracy: 0.28125


  0%|          | 163/60000 [00:05<36:23, 27.40it/s]

Accuracy: 0.2795031055900621
Accuracy: 0.2777777777777778
Accuracy: 0.2822085889570552





KeyboardInterrupt: 