In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)

In [49]:
print(torch.__version__)
print(torchvision.__version__)

1.3.0
0.4.1a0+d94043a


In [50]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [51]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
    
    def forward(self, t):
        t = self.conv1(t)
        t = F.relu(t) 
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = t.reshape(-1, 12*4*4)

        t = self.fc1(t)
        t = F.relu(t)

        t = self.fc2(t)
        t = F.relu(t)

        t = self.out(t)

        return t

In [52]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x11cb127f0>

In [53]:
network = Network()

In [54]:
data_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=10
)

In [55]:
batch = next(iter(data_loader))

In [56]:
images, labels = batch

In [57]:
images.shape

torch.Size([10, 1, 28, 28])

In [58]:
labels.shape

torch.Size([10])

In [59]:
preds = network(images)

In [60]:
preds.shape

torch.Size([10, 10])

In [61]:
preds

tensor([[ 0.1002,  0.0247,  0.0203,  0.0696, -0.0957,  0.1410, -0.1120, -0.0720,  0.1192,  0.0185],
        [ 0.1043,  0.0248,  0.0245,  0.0635, -0.0896,  0.1358, -0.1125, -0.0624,  0.1309,  0.0126],
        [ 0.1023,  0.0209,  0.0231,  0.0668, -0.1001,  0.1282, -0.1101, -0.0599,  0.1300,  0.0200],
        [ 0.1004,  0.0230,  0.0198,  0.0639, -0.0955,  0.1332, -0.1111, -0.0604,  0.1318,  0.0176],
        [ 0.1034,  0.0186,  0.0256,  0.0658, -0.0939,  0.1428, -0.1107, -0.0672,  0.1288,  0.0122],
        [ 0.1045,  0.0301,  0.0206,  0.0690, -0.0901,  0.1439, -0.1136, -0.0652,  0.1252,  0.0199],
        [ 0.1038,  0.0149,  0.0282,  0.0630, -0.1033,  0.1335, -0.1089, -0.0708,  0.1297,  0.0118],
        [ 0.1028,  0.0305,  0.0181,  0.0685, -0.0876,  0.1449, -0.1112, -0.0652,  0.1267,  0.0214],
        [ 0.1027,  0.0279,  0.0250,  0.0690, -0.1085,  0.1314, -0.1086, -0.0716,  0.1147,  0.0298],
        [ 0.0961,  0.0357,  0.0115,  0.0708, -0.1021,  0.1477, -0.1080, -0.0752,  0.1104,  0.0280]])

In [62]:
preds.argmax(dim=1)

tensor([5, 5, 8, 5, 5, 5, 5, 5, 5, 5])

In [63]:
labels

tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

In [64]:
preds.argmax(dim=1).eq(labels)

tensor([False, False, False, False, False, False, False, False,  True,  True])

In [65]:
preds.argmax(dim=1).eq(labels).sum()

tensor(2)

In [66]:
def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [67]:
get_num_correct(preds,labels)

2