# Imports

In [1]:
import torch
import torch.optim as optim
import numpy as np
import node_models
import loader
import training
import metrics
import autotune
import config

# GPU

In [2]:
device = torch.device("cuda")

# Data Loading

In [3]:
cifar_data_path = './cifar-10-batches-py'

In [4]:
cifar_data = loader.CIFAR(cifar_data_path, 5)
cifar_splits = cifar_data.splits

In [5]:
print("CIFAR STATS")
print("Number of splits:", len(cifar_splits))
print("Number of segments per split (train, val, test):", len(cifar_splits[0]))
print("Info per segment (data, labels):", len(cifar_splits[0][0]))
print("Size of segement (num examples):", len(cifar_splits[0][0][0]))

MNIST STATS
Number of splits: 5
Number of segments per split (train, val, test): 3
Info per segment (data, labels): 2
Size of segement (num examples): 48000


# Model & Optimizer

In [6]:
model = node_models.NeuralODE
optimizer = optim.Adam

# Training and Evaluation

## CIFAR

### NODE

In [7]:
model_params = config.baseline_config_cifar_node['model']
lr, epochs, batch, workers = config.baseline_config_cifar_node['train']

In [8]:
node_cifar_trainer = training.Trainer(model, optimizer, mnist_data, device)

In [None]:
node_cifar_trainer.train(model_params, lr, epochs, batch, num_workers=12, verbose=False, num_loss=10)

In [None]:
node_cifar_trainer.test(model_params, 32, 12)

### ANODE

In [None]:
model_params = config.baseline_config_mnist_anode['model']
lr, epochs, batch, workers = baseline_config.config_mnist_anode['train']

In [None]:
anode_cifar_trainer = training.Trainer(model, optimizer, mnist_data, device)

In [None]:
anode_cifar_trainer.train(model_params, lr, epochs, batch, num_workers=12, verbose=False, num_loss=10)

In [None]:
anode_cifar_trainer.test(model_params, 32, 12)

# Plots

In [None]:
node_cifar_trainer.val_metrics['legend'] = 'NODE'
anode_cifar_trainer.val_metrics['legend'] = 'ANODE'
metrics = [node_cifar_trainer.val_metrics, anode_cifar_trainer.val_metrics]

In [None]:
plt = metrics.Plotter(metrics)

In [None]:
plt.plotLoss("Optimized Model Validation Loss Comparisons")

In [None]:
plt.plotAccuracy("Optimized Model Validation Accuracy Comparisons")