In [1]:
from theloop import TheLoop

import torch
from torch.nn import functional as F
from torchvision.datasets import MNIST, CIFAR10
import torchvision.models as models
import torchvision.transforms as transforms
from tqdm import tqdm, tqdm_notebook
from sklearn.metrics import accuracy_score

In [2]:
def batch_callback(**kwargs):
    model, batch, device, criterion = kwargs["model"], kwargs["batch"], kwargs["device"], kwargs["criterion"]

    out = model(batch[0].to(device))
    loss = criterion(out, batch[1].to(device))

    return {"loss": loss}

In [3]:
def val_callback(**kwargs):
    model, dloader, device = kwargs["model"], kwargs["data"], kwargs["device"]

    predict = []
    ground_truth = []

    for batch in dloader:
        with torch.no_grad():
            out = F.softmax(model(batch[0].to(device)).cpu(), dim=1)
            pred = torch.argmax(out, dim=1)

        predict += pred.tolist()
        ground_truth += batch[1].tolist()

    accuracy = accuracy_score(predict, ground_truth)

    return {"accuracy": accuracy}

In [4]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = CIFAR10(root="./", train=True, transform=trans, download=True)
test_set = CIFAR10(root="./", train=False, transform=trans, download=True)

model = models.squeezenet1_0(pretrained=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
theloop = TheLoop(model, "CrossEntropyLoss", batch_callback,
                  val_callback=val_callback,
                  optimizer_params={"lr": 1e-4},
                  logdir="./logdir",
                  val_rate=1000,
                  device="cpu",
                  val_criterion_key="accuracy",
                  using_tqdm_notebook=False)

In [6]:
model = theloop.a(train_set, test_set, n_epoch=2)

  0%|          | 0/1563 [00:00<?, ?it/s]

||STARTING THE LOOP||


  |￣￣￣￣￣￣|
  |  EPOCH: 0  |
  |＿＿＿＿＿＿|
(\__/) || 
(•ㅅ•) || 
/ 　 づ


BATCH 0; ITER 0:   0%|          | 0/1563 [00:00<?, ?it/s, loss=44.8]

Starting validation...


BATCH 1; ITER 1:   0%|          | 2/1563 [00:13<3:58:11,  9.16s/it, loss=36.9]

Validation ready!


BATCH 1000; ITER 1000:  64%|██████▍   | 1000/1563 [02:13<01:14,  7.60it/s, loss=1.44]

Starting validation...


BATCH 1001; ITER 1001:  64%|██████▍   | 1002/1563 [02:25<25:49,  2.76s/it, loss=2.09]

Validation ready!


BATCH 1562; ITER 1562: 100%|██████████| 1563/1563 [03:36<00:00,  7.21it/s, loss=1.46] 
  0%|          | 0/1563 [00:00<?, ?it/s]

Save epoch checkpoint


  |￣￣￣￣￣￣|
  |  EPOCH: 1  |
  |＿＿＿＿＿＿|
(\__/) || 
(•ㅅ•) || 
/ 　 づ


BATCH 437; ITER 2000:  28%|██▊       | 437/1563 [00:55<02:24,  7.82it/s, loss=1.37] 

Starting validation...


BATCH 438; ITER 2001:  28%|██▊       | 439/1563 [01:08<50:15,  2.68s/it, loss=0.98]  

Validation ready!


BATCH 1437; ITER 3000:  92%|█████████▏| 1437/1563 [03:12<00:15,  8.08it/s, loss=1.03] 

Starting validation...


BATCH 1438; ITER 3001:  92%|█████████▏| 1439/1563 [03:24<05:34,  2.70s/it, loss=1.45]

Validation ready!


BATCH 1562; ITER 3125: 100%|█████████▉| 1562/1563 [03:39<00:00,  8.21it/s, loss=0.981]

Starting validation...


BATCH 1562; ITER 3125: 100%|██████████| 1563/1563 [03:52<00:00,  3.74s/it, loss=0.981]

Validation ready!
Save epoch checkpoint




BEST METRICS
|| Best checkpoint score: 0.6856
|| accuracy: 0.6856



