In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from torchvision import transforms
import torchvision

import matplotlib.pyplot as plt

from collections import namedtuple

from sklearn.metrics import classification_report

In [3]:
def get_clases():
  classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  return classes

TrainTest = namedtuple('TrainTest', ['train', 'test'])
    
def prepare_data():

  transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
  ])
  transform_test = transforms.Compose([
    transforms.ToTensor()
  ])
  trainset = torchvision.datasets.CIFAR10(root='./data', download=True, train=True, transform=transform_train)
  testset = torchvision.datasets.CIFAR10(root='./data', download=True, train=False, transform=transform_test)
  return TrainTest(train=trainset, test=testset)

def prepare_loader(datasets):
  trainloader = DataLoader(dataset=datasets.train, batch_size=128, shuffle=True, num_workers=4)
  testloader = DataLoader(dataset=datasets.test, batch_size=128, shuffle=False, num_workers=4)
  return TrainTest(train=trainloader, test=testloader)

class VGG16(nn.Module):
  def __init__(self):
    super().__init__()
    self.features = self._make_features()
    self.classification_head = nn.Linear(in_features=512, out_features=10)

  def forward(self, x):
    out = self.features(x)
    out = out.view(out.size(0), -1)
    out = self.classification_head(out)
    return out

  def _make_features(self):
    config = [64,64,'MP',128,128,'MP',256,256,256,'MP',512,512,512,'MP',512,512,512,'MP']
    layers = []
    c_in = 3
    for c in config:
      if c == 'MP':
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
      else:
        layers += [nn.Conv2d(in_channels=c_in, out_channels=c, kernel_size=3, stride=1, padding=1),
                   nn.BatchNorm2d(num_features=c),
                   nn.ReLU6(inplace=True)]
        c_in = c
    return nn.Sequential(*layers)

def imshow(images, labels, predicted, target_names):
  img = torchvision.utils.make_grid(images)
  plt.imshow(img.permute(1,2,0).cpu().numpy())
  [print(target_names[c], end=' ') for c in list(labels.cpu().numpy()) ]
  print()
  [print(target_names[c], end=' ') for c in list(predicted.cpu().numpy()) ]
  print()

def train_epoch(epoch, model, loader, loss_func, optimizer, device):
  model.train()
  running_loss = 0.0
  reporting_steps = 60
  for i, (images, labels) in enumerate(loader):
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    loss = loss_func(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    if i % reporting_steps == reporting_steps-1:
      print(f"Epoch {epoch} step {i} ave_loss {running_loss/reporting_steps:.4f}")
      running_loss = 0.0

def test_epoch(epoch, model, loader, device):
  ytrue = []
  ypred = []
  with torch.no_grad():
    model.eval()
    
    for i, (images, labels) in enumerate(loader):
      images, labels = images.to(device), labels.to(device)
      outputs = model(images)
      _, predicted = torch.max(outputs, dim=1)

      ytrue += list(labels.cpu().numpy())
      ypred += list(predicted.cpu().numpy())

  return ypred, ytrue

def main(PATH='./model.pth'):
  classes = get_clases()
  datasets = prepare_data()
  # img, label = datasets.train[0]
  # plt.imshow(img)
  # print(classes[label], img.size)
  # print('train', len(datasets.train), 'test', len(datasets.test))
  
  loaders = prepare_loader(datasets)
  # images, labels = iter(loaders.train).next()
  # print(images.shape, labels.shape)

  device = torch.device("cuda:0")
  model = VGG16().to(device)
  # images, labels = iter(loaders.train).next()
  # outputs = model(images)
  # print(outputs.shape)
  # print(outputs[0])
  # _, predicted = torch.max(outputs, dim=1)
  # print(predicted)
  # imshow(images, labels, predicted, classes)

  loss_func = nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
  for epoch in range(10):
    train_epoch(epoch, model, loaders.train, loss_func, optimizer, device)
    ypred, ytrue = test_epoch(epoch, model, loaders.test, device)
    print(classification_report(ytrue, ypred, target_names=classes))

    torch.save(model.state_dict(), PATH)

  return model

model = main()

Files already downloaded and verified
Files already downloaded and verified
Epoch 0 step 59 ave_loss 2.0290
Epoch 0 step 119 ave_loss 1.7851
Epoch 0 step 179 ave_loss 1.5487
Epoch 0 step 239 ave_loss 1.3900
Epoch 0 step 299 ave_loss 1.3034
Epoch 0 step 359 ave_loss 1.1896
              precision    recall  f1-score   support

       plane       0.55      0.70      0.62      1000
         car       0.84      0.58      0.69      1000
        bird       0.37      0.55      0.44      1000
         cat       0.36      0.37      0.37      1000
        deer       0.45      0.56      0.50      1000
         dog       0.83      0.09      0.16      1000
        frog       0.54      0.85      0.66      1000
       horse       0.81      0.30      0.44      1000
        ship       0.63      0.82      0.71      1000
       truck       0.73      0.69      0.71      1000

    accuracy                           0.55     10000
   macro avg       0.61      0.55      0.53     10000
weighted avg       0.61

  cpuset_checked))


