In [1]:
import sys
import idx2numpy
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset, Subset, random_split

def load_dataset2():
    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # 加载完整的 MNIST 训练数据集
    full_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

    # 随机选择 20,000 个样本
    subset_indices = torch.randperm(len(full_dataset))[:20000]
    subset_dataset = Subset(full_dataset, subset_indices)

    # 将 20,000 个样本分为 5,000 个训练集、5,000 个验证集和 10,000 个测试集
    train_set, val_set, test_set = random_split(subset_dataset, [5000, 5000, 10000])

    # 打乱训练集中的 2,500 个样本的标签
    rand_indices = torch.randperm(len(train_set))[:2500]
    for idx in rand_indices:
        # 随机生成一个新的标签
        new_label = torch.randint(0, 10, (1,)).item()
        train_set.dataset.dataset.targets[subset_indices[train_set.indices[idx]]] = new_label

    # 创建数据加载器
    trainloader = DataLoader(train_set, batch_size=64, shuffle=True)
    valloader = DataLoader(val_set, batch_size=64, shuffle=True)
    testloader = DataLoader(test_set, batch_size=64, shuffle=True)

    return trainloader, valloader, testloader

def test(net, testloader):
    # 测试网络
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')


trainloader, valloader, testloader = load_dataset2()

In [18]:
# 定义神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc(x)
        return x

net1 = SimpleNet()

N = 5000
la1 = torch.rand([N,1],requires_grad=True)

In [22]:
net = net1
la = la1

# 定义损失函数和优化器

def lower_function(output, label, la):
    crossentropy = nn.CrossEntropyLoss()
    loss = crossentropy(output, label)*la
    return loss

# SGD的效果明显要比Adam好不少
# optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=0.01)



# 第一次测试网络
test(net, testloader)

# 训练网络
for epoch in range(5):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = lower_function(outputs, labels, la[i])
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 200:.3f}')
    

test(net, testloader)

#la = la - 0.01 *  

for epoch in range(5):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = lower_function(outputs, labels, la[i])
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 200:.3f}')


print('Finished Training')

test(net, testloader)

Accuracy of the network on the 10000 test images: 46.89 %
[Epoch 1, Batch 1] loss: 0.009
[Epoch 1, Batch 2] loss: 0.011
[Epoch 1, Batch 3] loss: 0.025
[Epoch 1, Batch 4] loss: 0.026
[Epoch 1, Batch 5] loss: 0.044
[Epoch 1, Batch 6] loss: 0.053
[Epoch 1, Batch 7] loss: 0.072
[Epoch 1, Batch 8] loss: 0.085
[Epoch 1, Batch 9] loss: 0.087
[Epoch 1, Batch 10] loss: 0.101
[Epoch 1, Batch 11] loss: 0.101
[Epoch 1, Batch 12] loss: 0.109
[Epoch 1, Batch 13] loss: 0.113
[Epoch 1, Batch 14] loss: 0.116
[Epoch 1, Batch 15] loss: 0.129
[Epoch 1, Batch 16] loss: 0.146
[Epoch 1, Batch 17] loss: 0.153
[Epoch 1, Batch 18] loss: 0.157
[Epoch 1, Batch 19] loss: 0.164
[Epoch 1, Batch 20] loss: 0.172
[Epoch 1, Batch 21] loss: 0.173
[Epoch 1, Batch 22] loss: 0.181
[Epoch 1, Batch 23] loss: 0.183
[Epoch 1, Batch 24] loss: 0.188
[Epoch 1, Batch 25] loss: 0.201
[Epoch 1, Batch 26] loss: 0.210
[Epoch 1, Batch 27] loss: 0.220
[Epoch 1, Batch 28] loss: 0.226
[Epoch 1, Batch 29] loss: 0.237
[Epoch 1, Batch 30] los

In [10]:
print(torch.norm(la.grad))

tensor(90.8010)
