In [1]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torchvision.models import efficientnet_b0
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from torchvision.datasets import CIFAR10
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
train_dataset = CIFAR10(root="./data", train=True,
                        download=True, transform=transform)
test_dataset = CIFAR10(root="./data", train=False,
                        download=True, transform=transform)
train_data_size = len(train_dataset)
test_data_size = len(test_dataset)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [4]:
def train_and_valid(model, loss_function, optimizer, epochs=25):
    history = list()
    best_acc = 0.0
    best_epoch = 0

    for epoch in range(epochs):
        epoch_start = time.time()
        print("Epoch: {}/{}".format(epoch+1, epochs))

        model.train()

        train_loss = 0.0
        train_acc = 0.0
        test_loss = 0.0
        test_acc = 0.0

        for inputs, labels in tqdm(train_dataloader,total=len(train_dataloader)):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs,labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * inputs.size(0)
            ret, predictions = torch.max(outputs.data,1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            train_acc += acc.item() * inputs.size(0)

        with torch.no_grad():
            model.eval()

            for inputs, labels in tqdm(test_dataloader,total=len(test_dataloader)):
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                loss = loss_function(outputs, labels)

                test_loss += loss.item() * inputs.size(0)
                ret, predictions = torch.max(outputs.data, 1)
                correct_counts = predictions.eq(labels.data.view_as(predictions))
                acc = torch.mean(correct_counts.type(torch.FloatTensor))
                test_acc += acc.item() * inputs.size(0)

        avg_train_loss = train_loss/train_data_size
        avg_train_acc = train_acc/train_data_size
        avg_test_loss = test_loss/test_data_size
        avg_test_acc = test_acc/test_data_size

        history.append([avg_train_loss, avg_test_loss, avg_train_acc, avg_test_acc])

        if best_acc < avg_test_acc:
            best_acc = avg_test_acc
            best_epoch = epoch + 1
            torch.save(model.state_dict(), 'weights/best_model.pth')
            

        epoch_end = time.time()

        print("Epoch: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\Test: Loss: {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(
            epoch+1, avg_test_loss, avg_train_acc*100, avg_test_loss, avg_test_acc*100, epoch_end-epoch_start
        ))
        print("Best Accuracy for validation : {:.4f} at epoch {:03d}".format(best_acc, best_epoch))

    return model, history


In [5]:
model = efficientnet_b0(pretrained=True)
model.classifier.linear = nn.Linear(1000, 10)
model.to(device)
num_epochs = 30
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
trained_model, history = train_and_valid(model, loss_func, optimizer, num_epochs)



Epoch: 1/30


100%|██████████| 1563/1563 [06:38<00:00,  3.92it/s]
100%|██████████| 313/313 [00:16<00:00, 19.30it/s]


Epoch: 001, Training: Loss: 0.2970, Accuracy: 83.5560%, 
	\Test: Loss: 0.2970, Accuracy: 89.9700%, Time: 414.5447s
Best Accuracy for validation : 0.8997 at epoch 001
Epoch: 2/30


100%|██████████| 1563/1563 [03:54<00:00,  6.66it/s]
100%|██████████| 313/313 [00:17<00:00, 18.17it/s]


Epoch: 002, Training: Loss: 0.3840, Accuracy: 89.3860%, 
	\Test: Loss: 0.3840, Accuracy: 88.4800%, Time: 252.0304s
Best Accuracy for validation : 0.8997 at epoch 001
Epoch: 3/30


100%|██████████| 1563/1563 [03:58<00:00,  6.56it/s]
100%|██████████| 313/313 [00:16<00:00, 18.47it/s]


Epoch: 003, Training: Loss: 0.2462, Accuracy: 91.4340%, 
	\Test: Loss: 0.2462, Accuracy: 92.0000%, Time: 255.2438s
Best Accuracy for validation : 0.9200 at epoch 003
Epoch: 4/30


100%|██████████| 1563/1563 [04:00<00:00,  6.50it/s]
100%|██████████| 313/313 [00:19<00:00, 15.66it/s]


Epoch: 004, Training: Loss: 0.2431, Accuracy: 93.1240%, 
	\Test: Loss: 0.2431, Accuracy: 92.4200%, Time: 260.5830s
