In [1]:
import os.path as osp
from time import time
from datetime import datetime, timedelta
from functools import partial

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as T
from torchvision.datasets import MNIST

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler



In [2]:
class FFNN(nn.Module):
    def __init__(self, in_shape, out_shape, p_d1=0.5, p_d2=0.4):
        super().__init__()
        fc1 = nn.Linear(in_shape, 64)
        a1  = nn.ReLU()
        d1  = nn.Dropout(p=p_d1)
        fc2 = nn.Linear(64, 32)
        a2  = nn.ReLU()
        d2  = nn.Dropout(p=p_d2)
        fc3 = nn.Linear(32, out_shape)
        
        # not applying log_softmax here, as it is applied later in 
        # the torch CCE loss
        
        self.nn = nn.Sequential(fc1, a1, d1, fc2, a2, d2, fc3)

    def forward(self, x):
        x = self.nn(x)
        return x

In [3]:
def train_mnist(config, epochs, checkpoint_dir=None, data_dir=None):
    # create model
    model = FFNN(784, 10, p_d1=config['p_d1'], p_d2=config['p_d2'])
    
    # load data and make a validation split
    transforms = T.Compose([T.ToTensor(), T.Normalize((0.5,),(0.5)), 
                            T.Lambda(lambda x: torch.flatten(x))])
    dataset_train = MNIST(root='/data/', transform=transforms, train=True)

    train_samples = int(len(dataset_train) * 0.8)
    train_subset, val_subset = random_split(dataset_train,
                                           [train_samples, 
                                            len(dataset_train) - train_samples])
    # create dataloaders
    train_args = {'dataset':train_subset, 
                  'batch_size':config['batch_size'], 
                  'shuffle':True, 
                  'num_workers':8, 
                  'pin_memory':True}
    dataloader_train = torch.utils.data.DataLoader(**train_args)
    val_args  = {'dataset':val_subset, 
                  'batch_size':len(val_subset), 
                  'shuffle':False, 
                  'num_workers':8}
    dataloader_val  = torch.utils.data.DataLoader(**val_args) 
    
    # choose computation host device
    device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(device_name)
    model.to(device)
    
    
    optimiser = torch.optim.SGD(params=model.parameters(), lr=config['lr'], momentum=0.9)
    f_loss = nn.CrossEntropyLoss()
    
    # training loop
    for n in range(epochs):
        total_loss = 0.0
        # optimisation
        model.train()
        for idx, (X, y) in enumerate(dataloader_train):
            X, y = X.to(device), y.to(device)
            optimiser.zero_grad()
            y_pred = model(X)
            loss = f_loss(y_pred, y)
            loss.backward()
            total_loss += loss.detach().cpu().item() / len(y) # normalise for batch size
            optimiser.step()
            
        # validation set metrics
        predictions, targets, val_losses = [], [], []
        model.eval()
        # we are adding the metrics tensor for each batch to a list,
        # then concatenating at the end to make one tensor with all samples
        for idx, (X, y) in enumerate(dataloader_val):
            with torch.no_grad():
                y_pred = model(X)
                predictions.append(y_pred.detach())
                targets.append(y)
                val_losses.append(f_loss(y_pred, y).cpu().item())

        predictions = torch.cat(predictions, dim=0)
        targets = torch.cat(targets, dim=0)
        predictions = torch.argmax(F.log_softmax(predictions, dim=1),dim=1)
        corrects = (predictions == targets).sum().item()
        wrongs = len(targets) - corrects
        val_accuracy = corrects / len(targets)
        val_loss = sum(val_losses) / float(len(val_losses))
        
        # save checkpoint
        with tune.checkpoint_dir(n) as checkpoint_dir:
            path = osp.join(checkpoint_dir, 'checkpoint')
            torch.save((model.state_dict(), optimiser.state_dict()), path)
            
        # report metric values back to main scheduler
        tune.report(loss=val_loss, accuracy=val_accuracy)
        
def test_accuracy(model, device='cpu'):
    transforms = T.Compose([T.ToTensor(), T.Normalize((0.5,),(0.5)), 
                            T.Lambda(lambda x: torch.flatten(x))])
    dataset_test  = MNIST(root='/data/', transform=transforms, train=False)
    test_args  = {'dataset':dataset_test, 
                  'batch_size':len(dataset_test), 
                  'shuffle':False, 
                  'num_workers':8}
    dataloader_test  = torch.utils.data.DataLoader(**test_args) 
    
    model.eval()
    predictions, targets = [], []
    with torch.no_grad():
        for (X, y) in dataloader_test:
            y_pred = model(X)
            predictions.append(y_pred.detach())
            targets.append(y)

        predictions = torch.cat(predictions, dim=0)
        targets = torch.cat(targets, dim=0)
        predictions = torch.argmax(F.log_softmax(predictions, dim=1),dim=1)
        corrects = (predictions == targets).sum().item()
        wrongs = len(targets) - corrects
        test_accuracy = corrects / len(targets)
        
    return  test_accuracy

