In [1]:
import dill
from functools import partial
import json
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray.tune import JupyterNotebookReporter
from ray.tune.schedulers import ASHAScheduler

from my_models import (AlexNet, VGG16, ResNet)


In [2]:
%env CUDA_VISIBLE_DEVICES=0
device = torch.device("cuda:0")

env: CUDA_VISIBLE_DEVICES=0


In [3]:
input_channels = 3
output_channels = 10
resnet_output_shapes = [64, 128, 256, 512]
resnet_layers_depths = [2,2,2,2]

models_list = [AlexNet, VGG16, ResNet]

In [4]:
# создание папок для логов разых моделей:
checkpoint_dir="./data/checkpoints/"
"""
for name in models_list:    
    if not os.path.isdir(checkpoint_dir+str(name)):
        os.makedirs(checkpoint_dir+str(name))
"""       
if not os.path.isdir(checkpoint_dir+str("_all_models")):
    os.makedirs(checkpoint_dir+str("_all_models"))

In [5]:
# исходники: https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html

In [7]:
def train_cifar(config, epoch_num=2,
                checkpoint_dir=checkpoint_dir, data_dir=None): 
    net = config["models"](input_channels, 
                           output_channels, 
                           resnet_output_shapes, 
                           resnet_layers_depths)
    net.to(device)
    
    criterion = config["losses"]() # для итерирования разных losses
    optimizer = config["optimizers_names"](net.parameters(), lr=config["lr"])

    trainset, testset = load_data(data_dir)

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs])

    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True)
    valloader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True)

    for epoch in range(epoch_num):  # loop over the dataset multiple times
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        with tune.checkpoint_dir(epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save((net.state_dict(), optimizer.state_dict()), path)

        tune.report(loss=(val_loss / val_steps), accuracy=correct / total)

    print("Finished Training")

In [8]:
def main(num_samples=1, stop_criteria=20, max_num_epochs=20):    
    data_dir = os.path.abspath("./data/CIFAR")
    load_data(data_dir)

    config = {        
        "lr":tune.grid_search([1e-3, 1e-4]),   
        "batch_size": 10000,#tune.grid_search([100, 1000]),
        "optimizers_names":  tune.grid_search([optim.Adam, optim.SGD]),
        "losses": nn.CrossEntropyLoss,
        "models": tune.grid_search(models_list)
    }
    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=max_num_epochs,
        grace_period=stop_criteria)
    reporter = JupyterNotebookReporter(
        overwrite = True,
        print_intermediate_tables = True,
        metric_columns=["loss", "accuracy", "precision", "training_iteration"])
    result = tune.run(
        partial(train_cifar, data_dir=data_dir,
                checkpoint_dir=checkpoint_dir, epoch_num=max_num_epochs),
        name = 'CIFAR',
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        resources_per_trial = {"gpu": 1},
        local_dir= checkpoint_dir+"_all_models")
    
    best_trial = result.get_best_trial("loss", "min", "last")
    print("Best trial config: {}".format(best_trial.config))
    print("Best trial final validation loss: {}".format(
        best_trial.last_result["loss"]))
    print("Best trial final validation accuracy: {}".format(
        best_trial.last_result["accuracy"]))    

    return result, best_trial   

In [9]:
result, best_trial = main(num_samples=1, stop_criteria=30, max_num_epochs=1000)
best_trial_conf = best_trial.config

Trial name,status,loc,lr,models,optimizers_names,loss,accuracy,training_iteration
DEFAULT_2886a_00000,TERMINATED,,0.001,<class 'my_models.AlexNet'>,<class 'torch.optim.adam.Adam'>,1.38993,0.5835,1000.0
DEFAULT_2886a_00001,TERMINATED,,0.0001,<class 'my_models.AlexNet'>,<class 'torch.optim.adam.Adam'>,2.05037,0.2703,30.0
DEFAULT_2886a_00002,TERMINATED,,0.001,<class 'my_models.VGG16'>,<class 'torch.optim.adam.Adam'>,2.30259,0.1013,30.0
DEFAULT_2886a_00003,TERMINATED,,0.0001,<class 'my_models.VGG16'>,<class 'torch.optim.adam.Adam'>,2.30197,0.1063,30.0
DEFAULT_2886a_00006,TERMINATED,,0.001,<class 'my_models.AlexNet'>,<class 'torch.optim.sgd.SGD'>,2.30516,0.1007,30.0
DEFAULT_2886a_00007,TERMINATED,,0.0001,<class 'my_models.AlexNet'>,<class 'torch.optim.sgd.SGD'>,2.30487,0.1034,30.0
DEFAULT_2886a_00008,TERMINATED,,0.001,<class 'my_models.VGG16'>,<class 'torch.optim.sgd.SGD'>,2.3039,0.0976,30.0
DEFAULT_2886a_00009,TERMINATED,,0.0001,<class 'my_models.VGG16'>,<class 'torch.optim.sgd.SGD'>,2.30386,0.0989,30.0
DEFAULT_2886a_00004,ERROR,,0.001,<class 'my_models.ResNet'>,<class 'torch.optim.adam.Adam'>,,,
DEFAULT_2886a_00005,ERROR,,0.0001,<class 'my_models.ResNet'>,<class 'torch.optim.adam.Adam'>,,,

Trial name,# failures,error file
DEFAULT_2886a_00004,1,"/notebooks/sorokina/data/checkpoints/_all_models/CIFAR/DEFAULT_2886a_00004_4_lr=0.001,models=<class 'my_models.ResNet'>,optimizers_names=<class 'torch.optim.adam.Adam'>_2021-01-05_15-31-31/error.txt"
DEFAULT_2886a_00005,1,"/notebooks/sorokina/data/checkpoints/_all_models/CIFAR/DEFAULT_2886a_00005_5_lr=0.0001,models=<class 'my_models.ResNet'>,optimizers_names=<class 'torch.optim.adam.Adam'>_2021-01-05_15-31-38/error.txt"
DEFAULT_2886a_00010,1,"/notebooks/sorokina/data/checkpoints/_all_models/CIFAR/DEFAULT_2886a_00010_10_lr=0.001,models=<class 'my_models.ResNet'>,optimizers_names=<class 'torch.optim.sgd.SGD'>_2021-01-05_15-46-32/error.txt"
DEFAULT_2886a_00011,1,"/notebooks/sorokina/data/checkpoints/_all_models/CIFAR/DEFAULT_2886a_00011_11_lr=0.0001,models=<class 'my_models.ResNet'>,optimizers_names=<class 'torch.optim.sgd.SGD'>_2021-01-05_15-46-39/error.txt"


TuneError: ('Trials did not complete', [DEFAULT_2886a_00004, DEFAULT_2886a_00005, DEFAULT_2886a_00010, DEFAULT_2886a_00011])

In [None]:
# Извлечение наилучшей модели:
best_trained_model = best_trial_conf["models"](input_channels, output_channels, [64, 128, 256, 512],[2, 2, 2, 2])
best_trained_model.to(device=device)
model_state, optimizer_state = torch.load(os.path.join(
    best_trial.checkpoint.value, "checkpoint"))
best_trained_model.load_state_dict(model_state)

# Сохранение модели и конфига:
model_class = best_trial_conf["models"]
del best_trial_conf["models"]
with torch.no_grad():
    torch.save({"model_instance" : best_trained_model, 
                "model_class" : model_class,
                "config" : best_trial_conf},
               "./data/best_trials_info/best_trial_model_and_config.txt", 
               pickle_module=dill)