In [2]:
import time
import numpy as np
import matplotlib.pyplot as plt
import pandas

import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss, Conv2d, Sequential, BatchNorm2d
from torch.utils.data import DataLoader

import torchvision
import torchvision.datasets as datasets
from torchvision import transforms

from torchinfo import summary

from ActiveShiftLayer import ASL
from util import test_loss, train_NN

from ray import tune
from ray.tune.search.bayesopt import BayesOptSearch

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)

cuda


In [4]:
batch_size = 100

#transform images into normalized tensors
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

train_dataset = datasets.MNIST(
    "./data/MNIST",
    download=True,
    train=True,
    transform=transform,
)

test_dataset = datasets.MNIST(
    "./data/MNIST",
    download=True,
    train=False,
    transform=transform,
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=1,
    pin_memory=True,
)

In [5]:
from Models import LeASLNet
input_shape = (batch_size, 1, 28, 28)
criterion = CrossEntropyLoss()

def train_mnist(config):
    test_device = "cpu"
    model = LeASLNet(input_shape, 10, initial_lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"], device=test_device, expansion_rate=1).to(test_device)
    for i in range(2):
        train_NN(model, criterion, train_dataloader,
        test_dataloader, epochs=2, batches_to_test=100,patience=2,device=test_device, print_test=False, verbose=False)
        acc = test_loss(model, test_dataloader, criterion, test_device)[1]
        tune.report(mean_accuracy=acc)

analysis = tune.run(
    train_mnist,num_samples=15, config={"lr": tune.loguniform(0.001, 0.1), "momentum": tune.uniform(0.8, 1), "weight_decay": tune.uniform(0, 0.1)})

print("Best config: ", analysis.get_best_config(metric="mean_accuracy", mode="max",))

# Get a dataframe for analyzing trial results.
df = analysis.dataframe()

Trial name,status,loc,lr,momentum,weight_decay,acc,iter,total time (s)
train_mnist_44217_00000,TERMINATED,147.142.68.85:68312,0.00446645,0.993003,0.0399922,88.5,2,557.05
train_mnist_44217_00001,TERMINATED,147.142.68.85:68362,0.00252822,0.927413,0.00924895,95.89,2,554.376
train_mnist_44217_00002,TERMINATED,147.142.68.85:68364,0.00452166,0.881696,0.081959,78.7,2,556.69
train_mnist_44217_00003,TERMINATED,147.142.68.85:68368,0.0617364,0.89286,0.0637015,84.1,2,559.309
train_mnist_44217_00004,TERMINATED,147.142.68.85:68394,0.0189816,0.860582,0.0742808,80.58,2,567.034
train_mnist_44217_00005,TERMINATED,147.142.68.85:68420,0.00204454,0.903675,0.086614,78.12,2,567.705
train_mnist_44217_00006,TERMINATED,147.142.68.85:68446,0.0404193,0.881044,0.03029,88.91,2,569.063
train_mnist_44217_00007,TERMINATED,147.142.68.85:68473,0.00115685,0.969497,0.0585754,86.53,2,567.974
train_mnist_44217_00008,TERMINATED,147.142.68.85:68683,0.00100183,0.834387,0.021391,87.59,2,565.338
train_mnist_44217_00009,TERMINATED,147.142.68.85:68816,0.017448,0.936761,0.026766,90.64,2,559.398




Result for train_mnist_44217_00000:
  date: 2022-08-29_13-51-22
  done: false
  experiment_id: c386d6390e1944f5a0fcf79467e8ebc5
  hostname: max-Latitude-5401
  iterations_since_restore: 1
  mean_accuracy: 86.59
  node_ip: 147.142.68.85
  pid: 68312
  time_since_restore: 263.56469988822937
  time_this_iter_s: 263.56469988822937
  time_total_s: 263.56469988822937
  timestamp: 1661773882
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: '44217_00000'
  warmup_time: 0.002355813980102539
  
Result for train_mnist_44217_00001:
  date: 2022-08-29_13-51-29
  done: false
  experiment_id: ea3a79f2a8c04b69be51b1ad36850f88
  hostname: max-Latitude-5401
  iterations_since_restore: 1
  mean_accuracy: 92.61
  node_ip: 147.142.68.85
  pid: 68362
  time_since_restore: 266.13845586776733
  time_this_iter_s: 266.13845586776733
  time_total_s: 266.13845586776733
  timestamp: 1661773889
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: '44217_00001'
  warmup_time: 0.0037386417

2022-08-29 13:59:00,316	INFO tune.py:758 -- Total run time: 731.87 seconds (730.86 seconds for the tuning loop).


Best config:  {'lr': 0.0025282211848313913, 'momentum': 0.9274131119760397, 'weight_decay': 0.009248952703622115}


In [4]:
from Models import LeASLNet
input_shape = (batch_size, 1, 28, 28)
criterion = CrossEntropyLoss()

def train_mnist(config):
    test_device = "cpu"
    model = LeASLNet(input_shape, 10, initial_lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"], device=test_device, expansion_rate=1).to(test_device)
    for i in range(2):
        train_NN(model, criterion, train_dataloader,
        test_dataloader, epochs=2, batches_to_test=100,patience=2,device=test_device, print_test=False, verbose=False)
        acc = test_loss(model, test_dataloader, criterion, test_device)[1]
        tune.report(mean_accuracy=acc)

analysis = tune.run(
    train_mnist, config={"lr": tune.grid_search([0.001, 0.01, 0.1]), "momentum": tune.grid_search([0.9, 0.99]), "weight_decay": tune.grid_search([0, 0.01, 0.1])})

print("Best config: ", analysis.get_best_config(metric="mean_accuracy", mode="max",))

# Get a dataframe for analyzing trial results.
df = analysis.dataframe()

2022-08-29 14:19:38,106	INFO worker.py:1509 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


Trial name,status,loc,lr,momentum,weight_decay,acc,iter,total time (s)
train_mnist_dacf6_00000,TERMINATED,147.142.68.85:83875,0.001,0.9,0.0,91.45,2,536.509
train_mnist_dacf6_00001,TERMINATED,147.142.68.85:83927,0.01,0.9,0.0,97.93,2,562.865
train_mnist_dacf6_00002,TERMINATED,147.142.68.85:83930,0.1,0.9,0.0,98.69,2,557.857
train_mnist_dacf6_00003,TERMINATED,147.142.68.85:83956,0.001,0.99,0.0,97.32,2,558.735
train_mnist_dacf6_00004,TERMINATED,147.142.68.85:83959,0.01,0.99,0.0,98.02,2,550.848
train_mnist_dacf6_00005,TERMINATED,147.142.68.85:83982,0.1,0.99,0.0,10.09,2,558.601
train_mnist_dacf6_00006,TERMINATED,147.142.68.85:84160,0.001,0.9,0.01,91.36,2,575.619
train_mnist_dacf6_00007,TERMINATED,147.142.68.85:84163,0.01,0.9,0.01,83.8,2,572.41
train_mnist_dacf6_00008,TERMINATED,147.142.68.85:84349,0.1,0.9,0.01,96.83,2,570.339
train_mnist_dacf6_00009,TERMINATED,147.142.68.85:84404,0.001,0.99,0.01,95.29,2,577.206




Result for train_mnist_dacf6_00000:
  date: 2022-08-29_14-23-55
  done: false
  experiment_id: 7079a85c0fc04a629399b20373dc3293
  hostname: max-Latitude-5401
  iterations_since_restore: 1
  mean_accuracy: 86.9
  node_ip: 147.142.68.85
  pid: 83875
  time_since_restore: 245.56458640098572
  time_this_iter_s: 245.56458640098572
  time_total_s: 245.56458640098572
  timestamp: 1661775835
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: dacf6_00000
  warmup_time: 0.0023016929626464844
  
Result for train_mnist_dacf6_00002:
  date: 2022-08-29_14-24-17
  done: false
  experiment_id: 155fc01774114092ba5f488a31e405a0
  hostname: max-Latitude-5401
  iterations_since_restore: 1
  mean_accuracy: 97.72
  node_ip: 147.142.68.85
  pid: 83930
  time_since_restore: 262.0902826786041
  time_this_iter_s: 262.0902826786041
  time_total_s: 262.0902826786041
  timestamp: 1661775857
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: dacf6_00002
  warmup_time: 0.00828075408935546

2022-08-29 14:34:38,071	INFO tune.py:758 -- Total run time: 898.94 seconds (897.58 seconds for the tuning loop).


Best config:  {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0}


In [5]:
from Models import LeNet
input_shape = (batch_size, 1, 28, 28)
criterion = CrossEntropyLoss()

def train_mnist(config):
    test_device = "cpu"
    model = LeNet(input_shape, 10, initial_lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"]).to(test_device)
    for i in range(2):
        train_NN(model, criterion, train_dataloader,
        test_dataloader, epochs=2, batches_to_test=100,patience=2,device=test_device, print_test=False, verbose=False)
        acc = test_loss(model, test_dataloader, criterion, test_device)[1]
        tune.report(mean_accuracy=acc)

analysis = tune.run(
    train_mnist, config={"lr": tune.grid_search([0.001, 0.01, 0.1]), "momentum": tune.grid_search([0.9, 0.99]), "weight_decay": tune.grid_search([0, 0.01, 0.1])})

print("Best config: ", analysis.get_best_config(metric="mean_accuracy", mode="max",))

# Get a dataframe for analyzing trial results.
df = analysis.dataframe()

Trial name,status,loc,lr,momentum,weight_decay,acc,iter,total time (s)
train_mnist_d6d05_00000,TERMINATED,147.142.68.85:96840,0.001,0.9,0.0,95.77,2,430.805
train_mnist_d6d05_00001,TERMINATED,147.142.68.85:96891,0.01,0.9,0.0,94.67,2,444.224
train_mnist_d6d05_00002,TERMINATED,147.142.68.85:96898,0.1,0.9,0.0,99.07,2,444.087
train_mnist_d6d05_00003,TERMINATED,147.142.68.85:96920,0.001,0.99,0.0,98.35,2,450.634
train_mnist_d6d05_00004,TERMINATED,147.142.68.85:96922,0.01,0.99,0.0,98.67,2,438.771
train_mnist_d6d05_00005,TERMINATED,147.142.68.85:96962,0.1,0.99,0.0,30.74,2,444.898
train_mnist_d6d05_00006,TERMINATED,147.142.68.85:97124,0.001,0.9,0.01,96.17,2,461.305
train_mnist_d6d05_00007,TERMINATED,147.142.68.85:97127,0.01,0.9,0.01,93.32,2,460.013
train_mnist_d6d05_00008,TERMINATED,147.142.68.85:97310,0.1,0.9,0.01,97.27,2,455.669
train_mnist_d6d05_00009,TERMINATED,147.142.68.85:97393,0.001,0.99,0.01,95.66,2,457.808




Result for train_mnist_d6d05_00000:
  date: 2022-08-29_14-58-48
  done: false
  experiment_id: 6eb2d6ddb1ca433a98476d01ce30bb94
  hostname: max-Latitude-5401
  iterations_since_restore: 1
  mean_accuracy: 91.01
  node_ip: 147.142.68.85
  pid: 96840
  time_since_restore: 196.7659192085266
  time_this_iter_s: 196.7659192085266
  time_total_s: 196.7659192085266
  timestamp: 1661777928
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: d6d05_00000
  warmup_time: 0.0029163360595703125
  
Result for train_mnist_d6d05_00001:
  date: 2022-08-29_14-59-00
  done: false
  experiment_id: 09dead0ff7394abfba708c66d4c6fb7e
  hostname: max-Latitude-5401
  iterations_since_restore: 1
  mean_accuracy: 84.55
  node_ip: 147.142.68.85
  pid: 96891
  time_since_restore: 205.38097524642944
  time_this_iter_s: 205.38097524642944
  time_total_s: 205.38097524642944
  timestamp: 1661777940
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: d6d05_00001
  warmup_time: 0.0028080940246582

2022-08-29 15:07:06,676	INFO tune.py:758 -- Total run time: 706.66 seconds (705.67 seconds for the tuning loop).


Best config:  {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0}


In [5]:
from Models import LeNet
input_shape = (batch_size, 1, 28, 28)
criterion = CrossEntropyLoss()

def train_mnist(config):
    test_device = "cpu"
    model = LeNet(input_shape, 10, initial_lr=config["lr"], momentum=config["momentum"], weight_decay=0).to(test_device)
    for i in range(2):
        train_NN(model, criterion, train_dataloader,
        test_dataloader, epochs=2, batches_to_test=100,patience=2,device=test_device, print_test=False, verbose=False)
        acc = test_loss(model, test_dataloader, criterion, test_device)[1]
        tune.report(mean_accuracy=acc)

In [12]:
config = {"steps": 2,"lr": tune.uniform(0.001, 0.1), "momentum": tune.uniform(0.1, 1)}

bayesopt = BayesOptSearch(metric="mean_accuracy", mode="max")

tuner = tune.Tuner(train_mnist, tune_config=tune.TuneConfig(search_alg=bayesopt, num_samples=5), param_space=config)

analysis = tuner.fit()



Trial name,status,loc,lr,momentum,acc,iter,total time (s)
train_mnist_dca770a6,TERMINATED,129.206.61.139:129797,0.0380795,0.955643,98.94,2,185.733
train_mnist_debca8d4,TERMINATED,129.206.61.139:129848,0.0734674,0.638793,81.88,2,185.63
train_mnist_df3b427a,TERMINATED,129.206.61.139:129872,0.0164458,0.240395,84.22,2,187.914
train_mnist_dfc3da04,TERMINATED,129.206.61.139:129898,0.00675028,0.879559,98.01,2,187.512
train_mnist_e0607eea,TERMINATED,129.206.61.139:129939,0.0605104,0.737265,98.35,2,186.534




Result for train_mnist_dca770a6:
  date: 2022-08-29_16-44-26
  done: false
  experiment_id: 318f1ba9084d438bbb046ef1c981b5c6
  hostname: max-Latitude-5401
  iterations_since_restore: 1
  mean_accuracy: 98.49
  node_ip: 129.206.61.139
  pid: 129797
  time_since_restore: 89.69512009620667
  time_this_iter_s: 89.69512009620667
  time_total_s: 89.69512009620667
  timestamp: 1661784266
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: dca770a6
  warmup_time: 0.002826690673828125
  
Result for train_mnist_debca8d4:
  date: 2022-08-29_16-44-31
  done: false
  experiment_id: e7ef166a1f8d4b41b5c058e9f1dfb97a
  hostname: max-Latitude-5401
  iterations_since_restore: 1
  mean_accuracy: 95.53
  node_ip: 129.206.61.139
  pid: 129848
  time_since_restore: 91.39510655403137
  time_this_iter_s: 91.39510655403137
  time_total_s: 91.39510655403137
  timestamp: 1661784271
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: debca8d4
  warmup_time: 0.0028324127197265625
  
Resul

2022-08-29 16:46:10,268	INFO tune.py:758 -- Total run time: 198.28 seconds (197.33 seconds for the tuning loop).


In [13]:
print(analysis.get_best_result(metric="mean_accuracy", mode="max"))
analysis.get_dataframe()

Result(metrics={'mean_accuracy': 98.94, 'done': True, 'trial_id': 'dca770a6', 'experiment_tag': '1_lr=0.0381,momentum=0.9556,steps=2'}, error=None, log_dir=PosixPath('/home/max/ray_results/train_mnist_2022-08-29_16-42-51/train_mnist_dca770a6_1_lr=0.0381,momentum=0.9556,steps=2_2022-08-29_16-42-53'))


Unnamed: 0,mean_accuracy,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,trial_id,experiment_id,date,timestamp,...,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,warmup_time,config/lr,config/momentum,config/steps,logdir
0,98.94,96.037631,False,,,2,dca770a6,318f1ba9084d438bbb046ef1c981b5c6,2022-08-29_16-46-02,1661784362,...,max-Latitude-5401,129.206.61.139,185.732751,0,2,0.002827,0.038079,0.955643,2,/home/max/ray_results/train_mnist_2022-08-29_1...
1,81.88,94.234621,False,,,2,debca8d4,e7ef166a1f8d4b41b5c058e9f1dfb97a,2022-08-29_16-46-05,1661784365,...,max-Latitude-5401,129.206.61.139,185.629727,0,2,0.002832,0.073467,0.638793,2,/home/max/ray_results/train_mnist_2022-08-29_1...
2,84.22,94.828889,False,,,2,df3b427a,6f5dbd2f8aba4fd4bac247a56731ee81,2022-08-29_16-46-08,1661784368,...,max-Latitude-5401,129.206.61.139,187.913504,0,2,0.003958,0.016446,0.240395,2,/home/max/ray_results/train_mnist_2022-08-29_1...
3,98.01,94.461907,False,,,2,dfc3da04,50ed1c2ad36947cb827342474e38c73a,2022-08-29_16-46-09,1661784369,...,max-Latitude-5401,129.206.61.139,187.5116,0,2,0.004179,0.00675,0.879559,2,/home/max/ray_results/train_mnist_2022-08-29_1...
4,98.35,93.54927,False,,,2,e0607eea,bd08cbb745a344969bde54ecc3c103be,2022-08-29_16-46-10,1661784370,...,max-Latitude-5401,129.206.61.139,186.534145,0,2,0.005148,0.06051,0.737265,2,/home/max/ray_results/train_mnist_2022-08-29_1...