In [4]:
max_epochs = 20
num_samples = 10

config = {'lr':tune.loguniform(1e-4, 1e-1), 
          'batch_size':tune.choice([32, 64, 256, 512]), 
          'p_d1':tune.uniform(0.1,0.9), 
          'p_d2':tune.uniform(0.1,0.9)}


scheduler = ASHAScheduler(metric='loss', 
                        mode='min',
                        max_t=20, 
                        grace_period=1, 
                        reduction_factor=2)

reporter  = CLIReporter(metric_columns=['loss', 'accuracy', 'training_iteration'])

result = tune.run(partial(train_mnist, epochs=max_epochs),
                  resources_per_trial={'cpu':2}, 
                  config=config, 
                  num_samples=num_samples, 
                  scheduler=scheduler, 
                  progress_reporter=reporter)

best_trial = result.get_best_trial('loss', 'min', 'last')

2021-07-04 20:27:44,548	INFO services.py:1272 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
2021-07-04 20:27:46,426	INFO registry.py:64 -- Detected unknown callable for trainable. Converting to class.
2021-07-04 20:27:46,723	ERROR syncer.py:72 -- Log sync requires rsync to be installed.


== Status ==
Memory usage on this node: 1.9/7.7 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 16.000: None | Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 0/8 CPUs, 0/0 GPUs, 0.0/3.97 GiB heap, 0.0/1.98 GiB objects
Result logdir: /home/chris/ray_results/DEFAULT_2021-07-04_20-27-46
Number of trials: 10/10 (10 PENDING)
+---------------------+----------+-------+--------------+-------------+----------+----------+
| Trial name          | status   | loc   |   batch_size |          lr |     p_d1 |     p_d2 |
|---------------------+----------+-------+--------------+-------------+----------+----------|
| DEFAULT_e953d_00000 | PENDING  |       |           64 | 0.000359741 | 0.776499 | 0.805166 |
| DEFAULT_e953d_00001 | PENDING  |       |          256 | 0.0163747   | 0.120924 | 0.76933  |
| DEFAULT_e953d_00002 | PENDING  |       |          512 | 0.0015177   | 0.832719 | 0.173309 |
| DEFAULT_e953d_00003 | PENDING  |       |          512 | 0.

