In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            (0.1307,), (0.3081,)
        ),  # 注意Mnist是灰度图，彩色图像应该是  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

batch_size = 8
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2
)

classes = ("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")


  warn(f"Failed to load image Python extension: {e}")


In [7]:
import plotly.express as px

#获取一些随机训练图像
dataiter = iter(trainloader)
images, labels = next(dataiter)

# make_grid函数把4维的数据拼接为3维的rgb图像(首位是3),默认padding是2，pad_value=0
images = torchvision.utils.make_grid(images, pad_value=0.5)
fig = px.imshow(images.permute(1, 2, 0), 0, 1)
# 打印标签
fig.update_layout(title_text="标签值：" + "".join(f"{classes[labels[j]]:5s}"
                                              for j in range(batch_size)))
fig.show()


In [10]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # 除了batch展平所有维度
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
net.to(device)


Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [11]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)


In [13]:
for epoch in range(5):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        print(inputs.shape)
        print(outputs.shape)
        print(labels.shape)
        loss = criterion(outputs, labels)
        print(loss.shape)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 1000 == 999:  # print every 2000 mini-batches
            print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 1000:.3f}")
            running_loss = 0.0

print("Finished Training")


torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
torch.Size([])
torch.Size([8, 1, 28, 28])
torch.Size([8, 10])
torch.Size([8])
t

In [6]:
PATH = "./mnist_lenet5.pth"
torch.save(net.state_dict(), PATH)
