In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

import resnet
import dataloader

import matplotlib.pyplot as plt
import numpy as np

import os
from tqdm import tqdm

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
model = resnet.ResNet18(num_classes=10)
model = model.to(device)

In [4]:
CELoss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.95)
best_acc = 0

In [5]:
def train(epoch, model, trainloader):
    print(f'Train Epoch:{epoch}')
    model.train()
    
    batch_train_loss = []
    with tqdm(total=len(trainloader)) as pbar:
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = CELoss(outputs, targets)
            loss.backward()
            optimizer.step()
            
            batch_train_loss.append(loss.item())
            pbar.update(1)
            
    return np.mean(batch_train_loss), batch_train_loss


def test(epoch, model, testloader):
    global best_acc
    model.eval()
    batch_test_loss = []
    batch_correct = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = CELoss(outputs, targets)
            batch_test_loss.append(loss.item())
            
            _, predicted = torch.max(outputs, axis=1)
            batch_correct.append(predicted.eq(targets).numpy())
    
    batch_correct = np.array(batch_correct)
    correct = batch_correct.sum()
    total = batch_correct.size
    
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc
    
    return np.array(batch_test_loss).mean(), batch_test_loss, acc

In [6]:
train_loss_list = []
test_loss_list = []
acc_list = []
for epoch in range(2):
    _,train_loss = train(epoch, model, dataloader.trainloader)
    _,test_loss, acc = test(epoch, model, dataloader.testloader)
    scheduler.step()
    train_loss_list.append(train_loss)
    test_loss_list.append(test_loss)
    acc_list.append(acc)

  0%|          | 0/391 [00:00<?, ?it/s]

Train Epoch:0


 23%|██▎       | 90/391 [05:49<19:29,  3.89s/it]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure()
plt.plot(np.array(train_loss_list).ravel())
plt.figure()
plt.plot(np.array(test_loss_list).ravel())