### 微调最优单模型

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
import torch.nn.functional as F
import time



class SimpleMLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


# 1. 检查CUDA是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 2. 定义简单的多层感知机

seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def test_model(model, test_loader):
    model.eval()

    # 推理并计算准确度
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Accuracy on the test set: {accuracy:.2f}%")
    return accuracy


In [7]:
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

# 训练代码（不包括权重）
# 加载MNIST数据集
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=100, shuffle=True)

test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=100, shuffle=False)

# 4. 定义损失函数和优化器
model = SimpleMLP(784, 500, 10).to(device)
state_dict1 = torch.load("1_96.13%.pth")
model.load_state_dict(state_dict1)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)

epoch_times = []

# 5. 训练网络
num_epochs = 20
for epoch in range(num_epochs):
    total_loss = 0.0
    start_time = time.time()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (i+1) % 300 == 0:
            print(
                f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    # 计算每个epoch的时间
    end_time = time.time()
    elapsed_time = end_time - start_time
    epoch_times.append(elapsed_time)
    accuracy = test_model(model, test_loader)

    # 打印每轮epoch的average loss
    average_loss = total_loss / len(train_loader)
    print(
        f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {average_loss:.4f}')

# 6. 测试网络的性能
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(
        f'Accuracy of the model on the 10000 test images: {100 * correct / total} %')

# 7. 保存网络权重
torch.save(model.state_dict(), f'./{seed}++_{100 * correct / total}%.pth')


Epoch [1/20], Step [300/600], Loss: 0.0547
Epoch [1/20], Step [600/600], Loss: 0.1103
Accuracy on the test set: 97.15%
Epoch [1/20], Average Loss: 0.0603
Epoch [2/20], Step [300/600], Loss: 0.1117
Epoch [2/20], Step [600/600], Loss: 0.0235
Accuracy on the test set: 97.52%
Epoch [2/20], Average Loss: 0.0462
Epoch [3/20], Step [300/600], Loss: 0.0121
Epoch [3/20], Step [600/600], Loss: 0.0524
Accuracy on the test set: 97.53%
Epoch [3/20], Average Loss: 0.0397
Epoch [4/20], Step [300/600], Loss: 0.0231
Epoch [4/20], Step [600/600], Loss: 0.0274
Accuracy on the test set: 97.76%
Epoch [4/20], Average Loss: 0.0358
Epoch [5/20], Step [300/600], Loss: 0.0212
Epoch [5/20], Step [600/600], Loss: 0.0630
Accuracy on the test set: 97.59%
Epoch [5/20], Average Loss: 0.0308
Epoch [6/20], Step [300/600], Loss: 0.0400
Epoch [6/20], Step [600/600], Loss: 0.0313
Accuracy on the test set: 97.49%
Epoch [6/20], Average Loss: 0.0297
Epoch [7/20], Step [300/600], Loss: 0.0038
Epoch [7/20], Step [600/600], Los

### 微调融合单模型

In [12]:
# 修改网络，将可学习权重加进网络当中
class CombinedMLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, state_dicts):
        super(CombinedMLP, self).__init__()
        num_weights = len(state_dicts)

        # 动态创建权重列表
        self.weights = nn.ParameterList(
            [nn.Parameter(torch.tensor(1.0)) for _ in range(num_weights)])

        self.pretrained_weights = state_dicts
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(x.size(0), -1)

        # 使用softmax进行归一化
        norm_weights = F.softmax(torch.stack([w for w in self.weights]), dim=0)

        combined_weight1 = sum([norm_weights[i] * self.pretrained_weights[i]
                               ['fc1.weight'] for i in range(len(self.weights))])
        combined_bias1 = sum([norm_weights[i] * self.pretrained_weights[i]
                             ['fc1.bias'] for i in range(len(self.weights))])
        x = F.linear(x, combined_weight1, combined_bias1)

        x = self.relu(x)

        combined_weight2 = sum([norm_weights[i] * self.pretrained_weights[i]
                               ['fc2.weight'] for i in range(len(self.weights))])
        combined_bias2 = sum([norm_weights[i] * self.pretrained_weights[i]
                             ['fc2.bias'] for i in range(len(self.weights))])
        x = F.linear(x, combined_weight2, combined_bias2)

        return x


In [41]:
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

state_dict1 = torch.load("1_96.13%.pth")
state_dict2 = torch.load("2_95.8%.pth")
state_dict3 = torch.load("3_95.94%.pth")
state_dict4 = torch.load("1_92.17%.pth")
state_dict5 = torch.load("2_92.05%.pth")
state_dict6 = torch.load("3_92.03%.pth")

# state_dicts = [state_dict1, state_dict2, state_dict3,
#               state_dict4, state_dict5, state_dict6]
state_dicts = [state_dict1,  state_dict3,
               state_dict5, state_dict6]

unnormalized_weights = [3.501100540161133,
                        -0.06588193774223328, -1.0708355903625488, -1.384895920753479]

weights = F.softmax(torch.tensor(unnormalized_weights), dim=0)

model = SimpleMLP(28*28, 500, 10).to(device)


with torch.no_grad():
    model.fc1.weight.data = sum(
        [weights[i] * state_dicts[i]['fc1.weight'] for i in range(len(state_dicts))])
    model.fc1.bias.data = sum(
        [weights[i] * state_dicts[i]['fc1.bias'] for i in range(len(state_dicts))])
    model.fc2.weight.data = sum(
        [weights[i] * state_dicts[i]['fc2.weight'] for i in range(len(state_dicts))])
    model.fc2.bias.data = sum(
        [weights[i] * state_dicts[i]['fc2.bias'] for i in range(len(state_dicts))])

accuracy = test_model(model, test_loader)
print(f"Accuracy on the test set: {accuracy:.2f}%")

# 训练代码（不包括权重）
# 加载MNIST数据集
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=100, shuffle=True)

