In [4]:
import numpy as np
import matplotlib.pyplot as plt
import json

import torch
from torch.nn import CrossEntropyLoss, NLLLoss
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

from torchinfo import summary

from os.path import exists

from util import test_loss, train_NN

from ray import tune
from ray.tune.search.bayesopt import BayesOptSearch

In [5]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)

cuda


In [6]:
batch_size = 100

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=True,
                                        download=True, transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=False,
                                       download=True, transform=transform)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

input_shape = (batch_size, 3, 32, 32)
num_labels = 10

Files already downloaded and verified
Files already downloaded and verified


In [5]:
from Models import LeNet
input_shape = (batch_size, 1, 28, 28)
criterion = CrossEntropyLoss()

def train_mnist(config):
    test_device = "cpu"
    model = LeNet(input_shape, 10, initial_lr=config["lr"], momentum=config["momentum"], weight_decay=config["weight_decay"]).to(test_device)
    for i in range(2):
        train_NN(model, criterion, train_dataloader,
        test_dataloader, epochs=2, batches_to_test=100,patience=2,device=test_device, print_test=False, verbose=False)
        acc = test_loss(model, test_dataloader, criterion, test_device)[1]
        tune.report(mean_accuracy=acc)

analysis = tune.run(
    train_mnist, config={"lr": tune.grid_search([0.001, 0.01, 0.1]), "momentum": tune.grid_search([0.9, 0.99]), "weight_decay": tune.grid_search([0, 0.01, 0.1])})

print("Best config: ", analysis.get_best_config(metric="mean_accuracy", mode="max",))

# Get a dataframe for analyzing trial results.
df = analysis.dataframe()

Trial name,status,loc,lr,momentum,weight_decay,acc,iter,total time (s)
train_mnist_d6d05_00000,TERMINATED,147.142.68.85:96840,0.001,0.9,0.0,95.77,2,430.805
train_mnist_d6d05_00001,TERMINATED,147.142.68.85:96891,0.01,0.9,0.0,94.67,2,444.224
train_mnist_d6d05_00002,TERMINATED,147.142.68.85:96898,0.1,0.9,0.0,99.07,2,444.087
train_mnist_d6d05_00003,TERMINATED,147.142.68.85:96920,0.001,0.99,0.0,98.35,2,450.634
train_mnist_d6d05_00004,TERMINATED,147.142.68.85:96922,0.01,0.99,0.0,98.67,2,438.771
train_mnist_d6d05_00005,TERMINATED,147.142.68.85:96962,0.1,0.99,0.0,30.74,2,444.898
train_mnist_d6d05_00006,TERMINATED,147.142.68.85:97124,0.001,0.9,0.01,96.17,2,461.305
train_mnist_d6d05_00007,TERMINATED,147.142.68.85:97127,0.01,0.9,0.01,93.32,2,460.013
train_mnist_d6d05_00008,TERMINATED,147.142.68.85:97310,0.1,0.9,0.01,97.27,2,455.669
train_mnist_d6d05_00009,TERMINATED,147.142.68.85:97393,0.001,0.99,0.01,95.66,2,457.808




Result for train_mnist_d6d05_00000:
  date: 2022-08-29_14-58-48
  done: false
  experiment_id: 6eb2d6ddb1ca433a98476d01ce30bb94
  hostname: max-Latitude-5401
  iterations_since_restore: 1
  mean_accuracy: 91.01
  node_ip: 147.142.68.85
  pid: 96840
  time_since_restore: 196.7659192085266
  time_this_iter_s: 196.7659192085266
  time_total_s: 196.7659192085266
  timestamp: 1661777928
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: d6d05_00000
  warmup_time: 0.0029163360595703125
  
Result for train_mnist_d6d05_00001:
  date: 2022-08-29_14-59-00
  done: false
  experiment_id: 09dead0ff7394abfba708c66d4c6fb7e
  hostname: max-Latitude-5401
  iterations_since_restore: 1
  mean_accuracy: 84.55
  node_ip: 147.142.68.85
  pid: 96891
  time_since_restore: 205.38097524642944
  time_this_iter_s: 205.38097524642944
  time_total_s: 205.38097524642944
  timestamp: 1661777940
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: d6d05_00001
  warmup_time: 0.0028080940246582

2022-08-29 15:07:06,676	INFO tune.py:758 -- Total run time: 706.66 seconds (705.67 seconds for the tuning loop).


Best config:  {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0}
