In [1]:
from theloop import TheLoop

import torch
from torch.nn import functional as F
from torchvision.datasets import MNIST, CIFAR10
import torchvision.models as models
import torchvision.transforms as transforms
from tqdm import tqdm, tqdm_notebook
from sklearn.metrics import accuracy_score

In [2]:
def batch_callback(state):
    model = state.model
    batch = state.batch
    device = state.device
    criterion = state.criterion

    out = model(batch[0].to(device))
    state.loss = criterion(out, batch[1].to(device))

In [3]:
def val_callback(state):
    model = state.model
    dloader = state.data
    device = state.device
    
    predict = []
    ground_truth = []

    for batch in dloader:
        with torch.no_grad():
            out = F.softmax(model(batch[0].to(device)).cpu(), dim=1)
            pred = torch.argmax(out, dim=1)

        predict += pred.tolist()
        ground_truth += batch[1].tolist()

    accuracy = accuracy_score(predict, ground_truth)
    state.set_metric("accuracy", accuracy)

In [4]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = CIFAR10(root="./", train=True, transform=trans, download=True)
test_set = CIFAR10(root="./", train=False, transform=trans, download=True)

model = models.squeezenet1_0(pretrained=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
theloop = TheLoop(model, "CrossEntropyLoss", batch_callback,
                  val_callback=val_callback,
                  optimizer_params={"lr": 1e-4},
                  logdir="./logdir",
                  val_rate=1000,
                  device="cpu",
                  val_criterion_key="accuracy")

In [None]:
model = theloop.a(train_set, test_set, n_epoch=2)

   _____ _______       _____ _______   _______ _    _ ______ _      ____   ____  _____
  / ____|__   __|/\   |  __ \__   __| |__   __| |  | |  ____| |    / __ \ / __ \|  __ \
 | (___    | |  /  \  | |__) | | |       | |  | |__| | |__  | |   | |  | | |  | | |__) |
  \___ \   | | / /\ \ |  _  /  | |       | |  |  __  |  __| | |   | |  | | |  | |  ___/
  ____) |  | |/ ____ \| | \ \  | |       | |  | |  | | |____| |___| |__| | |__| | |
 |_____/   |_/_/    \_\_|  \_\ |_|       |_|  |_|  |_|______|______\____/ \____/|_|


EXPERIMENT NAME: experiment
EXPERIMENT ID: 2766
NUM EPOCH: 2
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+


  |￣￣￣￣￣￣|
    EPOCH: 0  
  |＿＿＿＿＿＿|
(\__/) || 
(•ㅅ•) || 
/ 　 づ


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

In [None]:
import theloop

In [None]:
?theloop.states.BatchState