In [20]:
import torch
import torchvision
import torchvision.transforms as transforms

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Loading CIFAR dataset
test_dataset = torchvision.datasets.CIFAR10(root='../../data/',
                                            train=False,
                                            transform=transforms.ToTensor())

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=100,
                                          shuffle=False)

"""for images, labels in test_loader:
    print(images.shape)   # torch.Size([100, 3, 32, 32])
    print(labels.shape)   # torch.Size([100])
    print(type(images))"""

# Test the model
model = torch.load('resnet_trained_model.ckpt')
model.load_state_dict(torch.load('resnet_trained_model_state.ckpt'))
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))


# Save the model checkpoint
torch.save(model.state_dict(), 'resnet_testing.ckpt')

Accuracy of the model on the test images: 88.49 %


In [38]:
#2. Loading Segmentated dataset (our dataset)
augmented_data = torch.load('augmented_data.pt')
label_set = []
for _,labels in test_loader:
    label_set.append(labels)
augmented_labels = torch.stack(label_set, dim=0)

print(augmented_data.shape)
print(augmented_labels.shape)

# Test the model
model = torch.load('resnet_trained_model.ckpt')
model.load_state_dict(torch.load('resnet_trained_model_state.ckpt'))
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for i in range(len(augmented_data)):
        images = augmented_data[i].to(device)
        labels = augmented_labels[i].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))


# Save the model checkpoint
torch.save(model.state_dict(), 'resnet_augment_testing.ckpt')

torch.Size([100, 100, 3, 32, 32])
torch.Size([100, 100])
Accuracy of the model on the test images: 23.6 %