Epoch 1 step 59 ave_loss 1.0778
Epoch 1 step 119 ave_loss 1.0425
Epoch 1 step 179 ave_loss 0.9724
Epoch 1 step 239 ave_loss 0.9476
Epoch 1 step 299 ave_loss 0.9010
Epoch 1 step 359 ave_loss 0.8719
              precision    recall  f1-score   support

       plane       0.75      0.73      0.74      1000
         car       0.80      0.91      0.85      1000
        bird       0.74      0.35      0.47      1000
         cat       0.44      0.66      0.53      1000
        deer       0.57      0.73      0.64      1000
         dog       0.71      0.48      0.57      1000
        frog       0.62      0.87      0.72      1000
       horse       0.87      0.53      0.66      1000
        ship       0.78      0.86      0.82      1000
       truck       0.89      0.76      0.82      1000

    accuracy                           0.69     10000
   macro avg       0.72      0.69      0.68     10000
weighted avg       0.72      0.69      0.68     10000



  cpuset_checked))


Epoch 2 step 59 ave_loss 0.8019
Epoch 2 step 119 ave_loss 0.8206
Epoch 2 step 179 ave_loss 0.7511
Epoch 2 step 239 ave_loss 0.7372
Epoch 2 step 299 ave_loss 0.7232
Epoch 2 step 359 ave_loss 0.7113
              precision    recall  f1-score   support

       plane       0.76      0.82      0.79      1000
         car       0.96      0.81      0.88      1000
        bird       0.57      0.76      0.65      1000
         cat       0.78      0.34      0.47      1000
        deer       0.73      0.75      0.74      1000
         dog       0.58      0.77      0.66      1000
        frog       0.80      0.83      0.82      1000
       horse       0.82      0.78      0.80      1000
        ship       0.92      0.81      0.86      1000
       truck       0.80      0.91      0.85      1000

    accuracy                           0.76     10000
   macro avg       0.77      0.76      0.75     10000
weighted avg       0.77      0.76      0.75     10000



  cpuset_checked))


Epoch 3 step 59 ave_loss 0.6554
Epoch 3 step 119 ave_loss 0.6578
Epoch 3 step 179 ave_loss 0.6367
Epoch 3 step 239 ave_loss 0.6410
Epoch 3 step 299 ave_loss 0.6392
Epoch 3 step 359 ave_loss 0.6071
              precision    recall  f1-score   support

       plane       0.75      0.84      0.80      1000
         car       0.75      0.96      0.84      1000
        bird       0.78      0.55      0.65      1000
         cat       0.60      0.63      0.61      1000
        deer       0.69      0.73      0.71      1000
         dog       0.88      0.45      0.60      1000
        frog       0.89      0.79      0.84      1000
       horse       0.61      0.94      0.74      1000
        ship       0.93      0.76      0.83      1000
       truck       0.82      0.85      0.84      1000

    accuracy                           0.75     10000
   macro avg       0.77      0.75      0.74     10000
weighted avg       0.77      0.75      0.74     10000



  cpuset_checked))


Epoch 4 step 59 ave_loss 0.5795
Epoch 4 step 119 ave_loss 0.5637
Epoch 4 step 179 ave_loss 0.5642
Epoch 4 step 239 ave_loss 0.5841
Epoch 4 step 299 ave_loss 0.5558
Epoch 4 step 359 ave_loss 0.5375
              precision    recall  f1-score   support

       plane       0.88      0.72      0.79      1000
         car       0.96      0.81      0.88      1000
        bird       0.87      0.56      0.68      1000
         cat       0.55      0.69      0.61      1000
        deer       0.80      0.73      0.76      1000
         dog       0.56      0.84      0.68      1000
        frog       0.86      0.84      0.85      1000
       horse       0.92      0.70      0.80      1000
        ship       0.86      0.89      0.88      1000
       truck       0.75      0.95      0.84      1000

    accuracy                           0.77     10000
   macro avg       0.80      0.77      0.78     10000
weighted avg       0.80      0.77      0.78     10000



  cpuset_checked))


