In [1]:
import torchvision
from torchvision.utils import make_grid
from torchvision import datasets, transforms
from src.model import MLModel
from src.trainer import Trainer
from src.dataloader import Loader
from src.utils.utils import plot_history, load_model, load_history

In [2]:
training_dir='cifar10-dataset'

In [3]:
# Define data augmentation
def _get_transforms():
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    return transform

In [4]:
val_set = torchvision.datasets.CIFAR10(root=training_dir,
                                        train=False,
                                        download=False,
                                        transform=_get_transforms())

In [5]:
val_set.data.shape

(10000, 32, 32, 3)

In [6]:
test_loader = Loader(val_set, batch_size=32, shuffle=True)

In [7]:
model = MLModel()
model = load_model(model, 'model_output/model.pth')

In [8]:
trainer = Trainer(model)

2023-02-03 15:16.15 [info     ] Config inputs.                 config={}
2023-02-03 15:16.15 [info     ] Loading the model.
2023-02-03 15:16.15 [info     ] Training on device: cpu.


In [9]:
test_loss, test_accuracy = trainer.test(model, test_loader)

2023-02-03 15:16.15 [info     ] Testing..


100%|██████████| 313/313 [00:05<00:00, 52.86batch/s, loss=0.859, metric=0.707] 


In [10]:
print(f'Model accuracy on test: {test_accuracy}')

Model accuracy on test: 0.7068690095846646
