In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

In [2]:
from matplotlib.pyplot import imshow
%matplotlib inline

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# for reproducibility
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

In [4]:
trans = transforms.Compose([
    transforms.Resize((64,128))
])

train_data = torchvision.datasets.ImageFolder(root='custom_data/', transform=trans)

In [5]:
for num, value in enumerate(train_data):
    data, label = value
    print(num, data, label)
    data.save('custom_data/changed_size/%d.jpg'%(num))

0 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 0
1 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 0
2 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 0
3 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 0
4 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 0
5 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 0
6 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 0
7 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 0
8 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 0
9 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 0
10 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 0
11 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 0
12 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 0
13 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 0
14 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 0
15 <P

134 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
135 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
136 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
137 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
138 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
139 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
140 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
141 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
142 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
143 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
144 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
145 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
146 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
147 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
148 <PIL.Image.Image image mode=RGB size=128x64 

293 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
294 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
295 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
296 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
297 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
298 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
299 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
300 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
301 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
302 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
303 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
304 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
305 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
306 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
307 <PIL.Image.Image image mode=RGB size=128x64 

459 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
460 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
461 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
462 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
463 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
464 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
465 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
466 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
467 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
468 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
469 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
470 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
471 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B769E5A130> 1
472 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
473 <PIL.Image.Image image mode=RGB size=128x64 

605 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79760> 1
606 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
607 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79850> 1
608 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
609 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79610> 1
610 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
611 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79760> 1
612 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
613 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79850> 1
614 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
615 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79820> 1
616 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
617 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF796A0> 1
618 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
619 <PIL.Image.Image image mode=RGB size=128x64 

729 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF794C0> 1
730 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
731 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79460> 1
732 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
733 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79850> 1
734 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
735 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79760> 1
736 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
737 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79610> 1
738 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
739 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF794C0> 1
740 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
741 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79880> 1
742 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
743 <PIL.Image.Image image mode=RGB size=128x64 

884 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
885 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79460> 1
886 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
887 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79820> 1
888 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
889 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF795E0> 1
890 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
891 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79790> 1
892 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
893 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF796A0> 1
894 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
895 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF797F0> 1
896 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
897 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79760> 1
898 <PIL.Image.Image image mode=RGB size=128x64 

1009 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79820> 1
1010 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
1011 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79610> 1
1012 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
1013 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79790> 1
1014 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
1015 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF796A0> 1
1016 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
1017 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF797F0> 1
1018 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
1019 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79640> 1
1020 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
1021 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF795E0> 1
1022 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
1023 <PIL.Image.Image image mode=R

1188 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
1189 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF798E0> 1
1190 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
1191 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF798E0> 1
1192 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
1193 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF798B0> 1
1194 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
1195 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF798E0> 1
1196 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
1197 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF79880> 1
1198 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DE8DC70> 1
1199 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF798E0> 1
1200 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF5CEB0> 1
1201 <PIL.Image.Image image mode=RGB size=128x64 at 0x1B76DF798B0> 1
1202 <PIL.Image.Image image mode=R

In [6]:
#이미지 사이즈 조정 완료

In [7]:
trans = transforms.Compose([
    transforms.ToTensor()
])

train_data = torchvision.datasets.ImageFolder(root='./custom_data/changed_size/', transform=trans)

In [8]:
data_loader = DataLoader(dataset = train_data, batch_size = 8, shuffle = True, num_workers=2)

In [9]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3,6,5),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(6,16,5),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.layer3 = nn.Sequential(
            nn.Linear(16*13*29, 120),
            nn.ReLU(),
            nn.Linear(120,2)
        )
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.shape[0], -1)
        out = self.layer3(out)
        return out

In [10]:
net = CNN().to(device)
test_input = (torch.Tensor(3,3,64,128)).to(device)
test_out = net(test_input)

In [11]:
optimizer = optim.Adam(net.parameters(), lr=0.000005)
loss_func = nn.CrossEntropyLoss().to(device)

In [12]:
total_batch = len(data_loader)

epochs = 7
for epoch in range(epochs):
    avg_cost = 0.0
    for num, data in enumerate(data_loader):
        imgs, labels = data
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        out = net(imgs)
        loss = loss_func(out, labels)
        loss.backward()
        optimizer.step()
        
        avg_cost += loss / total_batch
        
    print('[Epoch:{}] cost = {}'.format(epoch+1, avg_cost))
