In [None]:
import os
import sys; sys.path.append('../lib')
from functools import partial

import matplotlib.pyplot as plt

from assignment2 import visualize_learning_curves
from data import Cifar
from history import TrainHistory
from search import SearchResultSeries
from two_layer_fully_connected import TwoLayerFullyConnected

# Constants

In [None]:
DATA_DIR = '../data'
PICKLE_DIR = '../pickle'
FIGURE_DIR = '../figures'

HIDDEN_NODES = 50

# Load dataset

In [None]:
dataset = Cifar(DATA_DIR)

# Split into training, validation and test set

In [None]:
data_train, data_val, data_test = dataset.train_val_test_split(
    n_val=1000, normalize='zscore')

# Default network constructor

In [None]:
search_results = SearchResultSeries.load(PICKLE_DIR, postfix='fine')

default_network = partial(
    TwoLayerFullyConnected,
    input_size=data_train.input_size,
    hidden_nodes=HIDDEN_NODES,
    num_classes=data_train.num_classes,
    alpha=search_results.optimum()['alpha'],
    random_seed=0)

# Find good range for $\eta$

In [None]:
network = default_network()

network.lr_range_test(data_train,
                      eta_low=-5,
                      eta_high=-1,
                      logarithmic=True,
                      verbose=True)

plt.savefig(os.path.join(FIGURE_DIR, 'lr_range_log.svg'))

In [None]:
network = default_network()

network.lr_range_test(data_train,
                      eta_low=1e-5,
                      eta_high=0.2,
                      verbose=True)

plt.savefig(os.path.join(FIGURE_DIR, 'lr_range_linear.svg'))

In [None]:
eta_max = 0.025
eta_min = 1e-4

# Train and evaluate network

In [None]:
network = default_network()

history = network.train_cyclic(data_train,
                               data_val,
                               eta_min=eta_min,
                               eta_max=eta_max,
                               eta_ss=(2 * data_train.n // 100),
                               n_cycles=3,
                               verbose=True)

history.save(PICKLE_DIR, postfix='lr_range_train')

In [None]:
history = TrainHistory.load(PICKLE_DIR, postfix='lr_range_train')

In [None]:
visualize_learning_curves(history)

plt.savefig(os.path.join(FIGURE_DIR, 'curves_lr_range_train.svg'))

In [None]:
history.final_network.visualize_performance(data_test)

plt.savefig(os.path.join(FIGURE_DIR, 'performance_lr_range_train.svg'))