In [3]:
# Execute this code block to install dependencies when running on colab
try:
    import torch
except:
    from os.path import exists
    from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
    platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
    cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
    accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

    !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision

try: 
    import torchbearer
except:
    !pip install torchbearer

# automatically reload external modules if they change
%load_ext autoreload
%autoreload 2

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchbearer
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchbearer import Trial

# fix random seed for reproducibility
seed = 7
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [0]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset

class MyDataset(Dataset):
  def __init__(self, size=5000, dim=40, random_offset=0):
        super(MyDataset, self).__init__()
        self.size = size
        self.dim = dim
        self.random_offset = random_offset

  def __getitem__(self, index):
      if index >= len(self):
          raise IndexError("{} index out of range".format(self.__class__.__name__))

      rng_state = torch.get_rng_state()
      torch.manual_seed(index + self.random_offset)

      while True:
        img = torch.zeros(self.dim, self.dim)
        dx = torch.randint(-10,10,(1,),dtype=torch.float)
        dy = torch.randint(-10,10,(1,),dtype=torch.float)
        c = torch.randint(-20,20,(1,), dtype=torch.float)

        params = torch.cat((dy/dx, c))
        xy = torch.randint(0,img.shape[1], (20, 2), dtype=torch.float)
        xy[:,1] = xy[:,0] * params[0] + params[1]

        xy.round_()
        xy = xy[ xy[:,1] > 0 ]
        xy = xy[ xy[:,1] < self.dim ]
        xy = xy[ xy[:,0] < self.dim ]

        for i in range(xy.shape[0]):
          x, y = xy[i][0], self.dim - xy[i][1]
          img[int(y), int(x)]=1
        if img.sum() > 2:
          break

      torch.set_rng_state(rng_state)
      return img.unsqueeze(0), params

  def __len__(self):
      return self.size

train_data = MyDataset()
val_data = MyDataset(size=500, random_offset=33333)
test_data = MyDataset(size=500, random_offset=99999)

In [0]:
# convert each image to tensor format
transform = transforms.Compose([
    transforms.ToTensor()  # convert to tensor
])

# create data loaders
trainloader = DataLoader(train_data, batch_size=128, shuffle=True)
valloader = DataLoader(val_data, batch_size=128, shuffle=True)
testloader = DataLoader(test_data, batch_size=128, shuffle=True)

In [0]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 48, (3, 3), stride=1, padding=1)
        self.fc1 = nn.Linear(48 * 40**2, 128)
        self.fc2 = nn.Linear(128, 2)
            
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = out.view(out.shape[0], -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

In [20]:
# build the model
model = SimpleCNN()

# define the loss function and the optimiser
loss_function = nn.L1Loss()
optimiser = optim.Adam(model.parameters())

device = "cuda:0" if torch.cuda.is_available() else "cpu"

trial = Trial(model, optimiser, loss_function, metrics=['loss', 'accuracy']).to(device)
trial.with_generators(trainloader, valloader, test_generator=testloader)
trial.run(epochs=100)
results = trial.evaluate(data_key=torchbearer.TEST_DATA)
print(results)

HBox(children=(FloatProgress(value=0.0, description='0/100(t)', max=40.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='0/100(v)', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='1/100(t)', max=40.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='1/100(v)', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='2/100(t)', max=40.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='2/100(v)', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='3/100(t)', max=40.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='3/100(v)', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='4/100(t)', max=40.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='4/100(v)', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='5/100(t)', max=40.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='5/100(v)', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='6/100(t)', max=40.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='6/100(v)', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='7/100(t)', max=40.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='7/100(v)', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='8/100(t)', max=40.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='8/100(v)', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='9/100(t)', max=40.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='9/100(v)', max=4.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='10/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='10/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='11/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='11/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='12/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='12/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='13/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='13/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='14/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='14/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='15/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='15/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='16/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='16/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='17/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='17/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='18/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='18/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='19/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='19/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='20/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='20/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='21/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='21/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='22/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='22/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='23/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='23/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='24/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='24/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='25/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='25/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='26/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='26/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='27/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='27/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='28/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='28/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='29/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='29/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='30/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='30/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='31/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='31/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='32/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='32/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='33/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='33/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='34/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='34/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='35/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='35/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='36/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='36/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='37/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='37/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='38/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='38/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='39/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='39/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='40/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='40/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='41/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='41/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='42/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='42/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='43/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='43/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='44/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='44/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='45/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='45/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='46/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='46/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='47/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='47/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='48/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='48/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='49/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='49/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='50/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='50/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='51/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='51/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='52/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='52/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='53/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='53/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='54/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='54/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='55/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='55/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='56/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='56/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='57/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='57/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='58/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='58/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='59/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='59/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='60/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='60/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='61/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='61/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='62/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='62/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='63/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='63/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='64/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='64/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='65/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='65/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='66/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='66/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='67/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='67/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='68/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='68/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='69/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='69/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='70/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='70/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='71/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='71/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='72/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='72/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='73/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='73/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='74/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='74/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='75/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='75/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='76/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='76/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='77/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='77/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='78/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='78/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='79/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='79/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='80/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='80/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='81/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='81/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='82/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='82/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='83/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='83/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='84/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='84/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='85/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='85/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='86/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='86/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='87/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='87/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='88/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='88/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='89/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='89/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='90/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='90/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='91/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='91/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='92/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='92/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='93/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='93/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='94/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='94/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='95/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='95/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='96/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='96/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='97/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='97/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='98/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='98/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='99/100(t)', max=40.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='99/100(v)', max=4.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='0/1(e)', max=4.0, style=ProgressStyle(description_width='…


{'test_loss': 2.087024688720703, 'test_acc': 0.9120000004768372}
