In [None]:
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:05<00:00, 30.6MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=4, shuffle=False, num_workers=2)

In [None]:
image, label = train_data[0]

In [None]:
image.shape

torch.Size([3, 32, 32])

In [None]:
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
class NeuralNet(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv2d(3, 12, 5) # (12, 28, 28)
    self.pool = nn.MaxPool2d(2, 2) # (12, 14, 14)
    self.conv2 = nn.Conv2d(12, 24, 5) # (24, 10, 10) -> (24, 5, 5) -> Flatten(24*5*5)
    self.fc1 = nn.Linear(24 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

In [None]:
net = NeuralNet()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [None]:
for epoch in range(32):
  print(f'Epoch: {epoch}')

  running_loss = 0.0
  for i, data in enumerate(train_loader):
    inputs, labels = data

    optimizer.zero_grad()

    outputs = net(inputs)
    loss = loss_function(outputs, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  print(f'Loss: {running_loss / len(train_loader)}')

Epoch: 0
Loss: 1.6560420803207159
Epoch: 1
Loss: 1.2450916041195392
Epoch: 2
Loss: 1.0864098874369263
Epoch: 3
Loss: 0.9778060414333641
Epoch: 4
Loss: 0.899499504712969
Epoch: 5
Loss: 0.83641225784295
Epoch: 6
Loss: 0.7810519574615267
Epoch: 7
Loss: 0.733043318966201
Epoch: 8
Loss: 0.6892662267069845
Epoch: 9
Loss: 0.6550530498489737
Epoch: 10
Loss: 0.6222647298683183
Epoch: 11
Loss: 0.5968580974472454
Epoch: 12
Loss: 0.5763508282537776
Epoch: 13
Loss: 0.5453994342619134
Epoch: 14
Loss: 0.5296987986037007
Epoch: 15
Loss: 0.5109307186098871
Epoch: 16
Loss: 0.4929766548344791
Epoch: 17
Loss: 0.4857164517558086
Epoch: 18
Loss: 0.46667004325976946
Epoch: 19
Loss: 0.46930931827220795
Epoch: 20
Loss: 0.45575498424338234
Epoch: 21
Loss: 0.4540142018400499
Epoch: 22
Loss: 0.4424499706206807
Epoch: 23
Loss: 0.43051646676719174
Epoch: 24
Loss: 0.43067259686184295
Epoch: 25
Loss: 0.42711296619608635
Epoch: 26
Loss: 0.4288386247212842
Epoch: 27
Loss: 0.4298539037530623
Epoch: 28
Loss: 0.4262251765

In [None]:
torch.save(net.state_dict(), './trained_net.pth')

In [None]:
correct = 0
wrong = 0
total = 0

net.eval()
with torch.no_grad():
  for data in test_loader:
    images, labels = data
    outputs = net(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy: {accuracy}%')

Accuracy: 63.63%


In [None]:
new_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def load_image(path):
  image = Image.open(path)
  image = new_transform(image)
  image = image.unsqueeze(0)
  return image

image_paths = ['/content/example1.jpg', '/content/example2.jpg']
images = [load_image(path) for path in image_paths]

net.eval()
with torch.no_grad():
  for image in images:
    output = net(image)
    _, predicted = torch.max(output.data, 1)
    print(f'Predicted class: {class_names[predicted.item()]}')

Predicted class: bird
Predicted class: deer
