* [PyTorch Tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)
* [Do not need softmax](https://stackoverflow.com/questions/55675345/should-i-use-softmax-as-output-when-using-cross-entropy-loss-in-pytorch)
* I learned how to implement LeNet5
* I learned how to train net on GPU

In [32]:
# I would like to plot my  diagram inside the cell.
%matplotlib inline

## My Customized Dataset Class
In order to implement this section, I read two very good references:
* [WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)
* [utkuozbulak/pytorch-custom-dataset-examples](https://github.com/utkuozbulak/pytorch-custom-dataset-examples)

In [33]:
from torch.utils.data.dataset import Dataset
from torchvision import transforms
import numpy as np
import pandas as pd
import cv2 as cv

class AnimalDataset(Dataset):
    def __init__(self, root_path, subset, height, width):
        self.root_path = root_path
        
        # Need it in __getitem__() to transform numpy array to tensor.
        self.to_tensor = transforms.ToTensor()
        
        # Read CSV Metadata and convert it to numpy format
        subset_txt = pd.read_csv(root_path + subset, sep=" ")
        self.np_subset_txt = np.array(subset_txt)
        
        # Need it in __len__()
        self.count = len(self.np_subset_txt)
        
        # Need it in __getitem__() to resize image to desired size
        self.height = height
        self.width = width
    
    def __getitem__(self, index):
        # Get Label
        label = int(self.np_subset_txt[index][1])
        
        # Get Image
        img = cv.imread(self.root_path + self.np_subset_txt[index][0])
        img_resize = cv.resize(img, (self.height, self.width))/255.0
        img_resize = img_resize.reshape(-1, self.height, self.width) # to make channel first
        img_resize_tensor = self.to_tensor(img_resize)
        return (img_resize, label)
    
    def __len__(self):
        return self.count

In [None]:
AC_root_path="C:/Users/rathe/Desktop/Dataset/images/"
root_path = "C:/Users/USER/Desktop/Projects/Github_Repo/AI/DeepLearning/__HW1_DATA/"
height = 30
width = 30

Train_Dataset = AnimalDataset(root_path = root_path, subset = "train.txt", height = height, width = width)
Test_Dataset = AnimalDataset(root_path = root_path, subset = "test.txt", height = height, width = width)
Val_Dataset = AnimalDataset(root_path = root_path, subset = "val.txt", height = height, width = width)

## Define My DataLoader r.s.t. AnimalDataset

In [35]:
import torch
batch_size = 50
shuffle = True

Train_DataLoader = torch.utils.data.DataLoader(dataset=Train_Dataset, batch_size = batch_size, shuffle=shuffle)
Test_DataLoader = torch.utils.data.DataLoader(dataset=Test_Dataset, batch_size = batch_size, shuffle=shuffle)
Val_DataLoader = torch.utils.data.DataLoader(dataset=Val_Dataset, batch_size = batch_size, shuffle=shuffle)

## Define a Convolutional Neural Network(Use GPU)

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Use GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 50)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 6 * 6)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        #x = F.softmax(x)
        return x
    
lenet = LeNet()
lenet.to(device)

cuda:0


LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=50, bias=True)
)

## Define a loss function and Set up Hyperparameters

In [37]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(lenet.parameters(), lr=1e-3)

## Train the network on the training data
# How to use validation dataset
# Try to implement it as a function to call

In [None]:
from tqdm import tqdm
import time

tic = time.time()
for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(Train_DataLoader, 0):
        inputs, labels = data[0].to(device), data[1].to(device) # GPU
        #inputs, labels = data # CPU
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = lenet(inputs.float())
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0
            correct = 0
            total = 0
            # Train
            with torch.no_grad(): # since we're not training, we don't need to calculate the gradients for our outputs
                for datum in tqdm(Train_DataLoader):
                    imgs, labs = datum[0].to(device), datum[1].to(device)
                    # calculate outputs by running images through the network 
                    outputs = lenet(imgs.float())
                    # the class with the highest energy is what we choose as prediction
                    _, preds = torch.max(outputs.data, 1)
                    total += labs.size(0)
                    correct += (preds == labs).sum().item()

            print('Accuracy of the network on the train images: %d %%' % (100 * correct / total))
            # Val
            with torch.no_grad(): # since we're not training, we don't need to calculate the gradients for our outputs
                for datum in tqdm(Val_DataLoader):
                    imgs, labs = datum[0].to(device), datum[1].to(device)
                    # calculate outputs by running images through the network 
                    outputs = lenet(imgs.float())
                    # the class with the highest energy is what we choose as prediction
                    _, preds = torch.max(outputs.data, 1)
                    total += labs.size(0)
                    correct += (preds == labs).sum().item()

            print('Accuracy of the network on the val images: %d %%' % (100 * correct / total))

