In [1]:
import json
import os
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms

from my_models import AlexNet, VGG16, ResNet

In [2]:
def load_data(data_dir='./data/CIFAR'):
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform)

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset

In [3]:
def test_accuracy(net, device = "cuda:0"):
    trainset, testset = load_data()

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False)


    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [4]:
in_ch = 3
out_ch = 10
gpus_per_trial=1

device = "cpu"
if torch.cuda.is_available():
    device = "cuda:0"

models_list = [AlexNet, VGG16, ResNet]

for model_name in models_list:
    with open("./data/best_trials_info/best_trial_dir_{}.txt".format(str(model_name)), "r") as file:    
        config = json.load(file)
    best_trained_model = model_name(in_ch, out_ch)
    best_trained_model.to(device=device)
    
    model_state, optimizer_state = torch.load(os.path.join(
        config["path"], "checkpoint"))
    best_trained_model.load_state_dict(model_state)
    
    test_acc = test_accuracy(best_trained_model, device=device)   
    print("Best trial test set accuracy for {}: {}".format(model_name, test_acc))

Files already downloaded and verified
Files already downloaded and verified
Best trial test set accuracy for <class 'my_models.AlexNet'>: 0.244
Files already downloaded and verified
Files already downloaded and verified
Best trial test set accuracy for <class 'my_models.VGG16'>: 0.1
Files already downloaded and verified
Files already downloaded and verified
Best trial test set accuracy for <class 'my_models.ResNet'>: 0.2517