test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=100, shuffle=False)



# 3. 训练SimpleMLP
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to(device)

accuracy = test_model(model, test_loader)
print(f"Accuracy on the test set: {accuracy:.2f}%")

epoch_times = []

# 5. 训练网络
num_epochs = 20
for epoch in range(num_epochs):
    total_loss = 0.0
    start_time = time.time()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (i+1) % 300 == 0:
            print(
                f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    # 计算每个epoch的时间
    end_time = time.time()
    elapsed_time = end_time - start_time
    epoch_times.append(elapsed_time)
    accuracy = test_model(model, test_loader)

    # 打印每轮epoch的average loss
    average_loss = total_loss / len(train_loader)
    print(
        f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {average_loss:.4f}')

# 6. 测试网络的性能
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(
        f'Accuracy of the model on the 10000 test images: {100 * correct / total} %')

# 7. 保存网络权重
torch.save(model.state_dict(), f'./fuse1356_{100 * correct / total}%.pth')


Accuracy on the test set: 96.04%
Accuracy on the test set: 96.04%
Accuracy on the test set: 96.04%
Accuracy on the test set: 96.04%
Epoch [1/20], Step [300/600], Loss: 0.0726


KeyboardInterrupt: 

In [17]:
state_dict1 = torch.load("1_96.13%.pth")
state_dict2 = torch.load("2_95.8%.pth")
state_dict3 = torch.load("3_95.94%.pth")
state_dict4 = torch.load("1_92.17%.pth")
state_dict5 = torch.load("2_92.05%.pth")
state_dict6 = torch.load("3_92.03%.pth")

state_dict = [state_dict1, state_dict2, state_dict3,
              state_dict4, state_dict5, state_dict6]

model = CombinedMLP(
    28*28, 500, 10, state_dict).to(device)

specific_weights = [3.999155044555664,
                    0.18632793426513672, -1.046790599822998, 2.2161786556243896, -1.1982409954071045, -1.0864580869674683]

for i, weight_value in enumerate(specific_weights):
    model.weights[i].data = torch.tensor(weight_value).to(device)

accuracy = test_model(model, test_loader)
print(f"Accuracy on the test set: {accuracy:.2f}%")


Accuracy on the test set: 96.22%
Accuracy on the test set: 96.22%


In [18]:
m_model = CombinedMLP(
    28*28, 500, 10, state_dict).to(device)

accuracy = test_model(m_model, test_loader)


Accuracy on the test set: 38.76%


### 测试伪标签预测错误或者正确样本的鲁棒性

In [30]:
# 训练加权融合的模型
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=1, shuffle=True)
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=64, shuffle=False)


In [54]:
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

state_dict1 = torch.load("1_96.13%.pth")
state_dict2 = torch.load("2_95.8%.pth")
state_dict3 = torch.load("3_95.94%.pth")
state_dict4 = torch.load("1_92.17%.pth")
state_dict5 = torch.load("2_92.05%.pth")
state_dict6 = torch.load("3_92.03%.pth")

