In [5]:
import torch
import matplotlib.pyplot as plt

# 定义简单的 CNN 模型结构（必须与训练时的模型结构一致）
class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.fc1 = torch.nn.Linear(64 * 14 * 14, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载两个模型
benign_model_path = './model/benign_model.pth'
backdoor_model_path = './model/backdoor_model.pth'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

benign_model = SimpleCNN().to(device)
benign_model.load_state_dict(torch.load(benign_model_path, map_location=device))

backdoor_model = SimpleCNN().to(device)
backdoor_model.load_state_dict(torch.load(backdoor_model_path, map_location=device))

# 创建 60000 个数据点的 X 和随机标签 y
num_data_points = 60000
X = torch.randn(num_data_points, 1, 28, 28, requires_grad=True, device=device)
y = torch.randint(0, 10, (num_data_points,), device=device)

# 定义优化器
optimizer = torch.optim.Adam([X], lr=0.01)

# 获取 backdoor 模型参数
backdoor_params = {name: param.detach() for name, param in backdoor_model.named_parameters()}

# 训练过程
num_epochs = 600000  # 必须是 60000 的倍数
data_loader = torch.utils.data.DataLoader(torch.arange(num_data_points), batch_size=1, shuffle=False)

for epoch in range(num_epochs):
    epoch_id = epoch % num_data_points  # 当前 epoch 对应的数据点索引
    idx = epoch_id  # 数据点索引

    # Fine-tune benign 模型
    fine_tuned_model = SimpleCNN().to(device)
    fine_tuned_model.load_state_dict(benign_model.state_dict())
    fine_tune_optimizer = torch.optim.Adam(fine_tuned_model.parameters(), lr=0.01)

    fine_tuned_model.train()
    fine_tune_optimizer.zero_grad()

    images = X[idx:idx + 1]  # 取第 idx 个数据点
    labels = y[idx:idx + 1]

    outputs = fine_tuned_model(images)
    loss = torch.nn.functional.cross_entropy(outputs, labels)
    loss.backward()
    fine_tune_optimizer.step()

    # 计算参数差异
    param_diff = 0
    for name, param in fine_tuned_model.named_parameters():
        param_diff += torch.norm(param - backdoor_params[name])**2

    # 优化当前数据点的 X
    optimizer.zero_grad()
    param_diff.backward()
    optimizer.step()

    # 打印进度
    if (epoch + 1) % 1000 == 0:
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {param_diff.item():.4f}")

# 保存优化后的 X
optimized_X = X.detach()
torch.save(optimized_X, 'optimized_X.pth')
print("Optimized training set X saved as 'optimized_X.pth'")


  benign_model.load_state_dict(torch.load(benign_model_path, map_location=device))
  backdoor_model.load_state_dict(torch.load(backdoor_model_path, map_location=device))


Epoch 1000/600000, Loss: 3752.0017
Epoch 2000/600000, Loss: 3747.6431
Epoch 3000/600000, Loss: 3749.3633
Epoch 4000/600000, Loss: 3746.6482
Epoch 5000/600000, Loss: 3751.6179
Epoch 6000/600000, Loss: 3749.6870
Epoch 7000/600000, Loss: 3750.5388
Epoch 8000/600000, Loss: 3751.8865
Epoch 9000/600000, Loss: 3750.4155
Epoch 10000/600000, Loss: 3752.0361
Epoch 11000/600000, Loss: 3748.1204
Epoch 12000/600000, Loss: 3747.8115
Epoch 13000/600000, Loss: 3753.5991
Epoch 14000/600000, Loss: 3747.9905
Epoch 15000/600000, Loss: 3749.0955
Epoch 16000/600000, Loss: 3752.2773
Epoch 17000/600000, Loss: 3750.1797
Epoch 18000/600000, Loss: 3746.3223
Epoch 19000/600000, Loss: 3751.6277
Epoch 20000/600000, Loss: 3751.0840
Epoch 21000/600000, Loss: 3751.7490
Epoch 22000/600000, Loss: 3755.7175
Epoch 23000/600000, Loss: 3750.3005
Epoch 24000/600000, Loss: 3756.8862
Epoch 25000/600000, Loss: 3747.3430
Epoch 26000/600000, Loss: 3749.0422
Epoch 27000/600000, Loss: 3748.4749
Epoch 28000/600000, Loss: 3750.8484
E

KeyboardInterrupt: 