[2m[36m(pid=120442)[0m   return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
[2m[36m(pid=120447)[0m   return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
[2m[36m(pid=120449)[0m   return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
[2m[36m(pid=120443)[0m   return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


== Status ==
Memory usage on this node: 2.8/7.7 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 16.000: None | Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 8.0/8 CPUs, 0/0 GPUs, 0.0/3.97 GiB heap, 0.0/1.98 GiB objects
Result logdir: /home/chris/ray_results/DEFAULT_2021-07-04_20-27-46
Number of trials: 10/10 (6 PENDING, 4 RUNNING)
+---------------------+----------+-------+--------------+-------------+----------+----------+
| Trial name          | status   | loc   |   batch_size |          lr |     p_d1 |     p_d2 |
|---------------------+----------+-------+--------------+-------------+----------+----------|
| DEFAULT_e953d_00000 | RUNNING  |       |           64 | 0.000359741 | 0.776499 | 0.805166 |
| DEFAULT_e953d_00001 | RUNNING  |       |          256 | 0.0163747   | 0.120924 | 0.76933  |
| DEFAULT_e953d_00002 | RUNNING  |       |          512 | 0.0015177   | 0.832719 | 0.173309 |
| DEFAULT_e953d_00003 | RUNNING  |       |      

[2m[36m(pid=120446)[0m   return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Result for DEFAULT_e953d_00000:
  accuracy: 0.38625
  date: 2021-07-04_20-28-00
  done: true
  experiment_id: ff1f4d46e68e4c929467767d40b508ae
  hostname: chris-server
  iterations_since_restore: 1
  loss: 2.22209095954895
  node_ip: 192.168.1.58
  pid: 120442
  should_checkpoint: true
  time_since_restore: 13.011545419692993
  time_this_iter_s: 13.011545419692993
  time_total_s: 13.011545419692993
  timestamp: 1625426880
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: e953d_00000
  
== Status ==
Memory usage on this node: 2.6/7.7 GiB
Using AsyncHyperBand: num_stopped=2
Bracket: Iter 16.000: None | Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: -2.218522310256958
Resources requested: 8.0/8 CPUs, 0/0 GPUs, 0.0/3.97 GiB heap, 0.0/1.98 GiB objects (0.0/2.0 CPU_group_0_fbe0b91c8b2458be202dda7d5353dad3, 0.0/2.0 CPU_group_5ab07ccfc41880e219351fea009c30ac, 0.0/2.0 CPU_group_168530008168d15bc49a2e2d936da2f9, 0.0/2.0 CPU_group_0_168530008168d15bc49a2e2d936

[2m[36m(pid=120448)[0m   return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Result for DEFAULT_e953d_00002:
  accuracy: 0.40925
  date: 2021-07-04_20-28-03
  done: false
  experiment_id: 1549d07bf7d3489f8d17b8dd3a245ed8
  hostname: chris-server
  iterations_since_restore: 2
  loss: 2.061502695083618
  node_ip: 192.168.1.58
  pid: 120449
  should_checkpoint: true
  time_since_restore: 15.595103025436401
  time_this_iter_s: 7.056028127670288
  time_total_s: 15.595103025436401
  timestamp: 1625426883
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: e953d_00002
  
Result for DEFAULT_e953d_00001:
  accuracy: 0.9023333333333333
  date: 2021-07-04_20-28-03
  done: false
  experiment_id: 9cb516a0ea654c30b4daf30356d5c5f8
  hostname: chris-server
  iterations_since_restore: 2
  loss: 0.4278530180454254
  node_ip: 192.168.1.58
  pid: 120443
  should_checkpoint: true
  time_since_restore: 15.936234474182129
  time_this_iter_s: 7.03092098236084
  time_total_s: 15.936234474182129
  timestamp: 1625426883
  timesteps_since_restore: 0
  training_iteration: 2
  

Result for DEFAULT_e953d_00001:
  accuracy: 0.92
  date: 2021-07-04_20-28-19
  done: false
  experiment_id: 9cb516a0ea654c30b4daf30356d5c5f8
  hostname: chris-server
  iterations_since_restore: 4
  loss: 0.3156189024448395
  node_ip: 192.168.1.58
  pid: 120443
  should_checkpoint: true
  time_since_restore: 31.8237361907959
  time_this_iter_s: 8.172274112701416
  time_total_s: 31.8237361907959
  timestamp: 1625426899
  timesteps_since_restore: 0
  training_iteration: 4
  trial_id: e953d_00001
  
Result for DEFAULT_e953d_00004:
  accuracy: 0.2638333333333333
  date: 2021-07-04_20-28-22
  done: true
  experiment_id: 89991063bd6948bd860f491341e0877c
  hostname: chris-server
  iterations_since_restore: 2
  loss: 2.0606777667999268
  node_ip: 192.168.1.58
  pid: 120446
  should_checkpoint: true
  time_since_restore: 23.785216569900513
  time_this_iter_s: 12.630645036697388
  time_total_s: 23.785216569900513
  timestamp: 1625426902
  timesteps_since_restore: 0
  training_iteration: 2
  trial

[2m[36m(pid=120444)[0m   return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Result for DEFAULT_e953d_00002:
  accuracy: 0.6929166666666666
  date: 2021-07-04_20-28-25
  done: false
  experiment_id: 1549d07bf7d3489f8d17b8dd3a245ed8
  hostname: chris-server
  iterations_since_restore: 5
  loss: 1.3755099773406982
  node_ip: 192.168.1.58
  pid: 120449
  should_checkpoint: true
  time_since_restore: 37.883606910705566
  time_this_iter_s: 7.163927316665649
  time_total_s: 37.883606910705566
  timestamp: 1625426905
  timesteps_since_restore: 0
  training_iteration: 5
  trial_id: e953d_00002
  
Result for DEFAULT_e953d_00001:
  accuracy: 0.93075
  date: 2021-07-04_20-28-27
  done: false
  experiment_id: 9cb516a0ea654c30b4daf30356d5c5f8
  hostname: chris-server
  iterations_since_restore: 5
  loss: 0.2873825132846832
  node_ip: 192.168.1.58
  pid: 120443
  should_checkpoint: true
  time_since_restore: 40.04896950721741
  time_this_iter_s: 8.225233316421509
  time_total_s: 40.04896950721741
  timestamp: 1625426907
  timesteps_since_restore: 0
  training_iteration: 5
  

[2m[36m(pid=120445)[0m   return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Result for DEFAULT_e953d_00002:
  accuracy: 0.7499166666666667
  date: 2021-07-04_20-28-40
  done: false
  experiment_id: 1549d07bf7d3489f8d17b8dd3a245ed8
  hostname: chris-server
  iterations_since_restore: 7
  loss: 1.1118183135986328
  node_ip: 192.168.1.58
  pid: 120449
  should_checkpoint: true
  time_since_restore: 53.24133229255676
  time_this_iter_s: 7.398775577545166
  time_total_s: 53.24133229255676
  timestamp: 1625426920
  timesteps_since_restore: 0
  training_iteration: 7
  trial_id: e953d_00002
  
Result for DEFAULT_e953d_00001:
  accuracy: 0.93725
  date: 2021-07-04_20-28-42
  done: false
  experiment_id: 9cb516a0ea654c30b4daf30356d5c5f8
  hostname: chris-server
  iterations_since_restore: 7
  loss: 0.25126054883003235
  node_ip: 192.168.1.58
  pid: 120443
  should_checkpoint: true
  time_since_restore: 55.217782497406006
  time_this_iter_s: 7.880500793457031
  time_total_s: 55.217782497406006
  timestamp: 1625426922
  timesteps_since_restore: 0
  training_iteration: 7
 

Result for DEFAULT_e953d_00001:
  accuracy: 0.9424166666666667
  date: 2021-07-04_20-28-58
  done: false
  experiment_id: 9cb516a0ea654c30b4daf30356d5c5f8
  hostname: chris-server
  iterations_since_restore: 9
  loss: 0.23115980625152588
  node_ip: 192.168.1.58
  pid: 120443
  should_checkpoint: true
  time_since_restore: 70.93714332580566
  time_this_iter_s: 8.002756118774414
  time_total_s: 70.93714332580566
  timestamp: 1625426938
  timesteps_since_restore: 0
  training_iteration: 9
  trial_id: e953d_00001
  
Result for DEFAULT_e953d_00005:
  accuracy: 0.8343333333333334
  date: 2021-07-04_20-29-00
  done: false
  experiment_id: 0384a9bc828947e9aba84120999a55f8
  hostname: chris-server
  iterations_since_restore: 5
  loss: 0.8173239827156067
  node_ip: 192.168.1.58
  pid: 120448
  should_checkpoint: true
  time_since_restore: 58.98025965690613
  time_this_iter_s: 13.209801435470581
  time_total_s: 58.98025965690613
  timestamp: 1625426940
  timesteps_since_restore: 0
  training_iter

Result for DEFAULT_e953d_00002:
  accuracy: 0.8169166666666666
  date: 2021-07-04_20-29-12
  done: false
  experiment_id: 1549d07bf7d3489f8d17b8dd3a245ed8
  hostname: chris-server
  iterations_since_restore: 11
  loss: 0.8455222845077515
  node_ip: 192.168.1.58
  pid: 120449
  should_checkpoint: true
  time_since_restore: 84.75358057022095
  time_this_iter_s: 7.704926252365112
  time_total_s: 84.75358057022095
  timestamp: 1625426952
  timesteps_since_restore: 0
  training_iteration: 11
  trial_id: e953d_00002
  
Result for DEFAULT_e953d_00007:
  accuracy: 0.7021666666666667
  date: 2021-07-04_20-29-12
  done: true
  experiment_id: e6d5766b45bc401195f3fed9370cd43a
  hostname: chris-server
  iterations_since_restore: 4
  loss: 1.2410041093826294
  node_ip: 192.168.1.58
  pid: 120445
  should_checkpoint: true
  time_since_restore: 32.15148568153381
  time_this_iter_s: 7.884135961532593
  time_total_s: 32.15148568153381
  timestamp: 1625426952
  timesteps_since_restore: 0
  training_itera



Result for DEFAULT_e953d_00001:
  accuracy: 0.9438333333333333
  date: 2021-07-04_20-29-13
  done: false
  experiment_id: 9cb516a0ea654c30b4daf30356d5c5f8
  hostname: chris-server
  iterations_since_restore: 11
  loss: 0.21999235451221466
  node_ip: 192.168.1.58
  pid: 120443
  should_checkpoint: true
  time_since_restore: 86.30267238616943
  time_this_iter_s: 7.309501647949219
  time_total_s: 86.30267238616943
  timestamp: 1625426953
  timesteps_since_restore: 0
  training_iteration: 11
  trial_id: e953d_00001
  


[2m[36m(pid=126983)[0m   return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Result for DEFAULT_e953d_00002:
  accuracy: 0.8251666666666667
  date: 2021-07-04_20-29-19
  done: false
  experiment_id: 1549d07bf7d3489f8d17b8dd3a245ed8
  hostname: chris-server
  iterations_since_restore: 12
  loss: 0.8026316165924072
  node_ip: 192.168.1.58
  pid: 120449
  should_checkpoint: true
  time_since_restore: 92.13428044319153
  time_this_iter_s: 7.380699872970581
  time_total_s: 92.13428044319153
  timestamp: 1625426959
  timesteps_since_restore: 0
  training_iteration: 12
  trial_id: e953d_00002
  
== Status ==
Memory usage on this node: 2.6/7.7 GiB
Using AsyncHyperBand: num_stopped=5
Bracket: Iter 16.000: None | Iter 8.000: -0.6302493587136269 | Iter 4.000: -1.0702374577522278 | Iter 2.000: -1.9457329511642456 | Iter 1.000: -2.2034801244735718
Resources requested: 8.0/8 CPUs, 0/0 GPUs, 0.0/3.97 GiB heap, 0.0/1.98 GiB objects (0.0/2.0 CPU_group_fbe0b91c8b2458be202dda7d5353dad3, 0.0/2.0 CPU_group_0_f5ea1d0973e2a800f78b19831dc9004f, 0.0/2.0 CPU_group_0_fbe0b91c8b2458be202d

Result for DEFAULT_e953d_00001:
  accuracy: 0.9466666666666667
  date: 2021-07-04_20-29-28
  done: false
  experiment_id: 9cb516a0ea654c30b4daf30356d5c5f8
  hostname: chris-server
  iterations_since_restore: 13
  loss: 0.21568216383457184
  node_ip: 192.168.1.58
  pid: 120443
  should_checkpoint: true
  time_since_restore: 101.15271377563477
  time_this_iter_s: 7.604590654373169
  time_total_s: 101.15271377563477
  timestamp: 1625426968
  timesteps_since_restore: 0
  training_iteration: 13
  trial_id: e953d_00001
  
Result for DEFAULT_e953d_00008:
  accuracy: 0.7845833333333333
  date: 2021-07-04_20-29-30
  done: false
  experiment_id: 83644ace4c2c4ae2a7b3ba039255a8f9
  hostname: chris-server
  iterations_since_restore: 1
  loss: 0.9172196984291077
  node_ip: 192.168.1.58
  pid: 126983
  should_checkpoint: true
  time_since_restore: 14.558991193771362
  time_this_iter_s: 14.558991193771362
  time_total_s: 14.558991193771362
  timestamp: 1625426970
  timesteps_since_restore: 0
  trainin



Result for DEFAULT_e953d_00002:
  accuracy: 0.8374166666666667
  date: 2021-07-04_20-29-33
  done: false
  experiment_id: 1549d07bf7d3489f8d17b8dd3a245ed8
  hostname: chris-server
  iterations_since_restore: 14
  loss: 0.7466351985931396
  node_ip: 192.168.1.58
  pid: 120449
  should_checkpoint: true
  time_since_restore: 105.99567484855652
  time_this_iter_s: 6.443282604217529
  time_total_s: 105.99567484855652
  timestamp: 1625426973
  timesteps_since_restore: 0
  training_iteration: 14
  trial_id: e953d_00002
  


[2m[36m(pid=128385)[0m   return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Result for DEFAULT_e953d_00001:
  accuracy: 0.94
  date: 2021-07-04_20-29-35
  done: false
  experiment_id: 9cb516a0ea654c30b4daf30356d5c5f8
  hostname: chris-server
  iterations_since_restore: 14
  loss: 0.2266676425933838
  node_ip: 192.168.1.58
  pid: 120443
  should_checkpoint: true
  time_since_restore: 108.20400977134705
  time_this_iter_s: 7.05129599571228
  time_total_s: 108.20400977134705
  timestamp: 1625426975
  timesteps_since_restore: 0
  training_iteration: 14
  trial_id: e953d_00001
  
Result for DEFAULT_e953d_00002:
  accuracy: 0.8410833333333333
  date: 2021-07-04_20-29-40
  done: false
  experiment_id: 1549d07bf7d3489f8d17b8dd3a245ed8
  hostname: chris-server
  iterations_since_restore: 15
  loss: 0.7136115431785583
  node_ip: 192.168.1.58
  pid: 120449
  should_checkpoint: true
  time_since_restore: 113.17044925689697
  time_this_iter_s: 7.174774408340454
  time_total_s: 113.17044925689697
  timestamp: 1625426980
  timesteps_since_restore: 0
  training_iteration: 15


Result for DEFAULT_e953d_00002:
  accuracy: 0.8506666666666667
  date: 2021-07-04_20-29-47
  done: false
  experiment_id: 1549d07bf7d3489f8d17b8dd3a245ed8
  hostname: chris-server
  iterations_since_restore: 16
  loss: 0.6999717354774475
  node_ip: 192.168.1.58
  pid: 120449
  should_checkpoint: true
  time_since_restore: 119.6470582485199
  time_this_iter_s: 6.476608991622925
  time_total_s: 119.6470582485199
  timestamp: 1625426987
  timesteps_since_restore: 0
  training_iteration: 16
  trial_id: e953d_00002
  
Result for DEFAULT_e953d_00001:
  accuracy: 0.9479166666666666
  date: 2021-07-04_20-29-50
  done: false
  experiment_id: 9cb516a0ea654c30b4daf30356d5c5f8
  hostname: chris-server
  iterations_since_restore: 16
  loss: 0.20862366259098053
  node_ip: 192.168.1.58
  pid: 120443
  should_checkpoint: true
  time_since_restore: 123.08109593391418
  time_this_iter_s: 7.1991307735443115
  time_total_s: 123.08109593391418
  timestamp: 1625426990
  timesteps_since_restore: 0
  training

Result for DEFAULT_e953d_00002:
  accuracy: 0.8595
  date: 2021-07-04_20-30-01
  done: false
  experiment_id: 1549d07bf7d3489f8d17b8dd3a245ed8
  hostname: chris-server
  iterations_since_restore: 18
  loss: 0.6576714515686035
  node_ip: 192.168.1.58
  pid: 120449
  should_checkpoint: true
  time_since_restore: 133.5637993812561
  time_this_iter_s: 6.603286981582642
  time_total_s: 133.5637993812561
  timestamp: 1625427001
  timesteps_since_restore: 0
  training_iteration: 18
  trial_id: e953d_00002
  
Result for DEFAULT_e953d_00001:
  accuracy: 0.9514166666666667
  date: 2021-07-04_20-30-05
  done: false
  experiment_id: 9cb516a0ea654c30b4daf30356d5c5f8
  hostname: chris-server
  iterations_since_restore: 18
  loss: 0.19662445783615112
  node_ip: 192.168.1.58
  pid: 120443
  should_checkpoint: true
  time_since_restore: 137.3642065525055
  time_this_iter_s: 7.508636474609375
  time_total_s: 137.3642065525055
  timestamp: 1625427005
  timesteps_since_restore: 0
  training_iteration: 18


Result for DEFAULT_e953d_00008:
  accuracy: 0.88825
  date: 2021-07-04_20-30-16
  done: false
  experiment_id: 83644ace4c2c4ae2a7b3ba039255a8f9
  hostname: chris-server
  iterations_since_restore: 4
  loss: 0.42959773540496826
  node_ip: 192.168.1.58
  pid: 126983
  should_checkpoint: true
  time_since_restore: 59.64641571044922
  time_this_iter_s: 15.37753963470459
  time_total_s: 59.64641571044922
  timestamp: 1625427016
  timesteps_since_restore: 0
  training_iteration: 4
  trial_id: e953d_00008
  
Result for DEFAULT_e953d_00009:
  accuracy: 0.9010833333333333
  date: 2021-07-04_20-30-17
  done: false
  experiment_id: 6a950c29a43b4cb196f58d92aa0d6cbe
  hostname: chris-server
  iterations_since_restore: 4
  loss: 0.3381090760231018
  node_ip: 192.168.1.58
  pid: 128385
  should_checkpoint: true
  time_since_restore: 42.57322883605957
  time_this_iter_s: 9.348452091217041
  time_total_s: 42.57322883605957
  timestamp: 1625427017
  timesteps_since_restore: 0
  training_iteration: 4
  t

Result for DEFAULT_e953d_00008:
  accuracy: 0.9045
  date: 2021-07-04_20-30-30
  done: false
  experiment_id: 83644ace4c2c4ae2a7b3ba039255a8f9
  hostname: chris-server
  iterations_since_restore: 6
  loss: 0.36765313148498535
  node_ip: 192.168.1.58
  pid: 126983
  should_checkpoint: true
  time_since_restore: 74.29365134239197
  time_this_iter_s: 7.356085538864136
  time_total_s: 74.29365134239197
  timestamp: 1625427030
  timesteps_since_restore: 0
  training_iteration: 6
  trial_id: e953d_00008
  
Result for DEFAULT_e953d_00009:
  accuracy: 0.9153333333333333
  date: 2021-07-04_20-30-34
  done: false
  experiment_id: 6a950c29a43b4cb196f58d92aa0d6cbe
  hostname: chris-server
  iterations_since_restore: 7
  loss: 0.2944280505180359
  node_ip: 192.168.1.58
  pid: 128385
  should_checkpoint: true
  time_since_restore: 58.85636854171753
  time_this_iter_s: 5.35017728805542
  time_total_s: 58.85636854171753
  timestamp: 1625427034
  timesteps_since_restore: 0
  training_iteration: 7
  tri

Result for DEFAULT_e953d_00008:
  accuracy: 0.9080833333333334
  date: 2021-07-04_20-30-44
  done: false
  experiment_id: 83644ace4c2c4ae2a7b3ba039255a8f9
  hostname: chris-server
  iterations_since_restore: 8
  loss: 0.3302193880081177
  node_ip: 192.168.1.58
  pid: 126983
  should_checkpoint: true
  time_since_restore: 88.45154857635498
  time_this_iter_s: 7.162126541137695
  time_total_s: 88.45154857635498
  timestamp: 1625427044
  timesteps_since_restore: 0
  training_iteration: 8
  trial_id: e953d_00008
  
Result for DEFAULT_e953d_00009:
  accuracy: 0.91025
  date: 2021-07-04_20-30-50
  done: false
  experiment_id: 6a950c29a43b4cb196f58d92aa0d6cbe
  hostname: chris-server
  iterations_since_restore: 10
  loss: 0.2965039610862732
  node_ip: 192.168.1.58
  pid: 128385
  should_checkpoint: true
  time_since_restore: 75.14253330230713
  time_this_iter_s: 5.840113878250122
  time_total_s: 75.14253330230713
  timestamp: 1625427050
  timesteps_since_restore: 0
  training_iteration: 10
  

Result for DEFAULT_e953d_00009:
  accuracy: 0.91725
  date: 2021-07-04_20-31-00
  done: false
  experiment_id: 6a950c29a43b4cb196f58d92aa0d6cbe
  hostname: chris-server
  iterations_since_restore: 12
  loss: 0.3101970851421356
  node_ip: 192.168.1.58
  pid: 128385
  should_checkpoint: true
  time_since_restore: 85.55218887329102
  time_this_iter_s: 5.0559961795806885
  time_total_s: 85.55218887329102
  timestamp: 1625427060
  timesteps_since_restore: 0
  training_iteration: 12
  trial_id: e953d_00009
  
== Status ==
Memory usage on this node: 2.1/7.7 GiB
Using AsyncHyperBand: num_stopped=8
Bracket: Iter 16.000: -0.454297699034214 | Iter 8.000: -0.3302193880081177 | Iter 4.000: -0.6645342707633972 | Iter 2.000: -1.5008934140205383 | Iter 1.000: -2.177282691001892
Resources requested: 4.0/8 CPUs, 0/0 GPUs, 0.0/3.97 GiB heap, 0.0/1.98 GiB objects (0.0/2.0 CPU_group_fbe0b91c8b2458be202dda7d5353dad3, 0.0/2.0 CPU_group_168530008168d15bc49a2e2d936da2f9, 0.0/2.0 CPU_group_0_fbe0b91c8b2458be202

Result for DEFAULT_e953d_00008:
  accuracy: 0.9174166666666667
  date: 2021-07-04_20-31-13
  done: false
  experiment_id: 83644ace4c2c4ae2a7b3ba039255a8f9
  hostname: chris-server
  iterations_since_restore: 12
  loss: 0.3001546859741211
  node_ip: 192.168.1.58
  pid: 126983
  should_checkpoint: true
  time_since_restore: 117.51355981826782
  time_this_iter_s: 7.33526086807251
  time_total_s: 117.51355981826782
  timestamp: 1625427073
  timesteps_since_restore: 0
  training_iteration: 12
  trial_id: e953d_00008
  
Result for DEFAULT_e953d_00009:
  accuracy: 0.9254166666666667
  date: 2021-07-04_20-31-17
  done: false
  experiment_id: 6a950c29a43b4cb196f58d92aa0d6cbe
  hostname: chris-server
  iterations_since_restore: 15
  loss: 0.2665010094642639
  node_ip: 192.168.1.58
  pid: 128385
  should_checkpoint: true
  time_since_restore: 101.89608073234558
  time_this_iter_s: 5.32626485824585
  time_total_s: 101.89608073234558
  timestamp: 1625427077
  timesteps_since_restore: 0
  training_i

Result for DEFAULT_e953d_00008:
  accuracy: 0.9220833333333334
  date: 2021-07-04_20-31-28
  done: false
  experiment_id: 83644ace4c2c4ae2a7b3ba039255a8f9
  hostname: chris-server
  iterations_since_restore: 14
  loss: 0.2851559519767761
  node_ip: 192.168.1.58
  pid: 126983
  should_checkpoint: true
  time_since_restore: 131.62208127975464
  time_this_iter_s: 7.133996486663818
  time_total_s: 131.62208127975464
  timestamp: 1625427088
  timesteps_since_restore: 0
  training_iteration: 14
  trial_id: e953d_00008
  
Result for DEFAULT_e953d_00009:
  accuracy: 0.9264166666666667
  date: 2021-07-04_20-31-33
  done: false
  experiment_id: 6a950c29a43b4cb196f58d92aa0d6cbe
  hostname: chris-server
  iterations_since_restore: 18
  loss: 0.27444714307785034
  node_ip: 192.168.1.58
  pid: 128385
  should_checkpoint: true
  time_since_restore: 117.99564337730408
  time_this_iter_s: 5.699957370758057
  time_total_s: 117.99564337730408
  timestamp: 1625427093
  timesteps_since_restore: 0
  trainin

2021-07-04 20:31:43,549	INFO tune.py:549 -- Total run time: 237.13 seconds (236.96 seconds for the tuning loop).


Result for DEFAULT_e953d_00009:
  accuracy: 0.91875
  date: 2021-07-04_20-31-43
  done: true
  experiment_id: 6a950c29a43b4cb196f58d92aa0d6cbe
  hostname: chris-server
  iterations_since_restore: 20
  loss: 0.29735755920410156
  node_ip: 192.168.1.58
  pid: 128385
  should_checkpoint: true
  time_since_restore: 128.07539629936218
  time_this_iter_s: 4.778563499450684
  time_total_s: 128.07539629936218
  timestamp: 1625427103
  timesteps_since_restore: 0
  training_iteration: 20
  trial_id: e953d_00009
  
== Status ==
Memory usage on this node: 1.7/7.7 GiB
Using AsyncHyperBand: num_stopped=10
Bracket: Iter 16.000: -0.2786775827407837 | Iter 8.000: -0.3302193880081177 | Iter 4.000: -0.6645342707633972 | Iter 2.000: -1.5008934140205383 | Iter 1.000: -2.177282691001892
Resources requested: 0/8 CPUs, 0/0 GPUs, 0.0/3.97 GiB heap, 0.0/1.98 GiB objects (0.0/2.0 CPU_group_0_fbe0b91c8b2458be202dda7d5353dad3, 0.0/2.0 CPU_group_168530008168d15bc49a2e2d936da2f9, 0.0/2.0 CPU_group_0_168530008168d15b

In [5]:
print("Best trial config: {}".format(best_trial.config))
print("Best trial final validation loss: {}".format(
        best_trial.last_result["loss"]))
print("Best trial final validation accuracy: {}".format(
        best_trial.last_result["accuracy"]))

Best trial config: {'lr': 0.016374686378100446, 'batch_size': 256, 'p_d1': 0.12092449662468381, 'p_d2': 0.7693301753948671}
Best trial final validation loss: 0.19369226694107056
Best trial final validation accuracy: 0.9528333333333333


In [6]:
best_trained_model = FFNN(784, 10, 
                          p_d1=best_trial.config['p_d1'], 
                          p_d2=best_trial.config['p_d2'])

best_checkpoint_dir = best_trial.checkpoint.value
model_state, optimiser_state = torch.load(osp.join(best_checkpoint_dir, "checkpoint"))
best_trained_model.load_state_dict(model_state)
test_acc = test_accuracy(best_trained_model)
print("Best trial test set accuracy: {}".format(test_acc))

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Best trial test set accuracy: 0.9535
