In [None]:
import time
import torch
import torch.nn as nn
from utils.model.resnet import ResNet18
from utils.utils.dataloader import *

train_loader, test_loader, labels = cifar10(100)
PATH_PARAMETERS = "models/cifar10/resnet.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"

train_net = ResNet18().to(device)
epochs = 30

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(train_net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)



In [None]:
def train(epoch):
    train_net.train()
    train_loss = 0.0
    correct = 0.0
    total = 0.0
    start_time = time.time()
    for i, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        output = train_net(X)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        pred = output.argmax(1)
        correct += torch.eq(pred, y).sum().item()
        total += y.size(0)
        train_loss += loss.item()
    end_time = time.time()
    time_taken = end_time - start_time

    print("Epoch: {}  train_loss: {}  accuracy: {}".format(epoch + 1, train_loss / len(train_loader),
                                                             100*correct / total))
    print("Time:", time_taken)
    train_loss = 0.0

def test(epoch):
    train_net.eval()
    test_loss = 0.0
    correct = 0.0
    total = 0.0
    start_time = time.time()
    with torch.no_grad():
        for i, (X, y) in enumerate(test_loader):
            X, y = X.to(device), y.to(device)
            output = train_net(X)
            loss = criterion(output, y)
            pred = output.argmax(1)
            correct += torch.eq(pred, y).sum().item()
            test_loss += loss.item()
            
            total += y.size(0)
        end_time = time.time()
        time_taken = end_time - start_time

        print(
            "Epoch: {}  test_loss: {}  accuracy: {}".format(epoch + 1, test_loss / len(test_loader), 100*correct / total))
        print("Time:", time_taken)
        test_loss = 0.0
    print("saving net...")
    torch.save(train_net.state_dict(), "models/cifar10/advresnet.pth")

In [None]:
for epoch in range(epochs):
    train(epoch)
    test(epoch)
    scheduler.step()