In [69]:
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

In [71]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

In [73]:
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True, num_workers=2)

In [74]:
image, label = train_data[0]

In [75]:
image.size()

torch.Size([3, 32, 32])

In [76]:
class_names = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']

In [77]:
class NeuralNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 12, 5) #(12, 28, 28)
        self.pool = nn.MaxPool2d(2,2) #(12, 14, 14)
        self.conv2 = nn.Conv2d(12, 24, 5) #(24, 10, 10) -> Flatten (24, 5, 5)
        self.fc1 = nn.Linear(24*5*5, 120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [78]:
net = NeuralNet()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.001, momentum=0.9)

In [79]:
for epoch in tqdm(range(30)):
    print(f'Training epoch {epoch}...')

    running_loss = 0.0

    for i, data in enumerate(train_loader):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)

        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Loss: {running_loss/len(train_loader):.4f}')
        

  0%|                                                                                           | 0/30 [00:00<?, ?it/s]

Training epoch 0...


  3%|██▊                                                                                | 1/30 [00:22<10:50, 22.44s/it]

Loss: 2.2490
Training epoch 1...


  7%|█████▌                                                                             | 2/30 [00:43<10:04, 21.58s/it]

Loss: 1.7823
Training epoch 2...


 10%|████████▎                                                                          | 3/30 [01:04<09:39, 21.47s/it]

Loss: 1.5197
Training epoch 3...


 13%|███████████                                                                        | 4/30 [01:29<09:50, 22.73s/it]

Loss: 1.3989
Training epoch 4...


 17%|█████████████▊                                                                     | 5/30 [01:51<09:18, 22.33s/it]

Loss: 1.3091
Training epoch 5...


 20%|████████████████▌                                                                  | 6/30 [02:16<09:23, 23.49s/it]

Loss: 1.2193
Training epoch 6...


 23%|███████████████████▎                                                               | 7/30 [02:44<09:31, 24.86s/it]

Loss: 1.1463
Training epoch 7...


 27%|██████████████████████▏                                                            | 8/30 [03:09<09:09, 24.98s/it]

Loss: 1.0862
Training epoch 8...


 30%|████████████████████████▉                                                          | 9/30 [03:36<08:57, 25.58s/it]

Loss: 1.0359
Training epoch 9...


 33%|███████████████████████████▎                                                      | 10/30 [03:58<08:09, 24.49s/it]

Loss: 0.9913
Training epoch 10...


 37%|██████████████████████████████                                                    | 11/30 [04:28<08:14, 26.04s/it]

Loss: 0.9524
Training epoch 11...


 40%|████████████████████████████████▊                                                 | 12/30 [04:54<07:48, 26.03s/it]

Loss: 0.9184
Training epoch 12...


 43%|███████████████████████████████████▌                                              | 13/30 [05:21<07:28, 26.38s/it]

Loss: 0.8850
Training epoch 13...


 47%|██████████████████████████████████████▎                                           | 14/30 [05:45<06:51, 25.74s/it]

Loss: 0.8510
Training epoch 14...


 50%|█████████████████████████████████████████                                         | 15/30 [06:10<06:21, 25.40s/it]

Loss: 0.8215
Training epoch 15...


 53%|███████████████████████████████████████████▋                                      | 16/30 [06:39<06:12, 26.63s/it]

Loss: 0.7962
Training epoch 16...


 57%|██████████████████████████████████████████████▍                                   | 17/30 [07:03<05:36, 25.90s/it]

Loss: 0.7681
Training epoch 17...


 60%|█████████████████████████████████████████████████▏                                | 18/30 [07:27<05:01, 25.15s/it]

Loss: 0.7406
Training epoch 18...


 63%|███████████████████████████████████████████████████▉                              | 19/30 [07:52<04:35, 25.02s/it]

Loss: 0.7145
Training epoch 19...


 67%|██████████████████████████████████████████████████████▋                           | 20/30 [08:16<04:08, 24.86s/it]

Loss: 0.6920
Training epoch 20...


 70%|█████████████████████████████████████████████████████████▍                        | 21/30 [08:46<03:58, 26.53s/it]

Loss: 0.6680
Training epoch 21...


 73%|████████████████████████████████████████████████████████████▏                     | 22/30 [09:08<03:19, 24.98s/it]

Loss: 0.6463
Training epoch 22...


 77%|██████████████████████████████████████████████████████████████▊                   | 23/30 [09:29<02:47, 23.92s/it]

Loss: 0.6258
Training epoch 23...


 80%|█████████████████████████████████████████████████████████████████▌                | 24/30 [09:50<02:18, 23.06s/it]

Loss: 0.5994
Training epoch 24...


 83%|████████████████████████████████████████████████████████████████████▎             | 25/30 [10:11<01:52, 22.43s/it]

Loss: 0.5790
Training epoch 25...


 87%|███████████████████████████████████████████████████████████████████████           | 26/30 [10:32<01:27, 22.00s/it]

Loss: 0.5567
Training epoch 26...


 90%|█████████████████████████████████████████████████████████████████████████▊        | 27/30 [10:54<01:05, 21.77s/it]

Loss: 0.5410
Training epoch 27...


 93%|████████████████████████████████████████████████████████████████████████████▌     | 28/30 [11:15<00:43, 21.53s/it]

Loss: 0.5190
Training epoch 28...


 97%|███████████████████████████████████████████████████████████████████████████████▎  | 29/30 [11:36<00:21, 21.51s/it]

Loss: 0.4997
Training epoch 29...


100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [11:57<00:00, 23.92s/it]

Loss: 0.4837





In [91]:
torch.save(net.state_dict(), 'train_net.pth')

In [93]:
net = NeuralNet()
net.load_state_dict(torch.load('train_net.pth'))

<All keys matched successfully>

In [95]:
correct = 0
total = 0

net.eval()

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100*correct/total

print(f'Accuracy:{accuracy}%')

Accuracy:68.4%


In [96]:
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

def load_image(image_path):
    image = Image.open(image_path)
    image = new_transform(image)
    image = image.unsqueeze(0)
    return image

image_paths = ['car.png', 'cat.jpg']
image = [load_image(img) for img in image_paths]

net.eval()
with torch.no_grad():
    for iamge in images:
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        print(f'Prediction: {class_name[predicted.item()]}')

FileNotFoundError: [Errno 2] No such file or directory: 'car.png'