print('Learning Finished!')

[Epoch:1] cost = 0.6961687803268433
[Epoch:2] cost = 0.6630303263664246
[Epoch:3] cost = 0.6314494013786316
[Epoch:4] cost = 0.6013171672821045
[Epoch:5] cost = 0.57163006067276
[Epoch:6] cost = 0.5415332913398743
[Epoch:7] cost = 0.5114163160324097
Learning Finished!


In [13]:
torch.save(net.state_dict(), "./model/model.pth")

In [14]:
new_net = CNN().to(device)

In [15]:
new_net.load_state_dict(torch.load('./model/model.pth'))

<All keys matched successfully>

In [16]:
print(net.layer1[0])
print(new_net.layer1[0])

print(net.layer1[0].weight[0][0][0])
print(new_net.layer1[0].weight[0][0][0])

net.layer1[0].weight[0] == new_net.layer1[0].weight[0]

Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
tensor([-0.0962, -0.0017, -0.0220, -0.0260,  0.0884], grad_fn=<SelectBackward>)
tensor([-0.0962, -0.0017, -0.0220, -0.0260,  0.0884], grad_fn=<SelectBackward>)


tensor([[[True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True]],

        [[True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True]],

        [[True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True]]])

In [18]:
trans=torchvision.transforms.Compose([
    transforms.Resize((64,128)),
    transforms.ToTensor()
])
test_data = torchvision.datasets.ImageFolder(root='test_it/', transform=trans)

In [19]:
for num, value in enumerate(test_data):
    data, label = value
    print(num, data, label)