state_dict = [state_dict1, state_dict2, state_dict3,
              state_dict4, state_dict5, state_dict6]


def test_model(model, test_loader):
    model.eval()

    # 推理并计算准确度
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Accuracy on the test set: {accuracy:.2f}%")
    return accuracy

# train_times = []  # 用于记录每次训练的时间


# 从数据集中随机选择200个样本
indices = torch.randperm(len(train_dataset))

accuracies = []  # 用于记录每次训练后的精度

STEPS_PER_SAMPLE = 2  # 设置每个样本的训练步骤数

count = 0 # 记录处理的样本数
best_accuracy = 0.0
worst_accuracy = 100.0

max_index_count = 0  # 记录选择到最优模型的次数


for index in indices:
    # 对于每个样本，重置模型
    model = CombinedMLP(28*28, 500, 10, state_dict).to(device)
    optimizer = optim.Adam(model.weights.parameters(), lr=8)

    # 选择单个样本
    single_data = torch.utils.data.Subset(train_dataset, [index])
    single_loader = torch.utils.data.DataLoader(
        dataset=single_data, batch_size=1, shuffle=False)

    inputs, labels = next(iter(single_loader))
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = model(inputs)
    _, pseudo_labels = torch.max(outputs.data, 1)

    # # 如果伪标签与实际标签匹配，则继续下一个样本
    if (pseudo_labels != labels).item():
        continue

    for _ in range(STEPS_PER_SAMPLE):
        # for inputs, labels in single_loader:
        #     inputs, labels = inputs.to(device), labels.to(device)
        # outputs = model(inputs)
        # _, pseudo_labels = torch.max(outputs.data, 1)
        # print(pseudo_labels==labels)

        loss = criterion(outputs, pseudo_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        outputs = model(inputs)
        _, pseudo_labels = torch.max(outputs.data, 1)

    norm_weights = F.softmax(torch.stack([w for w in model.weights]), dim=0)
    max_index = torch.argmax(norm_weights).item()  # 获取最大值的索引
    if max_index == 0:
        max_index_count += 1

    # 每10下打印一次weights，以免没有更新
    if count % 10 == 0:
        print("Weights:", norm_weights)
        print("Index of the maximum weight:", max_index)


    # end_time = time.time()  # 结束计时
    # train_times.append(end_time - start_time)  # 计算并记录训练时间

    # 评估模型的性能
    accuracy = test_model(model, test_loader)
    accuracies.append(accuracy)
    print(f"count:{count} - Sample:{index}- Accuracy on the test set: {accuracy:.2f}%")

    if accuracy > best_accuracy:
        best_accuracy = accuracy
    if accuracy < worst_accuracy:
        worst_accuracy = accuracy

    count += 1
    if count >= 200:
        break

# 计算精度的mean和var
mean_accuracy = np.mean(accuracies)
var_accuracy = np.var(accuracies)

print(f"Best accuracy: {best_accuracy:.2f}%.")
print(f"Worst accuracy: {worst_accuracy:.2f}%.")
print(f"Average accuracy: {mean_accuracy:.2f}%.")
print(f"Variance of accuracy: {var_accuracy:.2f}%^2.")
print(f"Count of choosing the best model: {max_index_count}")


Weights: tensor([5.7115e-01, 1.1627e-13, 4.2885e-01, 1.1627e-13, 1.1627e-13, 1.5753e-06],
       device='cuda:0', grad_fn=<SoftmaxBackward0>)
Index of the maximum weight: 0
Accuracy on the test set: 80.48%
count:0 - Sample:58355- Accuracy on the test set: 80.48%
Accuracy on the test set: 95.80%
count:1 - Sample:33583- Accuracy on the test set: 95.80%
Accuracy on the test set: 96.12%
count:2 - Sample:17862- Accuracy on the test set: 96.12%
Accuracy on the test set: 96.13%
count:3 - Sample:28517- Accuracy on the test set: 96.13%
Accuracy on the test set: 96.14%
count:4 - Sample:29135- Accuracy on the test set: 96.14%
Accuracy on the test set: 96.15%
count:5 - Sample:54842- Accuracy on the test set: 96.15%
Accuracy on the test set: 73.63%
count:6 - Sample:1176- Accuracy on the test set: 73.63%
Accuracy on the test set: 96.13%
count:7 - Sample:20200- Accuracy on the test set: 96.13%
Accuracy on the test set: 95.57%
count:8 - Sample:46194- Accuracy on the test set: 95.57%
Accuracy on the te