This file applies circular shifts the the FashionMNSIT dataset and tests model "best.pt" from "models/best.pt"

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# we need the exact model definition
class FCN2(nn.Module):
    def __init__(self):
        super(FCN2, self).__init__()

        # Fully connected layers
        self.fc1 = nn.Linear(28 * 28, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 10)

    def forward(self, x):
        
        # Flatten 
        x = x.view(-1, 28 * 28)
        
        # Run through decision layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x 

In [3]:
model = FCN2()

# Load weights from the file 
PATH = "models/best.pt"

# load weights into memory 
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))

# inject weights into model 
model.load_state_dict(checkpoint)

# set eval mode 
model.eval()
print('model loaded')

model loaded


# Load Test Data 
We already have a trained model,so we now want to create the testset and loader

In [4]:
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5)) 
])

In [5]:
import torchvision

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                            download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10,
                                         shuffle=False, num_workers=2)

## Define a Shift Function 

In [7]:
def shift(model, loader, s_right=0, s_down=0):
    correct = 0
    total = 0
    
    # test loop 
    with torch.no_grad():
        for data in loader:
            inputs, labels = data
            
            # apply shift passed in 
            if s_right > 0:
                inputs = torch.roll(inputs, shifts=s_right, dims=3) # dims 3 = width
                                                                    # dims 2 = height 
            if s_down > 0:
                inputs= torch.roll(inputs, shifts=s_down, dims=2)
            
            # run images through model 
            outputs = model(inputs)    
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return 100 * correct / total 
            

## Run Tests

In [8]:
# no shift 
base_accuracy = shift(model, testloader, s_right=0,s_down=0)
print(f"Baseline Accuracy: {base_accuracy:.2f}%")

# shift right 2 pixels
right_accuracy = shift(model, testloader, s_right=2, s_down=0)
print(f"right shift of 2px Accuracy: {right_accuracy:.2f}%")

# shift right and down 2 pixels
right_down_accuracy = shift(model, testloader, s_right=2, s_down=2)
print(f"right and down shift of 2px Accuracy: {right_down_accuracy:.2f}%")

Baseline Accuracy: 89.33%
right shift of 2px Accuracy: 55.74%
right and down shift of 2px Accuracy: 43.53%
