In [16]:
import torch
from torch.utils.data import DataLoader

import torchvision
from torchvision.transforms import ToTensor
import pickle
from tqdm.auto import tqdm
from torch import nn

In [17]:
class ImprovedCNNDigits(nn.Module):
    def __init__(self):
        super(ImprovedCNNDigits, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 5 * 5, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
        
    def forward(self, x):
        x = self.layers(x)
        return self.classifier(x)

In [24]:
model = torch.load("model.pth")

In [33]:
test_dataset = torchvision.datasets.MNIST(
    "/data",
    train=False,
    transform=ToTensor(),
    download=True
)

test_dataset.data = torch.where(test_dataset.data >= 127, 255, 0).type(torch.uint8)


test_dataloader = DataLoader(test_dataset, batch_size=50, shuffle=False)


In [34]:
y_preds = []
y_trues = []
model.eval()
with torch.inference_mode():
    for X, y in tqdm(test_dataloader, desc="Making predictions"):
        y_logits = model(X)
        softmax = torch.softmax(y_logits, dim=1)
        y_pred = softmax.squeeze().argmax(dim=1)
        
        y_preds += y_pred.tolist()
        y_trues += y.tolist()



Making predictions:   0%|          | 0/200 [00:00<?, ?it/s]

In [35]:
indexes = [i for i, (x, y) in enumerate(zip(y_preds, y_trues)) if x != y]

print(len(indexes), indexes)

173 [321, 445, 449, 495, 582, 591, 659, 684, 717, 839, 844, 846, 882, 947, 1014, 1039, 1062, 1068, 1112, 1156, 1182, 1202, 1226, 1232, 1242, 1247, 1260, 1263, 1299, 1319, 1326, 1337, 1364, 1393, 1500, 1527, 1530, 1621, 1678, 1681, 1709, 1754, 1790, 1878, 1901, 2023, 2043, 2053, 2070, 2098, 2109, 2118, 2130, 2135, 2168, 2182, 2185, 2189, 2266, 2280, 2293, 2387, 2414, 2425, 2454, 2462, 2488, 2597, 2607, 2654, 2780, 2896, 2921, 2927, 2952, 2953, 2995, 3023, 3060, 3073, 3206, 3225, 3289, 3384, 3422, 3503, 3511, 3520, 3559, 3597, 3718, 3767, 3796, 3806, 3808, 3853, 3906, 4075, 4078, 4176, 4205, 4238, 4248, 4256, 4265, 4306, 4344, 4497, 4536, 4571, 4575, 4578, 4615, 4639, 4731, 4740, 4761, 4783, 4807, 4814, 4874, 4956, 4978, 5278, 5634, 5734, 5842, 5888, 5955, 5973, 6011, 6059, 6091, 6157, 6172, 6505, 6532, 6555, 6558, 6571, 6576, 6578, 6597, 6625, 6651, 6740, 6755, 6783, 6847, 7121, 7434, 7619, 7902, 8059, 8081, 8091, 8094, 8095, 8246, 8408, 8520, 8527, 9009, 9015, 9019, 9642, 9664, 9692, 9