In [1]:
# import packages
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch 
import torchvision
import torch.utils.data.dataloader as DataLoader
from torchvision import transforms

from sklearn.svm import SVC
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.neural_network import MLPClassifier

In [2]:
transform=transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor()
])

# training set
data_path_train = 'dataset/train'
train_dataset = torchvision.datasets.ImageFolder(
    root=data_path_train,
    transform=transform
)
train_loader = DataLoader.DataLoader(
    train_dataset,
    batch_size=100,
    num_workers=1,
    shuffle=True
)


# test set
data_path_test = 'dataset/test'
test_dataset = torchvision.datasets.ImageFolder(
    root=data_path_test,
    transform=transform
)
test_loader = DataLoader.DataLoader(
    test_dataset,
    batch_size=100,
    num_workers=1,
    shuffle=False
)

# validation set
data_path_valid = 'dataset/valid'
valid_dataset = torchvision.datasets.ImageFolder(
    root=data_path_valid,
    transform=transform
)
valid_loader = DataLoader.DataLoader(
    valid_dataset,
    batch_size=100,
    num_workers=1,
    shuffle=False
)

In [3]:
# convolutional neural network
import torch.nn as nn 
from torch.nn import functional as F

class FCNN(nn.Module):
    def __init__(self):
        # FCNN model
        super(FCNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 50)
        self.fc4 = nn.Linear(50, len(train_dataset.classes))
        self.softmax = nn.Softmax()
        

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.softmax(self.fc4(x))
        return x

In [4]:
# define a function to help get prediction
def make_prediction(loader, model):
    result_total = []
    reference_total = []
    for index, (data, target) in enumerate(loader):
        data, label = data.to(device), torch.eye(len(train_dataset.classes))[target].to(device)
        output = model(data)

        result = torch.max(output,dim=1).indices.cpu().detach().numpy()
        reference = torch.max(label,dim=1).indices.cpu().detach().numpy()

        result_total.append(result)
        reference_total.append(reference)
    return np.hstack(result_total), np.hstack(reference_total)

In [6]:
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = FCNN().to(device)
model.load_state_dict(torch.load("./model/FCNN_parameter_1.pkl"))   

# validation and test process
result_valid, reference_valid = make_prediction(valid_loader, model)
result_test, reference_test = make_prediction(test_loader, model)

  x = self.softmax(self.fc4(x))


In [7]:
epoch = 1

# calculate accuracy and other metrics
print('Validation accuracy: %.4f'%accuracy_score(reference_valid, result_valid))
print('Test accuracy: %.4f'%accuracy_score(reference_test, result_test))
print('Validation F1 score: %.4f'%f1_score(reference_valid, result_valid, average='macro'))
print('Test F1 score: %.4f'%f1_score(reference_test, result_test, average='macro'))
print('Validation confusion matrix: \n%s'%confusion_matrix(reference_valid, result_valid))
print('Test confusion matrix: \n%s'%confusion_matrix(reference_test, result_test))
print('Validation classification report: \n%s'%classification_report(reference_valid, result_valid))
print('Test classification report: \n%s'%classification_report(reference_test, result_test))

    
plt.figure(figsize=(12,12))
plt.title('Validation confusion matrix epoch %d'%epoch,fontsize=16)
confusion_data = confusion_matrix(reference_valid, result_valid)
plt.imshow(confusion_data,interpolation='nearest',cmap="YlGnBu",vmax=2500,vmin=0)
for i in range(confusion_data.shape[0]):
    for j in range(confusion_data.shape[1]):
        plt.text(j,i,confusion_data[i,j],ha="center",va="center",fontsize=12)
plt.xticks(np.arange(0,confusion_data.shape[1],1),fontsize=12)
plt.yticks(np.arange(0,confusion_data.shape[0],1),fontsize=12)
plt.savefig('./figure/FCNN_valid_confusion_matrix_epoch_%d.png'%epoch)
plt.close()


plt.figure(figsize=(12,12))
plt.title('Test confusion matrix epoch %d'%epoch,fontsize=16)
confusion_data = confusion_matrix(reference_test, result_test)
plt.imshow(confusion_data,interpolation='nearest',cmap="YlGnBu",vmax=2500,vmin=0)
for i in range(confusion_data.shape[0]):
    for j in range(confusion_data.shape[1]):
        plt.text(j,i,confusion_data[i,j],ha="center",va="center",fontsize=12)
plt.xticks(np.arange(0,confusion_data.shape[1],1),fontsize=12)
plt.yticks(np.arange(0,confusion_data.shape[0],1),fontsize=12)
plt.savefig('./figure/FCNN_test_confusion_matrix_epoch_%d.png'%epoch)
plt.close()

Validation accuracy: 0.3461
Test accuracy: 0.3449
Validation F1 score: 0.2928
Test F1 score: 0.2906
Validation confusion matrix: 
[[ 418  193   94   41    8    1   11   92    0  100   11    0   43   97
   616  214  159    0    0   61  214   51    5   62    9]
 [  76 1395  106   34   17    5   11  131    0  162   41    0   13   63
    76   14   15    0    0   49  154   36    4   83   15]
 [  65  196 1077   52   37    7   25   35    0  164   36    0   49   64
   212   69   52    0    0   28   35  152   25   82   38]
 [  12  241   96  354   20   10   27   15    0  279   77    0   60  218
    39   50    4    0    0   43  266  126  116   45  402]
 [  10  174   48   15 1423    7   66   37    0  126   46    0    4   18
    20    8   11    0    0   86   27   35   55  252   32]
 [  64  360   96  103   78   16   87   60    0  205  103    0  136  128
   126   59   27    0    0  120  239  269   36  145   43]
 [  57  164  120   55   35    2  258   88    0  323   23    0  112  246
    75   28   38  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Test classification report: 
              precision    recall  f1-score   support

           0       0.28      0.18      0.22      2500
           1       0.24      0.57      0.34      2500
           2       0.28      0.42      0.34      2500
           3       0.22      0.14      0.17      2500
           4       0.65      0.58      0.62      2500
           5       0.20      0.01      0.02      2500
           6       0.31      0.10      0.15      2500
           7       0.43      0.72      0.54      2500
           8       0.00      0.00      0.00      2500
           9       0.32      0.80      0.46      2500
          10       0.33      0.15      0.21      2500
          11       0.00      0.00      0.00      2500
          12       0.46      0.48      0.47      2500
          13       0.22      0.25      0.24      2500
          14       0.30      0.74      0.43      2500
          15       0.40      0.48      0.44      2500
          16       0.46      0.72      0.56      250