toc = time.time()
print(toc - tic)
print('Finished Training')

  0%|                                                                                 | 1/1267 [00:00<02:45,  7.64it/s]

[1,   100] loss: 3.746


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:38<00:00,  8.01it/s]


Accuracy of the network on the test images: 5 %


  0%|                                                                                 | 1/1267 [00:00<02:55,  7.19it/s]

[1,   200] loss: 3.726


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:33<00:00,  8.25it/s]


Accuracy of the network on the test images: 5 %


  0%|                                                                                 | 1/1267 [00:00<02:54,  7.25it/s]

[1,   300] loss: 3.709


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:31<00:00,  8.38it/s]


Accuracy of the network on the test images: 6 %


  0%|                                                                                 | 1/1267 [00:00<03:18,  6.37it/s]

[1,   400] loss: 3.677


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:35<00:00,  8.16it/s]


Accuracy of the network on the test images: 6 %


  0%|                                                                                 | 1/1267 [00:00<02:43,  7.75it/s]

[1,   500] loss: 3.651


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:31<00:00,  8.34it/s]


Accuracy of the network on the test images: 7 %


  0%|                                                                                 | 1/1267 [00:00<02:54,  7.25it/s]

[1,   600] loss: 3.651


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.40it/s]


Accuracy of the network on the test images: 7 %


  0%|                                                                                 | 1/1267 [00:00<02:33,  8.27it/s]

[1,   700] loss: 3.618


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.46it/s]


Accuracy of the network on the test images: 7 %


  0%|                                                                                 | 1/1267 [00:00<02:49,  7.46it/s]

[1,   800] loss: 3.615


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.48it/s]


Accuracy of the network on the test images: 7 %


  0%|                                                                                 | 1/1267 [00:00<02:47,  7.58it/s]

[1,   900] loss: 3.608


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.49it/s]


Accuracy of the network on the test images: 7 %


  0%|                                                                                 | 1/1267 [00:00<02:50,  7.41it/s]

[1,  1000] loss: 3.584


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.50it/s]


Accuracy of the network on the test images: 7 %


  0%|                                                                                 | 1/1267 [00:00<02:48,  7.52it/s]

[1,  1100] loss: 3.594


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.48it/s]


Accuracy of the network on the test images: 7 %


  0%|                                                                                 | 1/1267 [00:00<02:45,  7.63it/s]

[1,  1200] loss: 3.582


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:28<00:00,  8.51it/s]


Accuracy of the network on the test images: 8 %


  0%|                                                                                 | 1/1267 [00:00<02:40,  7.87it/s]

[2,   100] loss: 3.562


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:28<00:00,  8.51it/s]


Accuracy of the network on the test images: 8 %


  0%|                                                                                 | 1/1267 [00:00<02:53,  7.30it/s]

[2,   200] loss: 3.571


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:28<00:00,  8.52it/s]


Accuracy of the network on the test images: 8 %


  0%|                                                                                 | 1/1267 [00:00<02:45,  7.63it/s]

[2,   300] loss: 3.558


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:27<00:00,  8.59it/s]


Accuracy of the network on the test images: 8 %


  0%|                                                                                 | 1/1267 [00:00<03:07,  6.76it/s]

[2,   400] loss: 3.547


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:27<00:00,  8.57it/s]


Accuracy of the network on the test images: 8 %


  0%|                                                                                 | 1/1267 [00:00<02:30,  8.40it/s]

[2,   500] loss: 3.538


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:28<00:00,  8.54it/s]


Accuracy of the network on the test images: 8 %


  0%|                                                                                 | 1/1267 [00:00<02:35,  8.13it/s]

[2,   600] loss: 3.538


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.50it/s]


Accuracy of the network on the test images: 9 %


  0%|                                                                                 | 1/1267 [00:00<02:33,  8.26it/s]

[2,   700] loss: 3.539


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:28<00:00,  8.51it/s]


Accuracy of the network on the test images: 9 %


  0%|                                                                                 | 1/1267 [00:00<02:35,  8.13it/s]

[2,   800] loss: 3.519


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.50it/s]


Accuracy of the network on the test images: 9 %


  0%|                                                                                 | 1/1267 [00:00<02:35,  8.13it/s]

[2,   900] loss: 3.524


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.49it/s]


Accuracy of the network on the test images: 9 %


  0%|                                                                                 | 1/1267 [00:00<02:33,  8.24it/s]

[2,  1000] loss: 3.500


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.43it/s]


