In [1]:
from theloop import TheLoop

import torch
from torch.nn import functional as F
from torchvision.datasets import MNIST, CIFAR10
import torchvision.models as models
import torchvision.transforms as transforms
from tqdm import tqdm, tqdm_notebook
from sklearn.metrics import accuracy_score

In [2]:
def batch_callback(**kwargs):
    model, batch, device, criterion = kwargs["model"], kwargs["batch"], kwargs["device"], kwargs["criterion"]

    out = model(batch[0].to(device))
    loss = criterion(out, batch[1].to(device))

    return {"loss": loss}

In [3]:
def val_callback(**kwargs):
    model, dloader, device = kwargs["model"], kwargs["data"], kwargs["device"]

    predict = []
    ground_truth = []

    for batch in tqdm_notebook(dloader):
        with torch.no_grad():
            out = F.softmax(model(batch[0].to(device)).cpu(), dim=1)
            pred = torch.argmax(out, dim=1)

        predict += pred.tolist()
        ground_truth += batch[1].tolist()

    accuracy = accuracy_score(predict, ground_truth)

    return {"accuracy": accuracy}

In [4]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = CIFAR10(root="./", train=True, transform=trans, download=True)
test_set = CIFAR10(root="./", train=False, transform=trans, download=True)

resnet18 = models.squeezenet1_0(pretrained=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
theloop = TheLoop(resnet18, "CrossEntropyLoss", batch_callback,
                  val_callback=val_callback,
                  optimizer_params={"lr": 1e-4},
                  logdir="./logdir",
                  val_rate=1000,
                  device="cpu",
                  val_criterion_key="accuracy",
                  using_tqdm_notebook=True)

In [None]:
resnet18 = theloop.a(train_set, test_set, n_epoch=1)

||STARTING THE LOOP||


  |￣￣￣￣￣￣|
  |  EPOCH: 0  |
  |＿＿＿＿＿＿|
(\__/) || 
(•ㅅ•) || 
/ 　 づ


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

Starting validation...


HBox(children=(IntProgress(value=0, max=313), HTML(value='')))

Validation ready!
Checkpoint saved
Starting validation...


HBox(children=(IntProgress(value=0, max=313), HTML(value='')))

Validation ready!
Checkpoint saved
