In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os
from sklearn.metrics import accuracy_score, precision_score,recall_score, confusion_matrix

In [2]:
# Read the image paths and folder names 
count = 0
images = []

for i in range(10):
    init_path = os.path.join("Fake_Digits", str(i))
    dir_list = os.listdir(init_path)
    img = []
    for file_name in dir_list:
        if file_name.endswith(".png"):
            img.append(os.path.join(init_path,file_name))
    count += len(img)
    images.append(img)

In [3]:
# Tensors to store fake data and fake targets
fake_data = torch.zeros((count,28*28))
fake_targets = torch.zeros((count))
gray = transforms.Grayscale()

In [4]:
# Read images and targets 
cnt=0
for i in range(10):
    for image_path in images[i]:
        fake_data[cnt] = torch.flatten(gray(torchvision.io.read_image(path=image_path)))
        fake_targets[cnt] = i
        cnt+=1

In [5]:
# MNIST Classifier models
class MNISTclassifier(nn.Module):
    def __init__(self, input_size: int):
        super(MNISTclassifier, self).__init__()
        self.input_layer = nn.Linear(input_size, 512, bias=True)
        self.second_layer = nn.Linear(512, 64, bias=True)
        self.third_layer = nn.Linear(64, 10, bias=True)
        self.relu = nn.ReLU()
        self.soft = nn.Softmax()
    
    def forward(self, x):
        out = self.input_layer(x)
        out = self.relu(out)
        out = self.second_layer(out)
        out = self.relu(out)
        out = self.third_layer(out)
        out = self.soft(out)
        return out

# Load previously save model
Classifier = torch.load('C.pkl')

In [6]:
# Predicting the results 
predict_out = torch.zeros((count))

for i in range(count):
    tmp_list = Classifier(torch.flatten(fake_data[i])).tolist()
    predict_out[i] = tmp_list.index(max(tmp_list))

In [7]:
print('------------Accuracy----------------------')
predict_y=torch.round(predict_out,decimals=0).int()

print('prediction accuracy',accuracy_score(fake_targets.data,predict_y.data))

print('macro precision',precision_score(fake_targets.data,predict_y.data,average='macro'))
print('micro precision',precision_score(fake_targets.data,predict_y.data,average='micro'))

print('macro recall',recall_score(fake_targets.data,predict_y.data,average='macro'))
print('micro recall',recall_score(fake_targets.data,predict_y.data,average='micro'))

------------Accuracy----------------------
prediction accuracy 0.7338709677419355
macro precision 0.7399509803921569
micro precision 0.7338709677419355
macro recall 0.7312337662337662
micro recall 0.7338709677419355