Best Accuracy for validation : 0.9242 at epoch 004
Epoch: 5/30


100%|██████████| 1563/1563 [07:43<00:00,  3.37it/s]
100%|██████████| 313/313 [00:29<00:00, 10.66it/s]


Epoch: 005, Training: Loss: 0.2361, Accuracy: 93.5160%, 
	\Test: Loss: 0.2361, Accuracy: 92.5800%, Time: 493.2817s
Best Accuracy for validation : 0.9258 at epoch 005
Epoch: 6/30


100%|██████████| 1563/1563 [06:14<00:00,  4.18it/s]
100%|██████████| 313/313 [00:35<00:00,  8.70it/s]


Epoch: 006, Training: Loss: 0.2212, Accuracy: 94.6000%, 
	\Test: Loss: 0.2212, Accuracy: 93.2500%, Time: 410.1637s
Best Accuracy for validation : 0.9325 at epoch 006
Epoch: 7/30


100%|██████████| 1563/1563 [10:40<00:00,  2.44it/s]
100%|██████████| 313/313 [00:37<00:00,  8.27it/s]


Epoch: 007, Training: Loss: 0.2250, Accuracy: 95.3960%, 
	\Test: Loss: 0.2250, Accuracy: 93.2400%, Time: 678.8317s
Best Accuracy for validation : 0.9325 at epoch 006
Epoch: 8/30


100%|██████████| 1563/1563 [10:30<00:00,  2.48it/s]
100%|██████████| 313/313 [00:34<00:00,  9.00it/s]


Epoch: 008, Training: Loss: 0.2264, Accuracy: 95.8040%, 
	\Test: Loss: 0.2264, Accuracy: 92.9800%, Time: 665.4034s
Best Accuracy for validation : 0.9325 at epoch 006
Epoch: 9/30


100%|██████████| 1563/1563 [10:42<00:00,  2.43it/s]
100%|██████████| 313/313 [00:35<00:00,  8.91it/s]


Epoch: 009, Training: Loss: 0.2255, Accuracy: 96.2760%, 
	\Test: Loss: 0.2255, Accuracy: 93.5400%, Time: 677.9765s
Best Accuracy for validation : 0.9354 at epoch 009
Epoch: 10/30


100%|██████████| 1563/1563 [10:45<00:00,  2.42it/s]
100%|██████████| 313/313 [00:31<00:00,  9.86it/s]


Epoch: 010, Training: Loss: 0.3165, Accuracy: 96.5680%, 
	\Test: Loss: 0.3165, Accuracy: 92.0900%, Time: 677.5913s
Best Accuracy for validation : 0.9354 at epoch 009
Epoch: 11/30


100%|██████████| 1563/1563 [09:56<00:00,  2.62it/s]
100%|██████████| 313/313 [00:36<00:00,  8.48it/s]


Epoch: 011, Training: Loss: 0.2815, Accuracy: 96.8080%, 
	\Test: Loss: 0.2815, Accuracy: 92.6000%, Time: 633.2906s
Best Accuracy for validation : 0.9354 at epoch 009
Epoch: 12/30


100%|██████████| 1563/1563 [10:44<00:00,  2.43it/s]
100%|██████████| 313/313 [00:37<00:00,  8.28it/s]


Epoch: 012, Training: Loss: 0.2992, Accuracy: 96.9760%, 
	\Test: Loss: 0.2992, Accuracy: 92.6000%, Time: 682.0001s
Best Accuracy for validation : 0.9354 at epoch 009
Epoch: 13/30


100%|██████████| 1563/1563 [10:12<00:00,  2.55it/s]
100%|██████████| 313/313 [00:35<00:00,  8.78it/s]


Epoch: 013, Training: Loss: 0.2493, Accuracy: 97.2720%, 
	\Test: Loss: 0.2493, Accuracy: 93.4000%, Time: 648.0194s
Best Accuracy for validation : 0.9354 at epoch 009
Epoch: 14/30


 73%|███████▎  | 1140/1563 [07:29<02:46,  2.54it/s]


KeyboardInterrupt: 