Accuracy of the network on the test images: 9 %


  0%|                                                                                 | 1/1267 [00:00<02:35,  8.13it/s]

[2,  1100] loss: 3.506


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:35<00:00,  8.15it/s]


Accuracy of the network on the test images: 9 %


  0%|                                                                                 | 1/1267 [00:00<02:49,  7.46it/s]

[2,  1200] loss: 3.510


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:32<00:00,  8.29it/s]


Accuracy of the network on the test images: 9 %


  0%|                                                                                 | 1/1267 [00:00<02:53,  7.30it/s]

[3,   100] loss: 3.479


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.44it/s]


Accuracy of the network on the test images: 9 %


  0%|                                                                                 | 1/1267 [00:00<02:28,  8.51it/s]

[3,   200] loss: 3.500


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.47it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:43,  7.75it/s]

[3,   300] loss: 3.493


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.41it/s]


Accuracy of the network on the test images: 9 %


  0%|                                                                                 | 1/1267 [00:00<02:50,  7.41it/s]

[3,   400] loss: 3.467


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.42it/s]


Accuracy of the network on the test images: 9 %


  0%|                                                                                 | 1/1267 [00:00<02:43,  7.75it/s]

[3,   500] loss: 3.470


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.39it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:45,  7.64it/s]

[3,   600] loss: 3.474


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.44it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:36,  8.06it/s]

[3,   700] loss: 3.443


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.47it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:47,  7.58it/s]

[3,   800] loss: 3.443


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.44it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:47,  7.55it/s]

[3,   900] loss: 3.476


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.43it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:43,  7.75it/s]

[3,  1000] loss: 3.435


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.43it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:26,  8.62it/s]

[3,  1100] loss: 3.446


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.41it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:33,  8.26it/s]

[3,  1200] loss: 3.423


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.40it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:31,  8.33it/s]

[4,   100] loss: 3.417


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:33<00:00,  8.27it/s]


Accuracy of the network on the test images: 11 %


  0%|                                                                                 | 1/1267 [00:00<02:50,  7.41it/s]

[4,   200] loss: 3.413


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.45it/s]


Accuracy of the network on the test images: 11 %


  0%|                                                                                 | 1/1267 [00:00<02:33,  8.27it/s]

[4,   300] loss: 3.420


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.48it/s]


Accuracy of the network on the test images: 11 %


  0%|                                                                                 | 1/1267 [00:00<02:39,  7.93it/s]

[4,   400] loss: 3.390


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.50it/s]


Accuracy of the network on the test images: 11 %


  0%|                                                                                 | 1/1267 [00:00<02:54,  7.25it/s]

[4,   500] loss: 3.398


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:28<00:00,  8.50it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:47,  7.57it/s]

[4,   600] loss: 3.413


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.49it/s]


Accuracy of the network on the test images: 11 %


  0%|                                                                                 | 1/1267 [00:00<02:53,  7.30it/s]

[4,   700] loss: 3.414


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.50it/s]


Accuracy of the network on the test images: 11 %


  0%|                                                                                 | 1/1267 [00:00<02:45,  7.64it/s]

[4,   800] loss: 3.410


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.48it/s]


Accuracy of the network on the test images: 10 %


  0%|                                                                                 | 1/1267 [00:00<02:51,  7.38it/s]

[4,   900] loss: 3.388


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:31<00:00,  8.37it/s]


Accuracy of the network on the test images: 11 %


  0%|                                                                                 | 1/1267 [00:00<02:43,  7.75it/s]

[4,  1000] loss: 3.380


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:33<00:00,  8.24it/s]


Accuracy of the network on the test images: 11 %


  0%|                                                                                 | 1/1267 [00:00<02:47,  7.58it/s]

[4,  1100] loss: 3.369


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:33<00:00,  8.25it/s]


Accuracy of the network on the test images: 11 %


  0%|                                                                                 | 1/1267 [00:00<02:54,  7.27it/s]

[4,  1200] loss: 3.392


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:34<00:00,  8.19it/s]


Accuracy of the network on the test images: 12 %


  0%|                                                                                 | 1/1267 [00:00<02:54,  7.24it/s]

[5,   100] loss: 3.341


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:31<00:00,  8.34it/s]


Accuracy of the network on the test images: 12 %


  0%|                                                                                 | 1/1267 [00:00<02:34,  8.20it/s]

[5,   200] loss: 3.345


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:31<00:00,  8.39it/s]


Accuracy of the network on the test images: 11 %


  0%|                                                                                 | 1/1267 [00:00<02:42,  7.81it/s]

[5,   300] loss: 3.347


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:31<00:00,  8.38it/s]


