# Imports

In [None]:
import torch
import torch.optim as optim
import numpy as np
import node_models
import loader
import training
import metrics
import autotune
import config

# GPU

In [None]:
device = torch.device(1)

# Data Loading

In [None]:
cifar_data_path = './cifar-10-batches-py'

In [None]:
cifar_data = loader.CIFAR(cifar_data_path, 3)
cifar_splits = cifar_data.splits

In [None]:
print("CIFAR STATS")
print("Number of splits:", len(cifar_splits))
print("Number of segments per split (train, val, test):", len(cifar_splits[0]))
print("Info per segment (data, labels):", len(cifar_splits[0][0]))
print("Size of segement (num examples):", len(cifar_splits[0][0][0]))

# Model & Optimizer

In [None]:
model = node_models.NeuralODE
optimizer = optim.Adam

# Training and Evaluation

## CIFAR

### NODE

In [None]:
model_params = config.baseline_config_cifar_node['model']
lr, epochs, batch, workers = config.baseline_config_cifar_node['train']

In [None]:
node_cifar_trainer = training.Trainer(model, optimizer, cifar_data, device)

In [None]:
node_cifar_trainer.train(model_params, lr, epochs, batch, num_workers=6, verbose=True, num_loss=3)

In [None]:
node_cifar_trainer.test(model_params, 10, 3)

In [None]:
 torch.cuda.empty_cache()

### ANODE

In [None]:
model_params = config.baseline_config_cifar_anode['model']
lr, epochs, batch, workers = config.baseline_config_cifar_anode['train']

In [None]:
anode_cifar_trainer = training.Trainer(model, optimizer, cifar_data, device)

In [None]:
anode_cifar_trainer.train(model_params, lr, epochs, batch, num_workers=6, verbose=True, num_loss=3)

In [None]:
anode_cifar_trainer.test(model_params, 10, 3)

In [None]:
print(anode_cifar_trainer.val_metrics['loss'])

# Plots

In [None]:
node_cifar_trainer.val_metrics['legend'] = 'NODE'
anode_cifar_trainer.val_metrics['legend'] = 'ANODE'
out_metrics = [node_cifar_trainer.val_metrics, anode_cifar_trainer.val_metrics]

In [None]:
plt = metrics.Plotter(out_metrics)

In [None]:
plt.plotLoss("Optimized Model Validation Loss Comparisons")

In [None]:
plt.plotAccuracy("Optimized Model Validation Accuracy Comparisons")

In [None]:
plt.plotNFE("Loss vs NFE", style='loss')

In [None]:
plt.plotNFE("Accuracy vs NFE", style='accuracy')