In [None]:
import torch
from torchsummary import summary
import os

from models.ResNet import ResNet50
from utils import DataLoader, Logger, Evaluation

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PATH_TO_DATA = './data/'
DOWNLOAD = False
ROUND = 1
LOG_NAME = "ResNetCifar10_{}.log".format(DEVICE)
LEARNING_RATE = 0.01

Logger = Logger.Logger(log_name=LOG_NAME)
logger = Logger.logger
progress_bar = Logger.progress_bar

In [None]:
model = ResNet50().to(device=DEVICE)
_ = summary(model, (3,32,32), device=DEVICE)
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
train_loader, test_loader, classes = DataLoader.load_data_cifar10(
    path_to_data=PATH_TO_DATA, download=DOWNLOAD
)
logger.info("train_loader: len={} batch_size={}".format(train_loader.__len__(), train_loader.batch_size))
logger.info("test_loader: len={} batch_size={}".format(test_loader.__len__(), test_loader.batch_size))
logger.info("classes: labels={}".format(classes))
best_acc, best_epoch = 0, 0

In [None]:
def train(epoch):
    logger.info("training epoch {} start".format(epoch))
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_func(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        
def evalaution(epoch):
    global best_acc, best_epoch
    logger.info("evaluation epoch {} start".format(epoch))
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs = model(inputs)
            loss = loss_func(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        logger.info('Saving..')
        state = {
            'net': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc
        best_epoch = epoch

    logger.info("current best acc={} in epoch={}".format(best_acc, best_epoch))

In [None]:
for epoch in range(ROUND):
    train(epoch=epoch)
    evalaution(epoch=epoch)