In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms



In [2]:
print(torch.__version__)
print(torchvision.__version__)

1.8.1+cu111
0.9.1+cu111


In [3]:
train_set = torchvision.datasets.FashionMNIST(
    root = 'data/FashionMNIST/',
    train = True,
    download = True,
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
)

In [4]:
class Network(nn.Module): 
    def __init__(self):                                                                
        super(Network,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self,t):
        t = F.relu(self.conv1(t))                          #简洁书写，将卷积与relu结合在一起
        t = F.max_pool2d(t,kernel_size=2,stride=2)     
        
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t,kernel_size=2,stride=2)
        
        t = F.relu(self.fc1(t.reshape(-1,12*4*4)))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        
        return t 

In [5]:
data_loader =torch.utils.data.DataLoader(
    dataset = train_set,
    batch_size = 10
)

In [6]:
batch = next(iter(data_loader))

In [7]:
images,labels = batch

In [8]:
images.shape

torch.Size([10, 1, 28, 28])

In [9]:
labels.shape

torch.Size([10])

In [10]:
network =Network()
preds = network(images)    #通过图像张量传递给网络来得到一个预测

In [11]:
preds.shape

torch.Size([10, 10])

In [12]:
preds                       #对于这10个图象，我们有10分割预测类，第一个数组的元素表示对应的每个类别的10个预测

tensor([[ 0.0806,  0.1091,  0.1149,  0.0842,  0.0857, -0.0113,  0.0329,  0.1492,
         -0.0322,  0.0383],
        [ 0.0785,  0.1127,  0.1178,  0.0863,  0.0906, -0.0102,  0.0311,  0.1514,
         -0.0374,  0.0397],
        [ 0.0844,  0.1090,  0.1176,  0.0829,  0.0863, -0.0105,  0.0348,  0.1477,
         -0.0334,  0.0351],
        [ 0.0831,  0.1106,  0.1174,  0.0860,  0.0874, -0.0084,  0.0340,  0.1499,
         -0.0360,  0.0362],
        [ 0.0845,  0.1091,  0.1195,  0.0852,  0.0929, -0.0112,  0.0393,  0.1526,
         -0.0403,  0.0411],
        [ 0.0796,  0.1095,  0.1177,  0.0892,  0.0889, -0.0153,  0.0339,  0.1552,
         -0.0357,  0.0409],
        [ 0.0759,  0.1070,  0.1233,  0.0873,  0.0843, -0.0100,  0.0319,  0.1462,
         -0.0364,  0.0341],
        [ 0.0887,  0.1082,  0.1230,  0.0897,  0.0864, -0.0160,  0.0428,  0.1536,
         -0.0335,  0.0410],
        [ 0.0809,  0.1042,  0.1192,  0.0814,  0.0872, -0.0056,  0.0310,  0.1464,
         -0.0350,  0.0348],
        [ 0.0806,  

In [13]:
preds.argmax(dim=1)

tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7])

In [14]:
labels

tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

In [15]:
preds.argmax(dim=1).eq(labels)   ###eq()判断是否等于


tensor([False, False, False, False, False, False,  True, False, False, False])

In [16]:
preds.argmax(dim=1).eq(labels).sum()

tensor(1)