0 tensor([[[0.2863, 0.2784, 0.3020,  ..., 0.9765, 0.9765, 0.9765],
         [0.2667, 0.2863, 0.2980,  ..., 0.9765, 0.9765, 0.9765],
         [0.2353, 0.2549, 0.2588,  ..., 0.9765, 0.9765, 0.9765],
         ...,
         [0.4157, 0.4157, 0.4118,  ..., 0.7882, 0.7843, 0.7882],
         [0.4196, 0.4118, 0.4118,  ..., 0.7804, 0.7843, 0.7882],
         [0.4078, 0.4000, 0.3843,  ..., 0.7804, 0.7882, 0.7961]],

        [[0.4784, 0.5020, 0.5529,  ..., 0.9804, 0.9804, 0.9804],
         [0.4706, 0.5333, 0.5569,  ..., 0.9804, 0.9804, 0.9804],
         [0.4235, 0.4902, 0.5216,  ..., 0.9804, 0.9804, 0.9804],
         ...,
         [0.3569, 0.3569, 0.3490,  ..., 0.3569, 0.3569, 0.3608],
         [0.3608, 0.3529, 0.3490,  ..., 0.3529, 0.3569, 0.3412],
         [0.3490, 0.3412, 0.3255,  ..., 0.3451, 0.3451, 0.3529]],

        [[0.5608, 0.6000, 0.6510,  ..., 0.9961, 0.9961, 0.9961],
         [0.5647, 0.6353, 0.6549,  ..., 0.9961, 0.9961, 0.9961],
         [0.5098, 0.5843, 0.6157,  ..., 0.9961, 0.9961, 

16 tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]]) 0
17 tensor([[[0.2118, 0.2118, 0.2118,  ..., 0.2078, 0.1882, 0.2000],
         [0.2118, 0.2118, 0.2118,  ..., 0.2078, 0.1882, 0.2000],
         [0.2118, 0.2118, 0.2118,  ..., 0.2078, 0.1882, 0.2000],
         ..

29 tensor([[[0.3922, 0.3882, 0.3882,  ..., 0.4667, 0.4667, 0.4745],
         [0.3882, 0.3843, 0.3843,  ..., 0.4549, 0.4549, 0.4627],
         [0.3804, 0.3804, 0.3804,  ..., 0.4471, 0.4471, 0.4549],
         ...,
         [0.2196, 0.2157, 0.2157,  ..., 0.2039, 0.2039, 0.2000],
         [0.2235, 0.2235, 0.2196,  ..., 0.2000, 0.2000, 0.1922],
         [0.2275, 0.2275, 0.2431,  ..., 0.1961, 0.1961, 0.1882]],

        [[0.3882, 0.3843, 0.3843,  ..., 0.4667, 0.4667, 0.4745],
         [0.3843, 0.3804, 0.3804,  ..., 0.4549, 0.4549, 0.4627],
         [0.3765, 0.3765, 0.3765,  ..., 0.4471, 0.4471, 0.4549],
         ...,
         [0.2118, 0.2078, 0.2078,  ..., 0.2000, 0.2000, 0.1961],
         [0.2196, 0.2157, 0.2157,  ..., 0.1961, 0.1961, 0.1882],
         [0.2196, 0.2196, 0.2353,  ..., 0.1922, 0.1922, 0.1843]],

        [[0.3725, 0.3686, 0.3686,  ..., 0.4627, 0.4588, 0.4667],
         [0.3686, 0.3647, 0.3647,  ..., 0.4471, 0.4471, 0.4549],
         [0.3608, 0.3569, 0.3608,  ..., 0.4392, 0.4392,

37 tensor([[[0.9647, 0.9137, 0.9412,  ..., 1.0000, 1.0000, 1.0000],
         [0.9529, 0.8627, 0.9137,  ..., 1.0000, 1.0000, 1.0000],
         [0.9686, 0.9608, 0.9569,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [0.9412, 0.9373, 0.9451,  ..., 1.0000, 1.0000, 1.0000],
         [0.8863, 0.9020, 0.9333,  ..., 1.0000, 1.0000, 1.0000],
         [0.9490, 0.9569, 0.9608,  ..., 1.0000, 1.0000, 1.0000]],

        [[0.9608, 0.9137, 0.9412,  ..., 1.0000, 1.0000, 1.0000],
         [0.9529, 0.8627, 0.9176,  ..., 1.0000, 1.0000, 1.0000],
         [0.9686, 0.9608, 0.9569,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [0.9412, 0.9373, 0.9451,  ..., 1.0000, 1.0000, 1.0000],
         [0.8902, 0.9020, 0.9373,  ..., 1.0000, 1.0000, 1.0000],
         [0.9490, 0.9569, 0.9608,  ..., 1.0000, 1.0000, 1.0000]],

        [[0.9608, 0.9098, 0.9451,  ..., 1.0000, 1.0000, 1.0000],
         [0.9490, 0.8588, 0.9216,  ..., 1.0000, 1.0000, 1.0000],
         [0.9686, 0.9569, 0.9569,  ..., 1.0000, 1.0000,

50 tensor([[[0.2431, 0.2471, 0.2549,  ..., 0.0275, 0.0667, 0.1098],
         [0.2471, 0.2549, 0.2588,  ..., 0.0392, 0.0745, 0.1176],
         [0.2510, 0.2627, 0.2745,  ..., 0.0353, 0.0863, 0.1294],
         ...,
         [0.1176, 0.0941, 0.0824,  ..., 0.2000, 0.2157, 0.1922],
         [0.1137, 0.1059, 0.1020,  ..., 0.3255, 0.2863, 0.2157],
         [0.3137, 0.3059, 0.2941,  ..., 0.4118, 0.3490, 0.2706]],

        [[0.2392, 0.2431, 0.2510,  ..., 0.3137, 0.2980, 0.2588],
         [0.2353, 0.2431, 0.2510,  ..., 0.3059, 0.2863, 0.2510],
         [0.2275, 0.2353, 0.2431,  ..., 0.2980, 0.2784, 0.2431],
         ...,
         [0.1490, 0.1294, 0.1176,  ..., 0.3569, 0.3804, 0.3608],
         [0.1569, 0.1490, 0.1412,  ..., 0.4784, 0.4314, 0.3569],
         [0.3569, 0.3490, 0.3333,  ..., 0.5451, 0.4824, 0.4039]],

        [[0.4667, 0.4706, 0.4784,  ..., 0.8039, 0.7608, 0.6784],
         [0.4627, 0.4745, 0.4824,  ..., 0.7961, 0.7451, 0.6588],
         [0.4627, 0.4706, 0.4824,  ..., 0.7804, 0.7255,

In [20]:
test_set = DataLoader(dataset = test_data, batch_size = len(test_data))

In [21]:
with torch.no_grad():
    for num, data in enumerate(test_set):
        imgs, label = data
        imgs = imgs.to(device)
        label = label.to(device)
        
        prediction = net(imgs)
        
        correct_prediction = torch.argmax(prediction, 1) == label
        
        accuracy = correct_prediction.float().mean()
        print('Accuracy:', accuracy.item())

Accuracy: 1.0