Epoch 5 step 59 ave_loss 0.4831
Epoch 5 step 119 ave_loss 0.5217
Epoch 5 step 179 ave_loss 0.5193
Epoch 5 step 239 ave_loss 0.5031
Epoch 5 step 299 ave_loss 0.4878
Epoch 5 step 359 ave_loss 0.4988
              precision    recall  f1-score   support

       plane       0.77      0.88      0.82      1000
         car       0.95      0.79      0.86      1000
        bird       0.75      0.75      0.75      1000
         cat       0.85      0.46      0.60      1000
        deer       0.81      0.75      0.78      1000
         dog       0.76      0.76      0.76      1000
        frog       0.85      0.88      0.86      1000
       horse       0.73      0.90      0.81      1000
        ship       0.90      0.83      0.86      1000
       truck       0.70      0.96      0.81      1000

    accuracy                           0.80     10000
   macro avg       0.81      0.80      0.79     10000
weighted avg       0.81      0.80      0.79     10000



  cpuset_checked))


Epoch 6 step 59 ave_loss 0.4542
Epoch 6 step 119 ave_loss 0.4598
Epoch 6 step 179 ave_loss 0.4537
Epoch 6 step 239 ave_loss 0.4544
Epoch 6 step 299 ave_loss 0.4533
Epoch 6 step 359 ave_loss 0.4468
              precision    recall  f1-score   support

       plane       0.88      0.81      0.84      1000
         car       0.94      0.91      0.93      1000
        bird       0.88      0.63      0.74      1000
         cat       0.47      0.88      0.62      1000
        deer       0.74      0.88      0.80      1000
         dog       0.91      0.44      0.59      1000
        frog       0.92      0.81      0.86      1000
       horse       0.93      0.82      0.87      1000
        ship       0.88      0.93      0.91      1000
       truck       0.89      0.93      0.91      1000

    accuracy                           0.80     10000
   macro avg       0.84      0.80      0.81     10000
weighted avg       0.84      0.80      0.81     10000



  cpuset_checked))


Epoch 7 step 59 ave_loss 0.4203
Epoch 7 step 119 ave_loss 0.4143
Epoch 7 step 179 ave_loss 0.4288
Epoch 7 step 239 ave_loss 0.4104
Epoch 7 step 299 ave_loss 0.4067
Epoch 7 step 359 ave_loss 0.4051
              precision    recall  f1-score   support

       plane       0.76      0.84      0.80      1000
         car       0.92      0.94      0.93      1000
        bird       0.82      0.71      0.76      1000
         cat       0.85      0.43      0.57      1000
        deer       0.79      0.85      0.82      1000
         dog       0.59      0.89      0.71      1000
        frog       0.91      0.85      0.88      1000
       horse       0.95      0.76      0.84      1000
        ship       0.76      0.97      0.85      1000
       truck       0.92      0.84      0.88      1000

    accuracy                           0.81     10000
   macro avg       0.83      0.81      0.80     10000
weighted avg       0.83      0.81      0.80     10000



  cpuset_checked))


Epoch 8 step 59 ave_loss 0.3959
Epoch 8 step 119 ave_loss 0.3720
Epoch 8 step 179 ave_loss 0.3859
Epoch 8 step 239 ave_loss 0.3852
Epoch 8 step 299 ave_loss 0.3926
Epoch 8 step 359 ave_loss 0.3720
              precision    recall  f1-score   support

       plane       0.78      0.80      0.79      1000
         car       0.95      0.93      0.94      1000
        bird       0.70      0.80      0.75      1000
         cat       0.71      0.69      0.70      1000
        deer       0.65      0.91      0.76      1000
         dog       0.81      0.68      0.74      1000
        frog       0.98      0.69      0.81      1000
       horse       0.98      0.71      0.82      1000
        ship       0.77      0.97      0.86      1000
       truck       0.93      0.89      0.91      1000

    accuracy                           0.81     10000
   macro avg       0.83      0.81      0.81     10000
weighted avg       0.83      0.81      0.81     10000



  cpuset_checked))


Epoch 9 step 59 ave_loss 0.3474
Epoch 9 step 119 ave_loss 0.3604
Epoch 9 step 179 ave_loss 0.3738
Epoch 9 step 239 ave_loss 0.3587
Epoch 9 step 299 ave_loss 0.3650
Epoch 9 step 359 ave_loss 0.3424
              precision    recall  f1-score   support

       plane       0.92      0.78      0.85      1000
         car       0.96      0.89      0.92      1000
        bird       0.77      0.81      0.79      1000
         cat       0.77      0.66      0.71      1000
        deer       0.70      0.94      0.80      1000
         dog       0.83      0.74      0.78      1000
        frog       0.88      0.91      0.90      1000
       horse       0.93      0.82      0.87      1000
        ship       0.88      0.93      0.91      1000
       truck       0.84      0.95      0.89      1000

    accuracy                           0.84     10000
   macro avg       0.85      0.84      0.84     10000
weighted avg       0.85      0.84      0.84     10000