Accuracy of the network on the test images: 12 %


  0%|                                                                                 | 1/1267 [00:00<02:45,  7.63it/s]

[5,   400] loss: 3.362


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:32<00:00,  8.29it/s]


Accuracy of the network on the test images: 12 %


  0%|                                                                                 | 1/1267 [00:00<02:44,  7.69it/s]

[5,   500] loss: 3.339


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:31<00:00,  8.37it/s]


Accuracy of the network on the test images: 12 %


  0%|                                                                                 | 1/1267 [00:00<02:44,  7.69it/s]

[5,   600] loss: 3.358


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.41it/s]


Accuracy of the network on the test images: 12 %


  0%|                                                                                 | 1/1267 [00:00<02:54,  7.25it/s]

[5,   700] loss: 3.311


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.39it/s]


Accuracy of the network on the test images: 12 %


  0%|                                                                                 | 1/1267 [00:00<02:42,  7.81it/s]

[5,   800] loss: 3.371


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:28<00:00,  8.52it/s]


Accuracy of the network on the test images: 12 %


  0%|                                                                                 | 1/1267 [00:00<02:38,  8.00it/s]

[5,   900] loss: 3.343


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.50it/s]


Accuracy of the network on the test images: 12 %


  0%|                                                                                 | 1/1267 [00:00<02:41,  7.84it/s]

[5,  1000] loss: 3.331


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.46it/s]


Accuracy of the network on the test images: 12 %


  0%|                                                                                 | 1/1267 [00:00<02:30,  8.40it/s]

[5,  1100] loss: 3.347


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.50it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:38,  8.00it/s]

[5,  1200] loss: 3.336


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.47it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                         | 0/1267 [00:00<?, ?it/s]

[6,   100] loss: 3.324


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [04:35<00:00,  4.59it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:42,  7.81it/s]

[6,   200] loss: 3.312


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.39it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:53,  7.30it/s]

[6,   300] loss: 3.310


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:28<00:00,  8.54it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:30,  8.40it/s]

[6,   400] loss: 3.299


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.42it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:24,  8.77it/s]

[6,   500] loss: 3.318


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:29<00:00,  8.47it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:45,  7.63it/s]

[6,   600] loss: 3.302


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:27<00:00,  8.59it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:25,  8.70it/s]

[6,   700] loss: 3.285


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:27<00:00,  8.59it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:39,  7.94it/s]

[6,   800] loss: 3.300


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:26<00:00,  8.64it/s]


Accuracy of the network on the test images: 14 %


  0%|                                                                                 | 1/1267 [00:00<02:35,  8.13it/s]

[6,   900] loss: 3.309


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:26<00:00,  8.63it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:36,  8.06it/s]

[6,  1000] loss: 3.273


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:27<00:00,  8.61it/s]


Accuracy of the network on the test images: 14 %


  0%|                                                                                 | 1/1267 [00:00<02:45,  7.63it/s]

[6,  1100] loss: 3.271


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:28<00:00,  8.52it/s]


Accuracy of the network on the test images: 14 %


  0%|                                                                                 | 1/1267 [00:00<02:16,  9.26it/s]

[6,  1200] loss: 3.281


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:34<00:00,  8.21it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:58,  7.09it/s]

[7,   100] loss: 3.258


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:30<00:00,  8.39it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:21,  8.93it/s]

[7,   200] loss: 3.241


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:35<00:00,  8.15it/s]


Accuracy of the network on the test images: 13 %


  0%|                                                                                 | 1/1267 [00:00<02:47,  7.58it/s]

[7,   300] loss: 3.260


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:31<00:00,  8.35it/s]


Accuracy of the network on the test images: 14 %


  0%|                                                                                 | 1/1267 [00:00<02:42,  7.81it/s]

[7,   400] loss: 3.269


100%|██████████████████████████████████████████████████████████████████████████████| 1267/1267 [02:31<00:00,  8.38it/s]


Accuracy of the network on the test images: 14 %


  0%|                                                                                 | 1/1267 [00:00<02:53,  7.30it/s]

[7,   500] loss: 3.265


 12%|█████████▌                                                                     | 154/1267 [00:19<02:29,  7.46it/s]

## Save the model for later usage

In [29]:
PATH = './LeNet.pth'
torch.save(net.state_dict(), PATH)

## Test the network on the test data

In [31]:
lenet = LeNet()
lenet.load_state_dict(torch.load(PATH))

correct = 0
total = 0

with torch.no_grad(): # since we're not training, we don't need to calculate the gradients for our outputs
    for data in Test_DataLoader:
        images, labels = data
        # calculate outputs by running images through the network 
        outputs = lenet(images.float())
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

Accuracy of the network on the test images: 2 %
