# Part 1: Training and evaluating simple CNNs

__Before starting, we recommend you enable GPU acceleration if you're running on Colab.__

In [None]:
# Execute this code block to install dependencies when running on colab
try:
    import torch
except:
    from os.path import exists
    from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
    platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
    cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
    accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

    !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision

try: 
    import torchbearer
except:
    !pip install torchbearer

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchbearer
  Downloading torchbearer-0.5.3-py3-none-any.whl (138 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.1/138.1 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchbearer
Successfully installed torchbearer-0.5.3


In [None]:
# automatically reload external modules if they change
%load_ext autoreload
%autoreload 2

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchbearer
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchbearer import Trial
from pickle import load

In [None]:
# fix random seed for reproducibility
seed = 7
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [None]:
class MyDataset(Dataset):
  def __init__(self, size=5000, dim=40, random_offset=0):
        super(MyDataset, self).__init__()
        self.size = size
        self.dim = dim
        self.random_offset = random_offset

  def __getitem__(self, index):
      if index >= len(self):
          raise IndexError("{} index out of range".format(self.__class__.__name__))

      rng_state = torch.get_rng_state()
      torch.manual_seed(index + self.random_offset)

      while True:
        img = torch.zeros(self.dim, self.dim)
        dx = torch.randint(-10,10,(1,),dtype=torch.float)
        dy = torch.randint(-10,10,(1,),dtype=torch.float)
        c = torch.randint(-20,20,(1,), dtype=torch.float)

        params = torch.cat((dy/dx, c))
        xy = torch.randint(0,img.shape[1], (20, 2), dtype=torch.float)
        xy[:,1] = xy[:,0] * params[0] + params[1]

        xy.round_()
        xy = xy[ xy[:,1] > 0 ]
        xy = xy[ xy[:,1] < self.dim ]
        xy = xy[ xy[:,0] < self.dim ]

        for i in range(xy.shape[0]):
          x, y = xy[i][0], self.dim - xy[i][1]
          img[int(y), int(x)]=1
        if img.sum() > 2:
          break

      torch.set_rng_state(rng_state)
      return img.unsqueeze(0), params

  def __len__(self):
      return self.size

train_data = MyDataset()
val_data = MyDataset(size=500, random_offset=33333)
test_data = MyDataset(size=500, random_offset=99999)


In [None]:
def loadData(batch_size: int=128):
    trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
    testloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    return trainloader,val_loader, testloader

In [None]:
class CNNmodel(nn.Module):
    def __init__(self):
        super(CNNmodel, self).__init__()
        self.conv1 = nn.Conv2d(1, 48, (3, 3),stride=1, padding=1)
        self.fc1 = nn.Linear(48 * 40**2, 128)
        self.fc2 = nn.Linear(128, 2)
    
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = out.view(out.shape[0], -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

In [None]:
# build the model
model = CNNmodel()

# define the loss function and the optimiser
loss_function = nn.MSELoss()
optimiser = optim.Adam(model.parameters())

device = "cuda:0" if torch.cuda.is_available() else "cpu"
trial = Trial(model, optimiser, loss_function, metrics=['loss', 'accuracy']).to(device)

train_loader, val_loader, test_loader = loadData()
trial.with_generators(train_generator=train_loader, val_generator=val_loader,test_generator=test_loader)
trial.run(epochs=100)
results = trial.evaluate(data_key=torchbearer.TEST_DATA)
print(results)

0/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

0/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

1/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

1/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

2/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

2/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

3/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

3/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

4/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

4/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

5/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

5/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

6/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

6/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

7/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

7/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

8/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

8/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

9/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

9/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

10/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

10/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

11/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

11/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

12/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

12/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

13/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

13/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

14/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

14/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

15/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

15/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

16/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

16/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

17/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

17/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

18/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

18/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

19/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

19/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

20/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

20/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

21/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

21/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

22/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

22/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

23/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

23/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

24/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

24/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

25/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

25/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

26/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

26/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

27/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

27/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

28/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

28/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

29/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

29/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

30/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

30/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

31/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

31/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

32/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

32/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

33/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

33/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

34/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

34/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

35/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

35/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

36/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

36/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

37/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

37/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

38/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

38/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

39/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

39/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

40/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

40/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

41/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

41/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

42/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

42/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

43/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

43/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

44/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

44/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

45/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

45/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

46/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

46/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

47/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

47/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

48/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

48/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

49/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

49/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

50/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

50/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

51/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

51/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

52/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

52/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

53/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

53/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

54/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

54/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

55/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

55/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

56/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

56/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

57/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

57/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

58/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

58/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

59/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

59/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

60/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

60/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

61/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

61/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

62/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

62/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

63/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

63/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

64/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

64/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

65/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

65/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

66/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

66/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

67/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

67/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

68/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

68/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

69/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

69/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

70/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

70/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

71/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

71/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

72/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

72/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

73/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

73/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

74/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

74/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

75/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

75/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

76/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

76/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

77/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

77/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

78/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

78/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

79/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

79/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

80/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

80/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

81/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

81/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

82/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

82/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

83/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

83/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

84/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

84/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

85/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

85/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

86/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

86/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

87/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

87/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

88/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

88/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

89/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

89/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

90/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

90/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

91/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

91/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

92/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

92/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

93/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

93/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

94/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

94/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

95/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

95/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

96/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

96/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

97/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

97/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

98/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

98/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

99/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

99/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

0/1(e):   0%|          | 0/4 [00:00<?, ?it/s]

{'test_loss': 10.333111763000488, 'test_mse': 10.268647193908691}


In [None]:
results = trial.evaluate(data_key=torchbearer.TEST_DATA)
print(results)

0/1(e):   0%|          | 0/4 [00:00<?, ?it/s]

{'test_loss': 10.25437068939209, 'test_mse': 10.268647193908691}


In [None]:
class CNNmodel2(nn.Module):
    def __init__(self):
        super(CNNmodel2, self).__init__()
        self.conv1 = nn.Conv2d(1, 48, (3, 3),stride=1, padding=1)
        self.conv2 = nn.Conv2d(48, 48, (3, 3),stride=1, padding=1)
        self.fc1 = nn.Linear(48 * 1 ** 2, 128)
        self.fc2 = nn.Linear(128, 2)
        self.pool1 = nn.AdaptiveMaxPool2d((1))
    
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = self.conv2(out)
        out = F.relu(out)
        #print(out.shape)
        out = self.pool1(out)
       # print("POOL")
        #print(out.shape)
        #out = F.dropout(out, 0.2)
        out = out.view(out.shape[0], -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

In [None]:
# build the model
model = CNNmodel2()

# define the loss function and the optimiser
loss_function = nn.MSELoss()
optimiser = optim.Adam(model.parameters())

device = "cuda:0" if torch.cuda.is_available() else "cpu"
trial = Trial(model, optimiser, loss_function, metrics=['loss', 'accuracy']).to(device)

train_loader, val_loader, test_loader = loadData()
trial.with_generators(train_generator=train_loader, val_generator=val_loader,test_generator=test_loader)
trial.run(epochs=100)

0/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

0/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

1/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

1/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

2/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

2/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

3/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

3/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

4/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

4/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

5/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

5/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

6/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

6/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

7/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

7/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

8/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

8/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

9/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

9/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

10/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

10/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

11/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

11/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

12/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

12/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

13/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

13/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

14/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

14/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

15/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

15/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

16/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

16/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

17/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

17/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

18/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

18/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

19/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

19/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

20/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

20/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

21/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

21/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

22/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

22/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

23/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

23/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

24/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

24/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

25/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

25/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

26/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

26/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

27/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

27/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

28/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

28/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

29/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

29/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

30/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

30/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

31/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

31/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

32/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

32/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

33/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

33/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

34/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

34/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

35/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

35/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

36/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

36/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

37/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

37/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

38/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

38/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

39/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

39/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

40/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

40/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

41/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

41/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

42/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

42/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

43/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

43/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

44/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

44/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

45/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

45/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

46/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

46/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

47/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

47/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

48/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

48/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

49/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

49/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

50/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

50/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

51/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

51/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

52/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

52/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

53/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

53/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

54/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

54/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

55/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

55/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

56/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

56/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

57/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

57/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

58/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

58/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

59/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

59/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

60/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

60/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

61/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

61/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

62/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

62/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

63/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

63/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

64/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

64/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

65/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

65/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

66/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

66/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

67/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

67/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

68/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

68/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

69/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

69/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

70/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

70/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

71/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

71/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

72/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

72/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

73/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

73/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

74/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

74/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

75/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

75/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

76/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

76/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

77/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

77/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

78/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

78/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

79/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

79/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

80/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

80/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

81/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

81/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

82/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

82/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

83/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

83/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

84/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

84/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

85/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

85/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

86/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

86/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

87/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

87/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

88/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

88/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

89/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

89/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

90/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

90/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

91/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

91/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

92/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

92/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

93/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

93/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

94/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

94/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

95/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

95/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

96/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

96/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

97/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

97/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

98/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

98/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

99/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

99/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

[{'running_loss': 64.1994400024414,
  'running_mse': 64.1994400024414,
  'loss': 62.601104736328125,
  'mse': 62.997440338134766,
  'val_loss': 54.19070053100586,
  'val_mse': 54.171775817871094,
  'train_steps': 40,
  'validation_steps': 4},
 {'running_loss': 57.78752136230469,
  'running_mse': 57.78752136230469,
  'loss': 56.5319938659668,
  'mse': 56.93645477294922,
  'val_loss': 48.58041763305664,
  'val_mse': 48.39444351196289,
  'train_steps': 40,
  'validation_steps': 4},
 {'running_loss': 49.83289337158203,
  'running_mse': 49.83289337158203,
  'loss': 46.400794982910156,
  'mse': 45.857845306396484,
  'val_loss': 34.71123504638672,
  'val_mse': 34.71419143676758,
  'train_steps': 40,
  'validation_steps': 4},
 {'running_loss': 35.99846267700195,
  'running_mse': 35.99846267700195,
  'loss': 31.625089645385742,
  'mse': 31.62813377380371,
  'val_loss': 25.622026443481445,
  'val_mse': 25.623075485229492,
  'train_steps': 40,
  'validation_steps': 4},
 {'running_loss': 27.090349

In [None]:
results = trial.evaluate(data_key=torchbearer.TEST_DATA)
print(results)

0/1(e):   0%|          | 0/4 [00:00<?, ?it/s]

{'test_loss': 15.032289505004883, 'test_mse': 15.025256156921387}


In [None]:
class CNNmodel3(nn.Module):
    def __init__(self):
        super(CNNmodel3, self).__init__()
        self.conv1 = nn.Conv2d(3, 48, (3, 3),stride=1, padding=1)
        self.conv2 = nn.Conv2d(48, 48, (3, 3),stride=1, padding=1)
        self.fc1 = nn.Linear(48 * 1**2, 128)
        self.fc2 = nn.Linear(128, 2)
        self.pool1 = nn.AdaptiveMaxPool2d((1))
    
    def forward(self, x):
        idxx = torch.repeat_interleave(
            torch.arange(-20, 20, dtype=torch.float).unsqueeze(0)/40.0,
            repeats=40, dim=0).to(x.device)
        idxy = idxx.clone().t()
        idx = torch.stack([idxx, idxy]).unsqueeze(0)
        idx = torch.repeat_interleave(idx, repeats=x.shape[0], dim=0)
        x = torch.cat([x, idx], dim=1)
        
        out = self.conv1(x)
        out = F.relu(out)
        out = self.conv2(out)
        out = F.relu(out)
        #print(out.shape)
        out = self.pool1(out)
        #print("POOL")
        #print(out.shape)
        #out = F.dropout(out, 0.2)
        out = out.view(out.shape[0], -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        return out

In [None]:
# build the model
model = CNNmodel3()

# define the loss function and the optimiser
loss_function = nn.MSELoss()
optimiser = optim.Adam(model.parameters())

device = "cuda:0" if torch.cuda.is_available() else "cpu"
trial = Trial(model, optimiser, loss_function, metrics=['loss', 'accuracy']).to(device)

train_loader, val_loader, test_loader = loadData()
trial.with_generators(train_generator=train_loader, val_generator=val_loader,test_generator=test_loader)
trial.run(epochs=100)

0/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

0/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

1/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

1/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

2/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

2/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

3/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

3/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

4/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

4/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

5/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

5/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

6/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

6/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

7/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

7/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

8/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

8/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

9/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

9/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

10/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

10/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

11/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

11/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

12/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

12/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

13/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

13/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

14/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

14/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

15/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

15/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

16/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

16/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

17/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

17/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

18/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

18/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

19/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

19/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

20/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

20/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

21/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

21/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

22/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

22/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

23/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

23/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

24/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

24/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

25/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

25/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

26/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

26/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

27/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

27/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

28/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

28/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

29/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

29/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

30/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

30/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

31/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

31/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

32/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

32/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

33/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

33/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

34/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

34/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

35/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

35/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

36/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

36/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

37/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

37/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

38/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

38/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

39/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

39/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

40/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

40/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

41/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

41/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

42/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

42/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

43/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

43/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

44/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

44/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

45/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

45/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

46/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

46/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

47/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

47/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

48/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

48/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

49/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

49/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

50/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

50/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

51/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

51/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

52/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

52/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

53/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

53/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

54/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

54/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

55/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

55/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

56/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

56/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

57/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

57/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

58/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

58/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

59/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

59/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

60/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

60/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

61/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

61/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

62/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

62/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

63/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

63/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

64/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

64/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

65/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

65/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

66/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

66/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

67/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

67/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

68/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

68/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

69/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

69/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

70/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

70/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

71/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

71/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

72/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

72/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

73/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

73/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

74/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

74/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

75/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

75/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

76/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

76/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

77/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

77/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

78/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

78/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

79/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

79/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

80/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

80/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

81/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

81/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

82/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

82/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

83/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

83/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

84/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

84/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

85/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

85/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

86/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

86/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

87/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

87/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

88/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

88/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

89/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

89/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

90/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

90/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

91/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

91/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

92/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

92/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

93/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

93/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

94/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

94/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

95/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

95/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

96/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

96/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

97/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

97/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

98/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

98/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

99/100(t):   0%|          | 0/40 [00:00<?, ?it/s]

99/100(v):   0%|          | 0/4 [00:00<?, ?it/s]

[{'running_loss': 64.3674545288086,
  'running_mse': 64.3674545288086,
  'loss': 62.3181266784668,
  'mse': 62.661590576171875,
  'val_loss': 51.0767936706543,
  'val_mse': 51.12935256958008,
  'train_steps': 40,
  'validation_steps': 4},
 {'running_loss': 52.50098419189453,
  'running_mse': 52.50098419189453,
  'loss': 45.357078552246094,
  'mse': 45.70233154296875,
  'val_loss': 25.512638092041016,
  'val_mse': 25.6035099029541,
  'train_steps': 40,
  'validation_steps': 4},
 {'running_loss': 26.024810791015625,
  'running_mse': 26.024810791015625,
  'loss': 18.320240020751953,
  'mse': 18.558773040771484,
  'val_loss': 12.472264289855957,
  'val_mse': 12.521013259887695,
  'train_steps': 40,
  'validation_steps': 4},
 {'running_loss': 13.35020637512207,
  'running_mse': 13.35020637512207,
  'loss': 12.059473991394043,
  'mse': 11.90164852142334,
  'val_loss': 9.706626892089844,
  'val_mse': 9.762402534484863,
  'train_steps': 40,
  'validation_steps': 4},
 {'running_loss': 10.089457

In [None]:
results = trial.evaluate(data_key=torchbearer.TEST_DATA)
print(results)

0/1(e):   0%|          | 0/4 [00:00<?, ?it/s]

{'test_loss': 1.144503116607666, 'test_mse': 1.1445499658584595}
