In [None]:
import os
import math
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.backends import cudnn
from torch.utils.data import DataLoader
from torchvision.models import resnet152
from tqdm import tqdm

### Params

In [None]:
num_workers = os.cpu_count()
use_cuda = torch.cuda.is_available()

# following params can be input by argparse
batch_size = 256
lr = 0.01
momentum = 0.9
epochs = 100

### Best Result

In [None]:
best_loss = math.inf
best_acc = 0.0

### Load Dataset

In [None]:
train_images = []
train_loader = DataLoader(train_images,
                          batch_size,
                          shuffle=True,
                          num_workers=num_workers)
test_images = []
test_loader = DataLoader(test_images,
                         batch_size,
                         shuffle=True,
                         num_workers=num_workers)

### Model

In [None]:
net = resnet152(num_classes=4)
if use_cuda:
    net = nn.DataParallel(net)
    net = net.cuda()
    cudnn.benchmark = True

### Optimizer

In [None]:
optimizer = optim.SGD(net.parameters(),
                      lr=lr,
                      momentum=momentum,
                      weight_decay=1e-4)

### Loss

In [None]:
criterion = nn.CrossEntropyLoss()

### Train

In [None]:
def train():
    net.train()

    losses = 0.0

    loader = tqdm(train_loader)

    for images, labels in loader:
        if use_cuda:
            images = images.cuda()
            labels = labels.cuda()
        images = Variable(images)
        labels = Variable(labels)

        optimizer.zero_grad()

        output = net(images)

        loss = criterion(output, labels)
        loss.backward()
        losses += loss

        optimizer.step()

        loader.set_description(f'[Train] Loss: {loss.data:.4f}')

    print(f'[Train] Avg loss: {losses / len(train_loader):.4f}')

### Test

Not save best model

In [None]:
def test():
    net.eval()

    accuracy = 0.0

    loader = tqdm(test_loader)

    for images, labels in loader:
        if use_cuda:
            images = images.cuda()
            labels = labels.cuda()
        images = Variable(images)

        with torch.no_grad():
            output = net(images)
            pred = torch.argmax(output, 1)

        correct = torch.eq(pred, labels).sum().float() / pred.size(0)
        accuracy += correct
        loader.set_description(f'[Test] Accuracy: {correct:.4f}')

    print(f'[Test] Avg accuracy: {accuracy / len(test_loader):.4f}')

### Main

In [None]:
if __name__ == '__main__':
    for epoch in range(1, epochs + 1):
        print(f'Epoch: {epoch}')
        train()
        test()