In [1]:
# import packages
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch 
import torchvision
import torch.utils.data.dataloader as DataLoader
from torchvision import transforms

from sklearn.svm import SVC
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.neural_network import MLPClassifier

In [2]:
transform=transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor()
])

# training set
data_path_train = 'dataset/train'
train_dataset = torchvision.datasets.ImageFolder(
    root=data_path_train,
    transform=transform
)
train_loader = DataLoader.DataLoader(
    train_dataset,
    batch_size=100,
    num_workers=1,
    shuffle=True
)


# test set
data_path_test = 'dataset/test'
test_dataset = torchvision.datasets.ImageFolder(
    root=data_path_test,
    transform=transform
)
test_loader = DataLoader.DataLoader(
    test_dataset,
    batch_size=100,
    num_workers=1,
    shuffle=False
)

# validation set
data_path_valid = 'dataset/valid'
valid_dataset = torchvision.datasets.ImageFolder(
    root=data_path_valid,
    transform=transform
)
valid_loader = DataLoader.DataLoader(
    valid_dataset,
    batch_size=100,
    num_workers=1,
    shuffle=False
)

In [6]:
# simple neural network
import torch.nn as nn 
from torch.nn import functional as F

class CNN(nn.Module):
    def __init__(self):
        # CNN model
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, 3)
        self.conv2 = nn.Conv2d(3, 3, 3)
        self.conv3 = nn.Conv2d(3, 3, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(3*10*10, 100)
        self.fc2 = nn.Linear(100, 50)
        self.fc3 = nn.Linear(50, len(train_dataset.classes))
        self.softmax = nn.Softmax()
        

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = x.view(-1, 3*10*10)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.softmax(self.fc3(x))
        return x

In [10]:
# define a function to help get prediction
def make_prediction(loader, model):
    result_total = []
    reference_total = []
    for index, (data, target) in enumerate(loader):
        data, label = data.to(device), torch.eye(len(train_dataset.classes))[target].to(device)
        output = model(data)

        result = torch.max(output,dim=1).indices.cpu().detach().numpy()
        reference = torch.max(label,dim=1).indices.cpu().detach().numpy()

        result_total.append(result)
        reference_total.append(reference)
    return np.hstack(result_total), np.hstack(reference_total)

In [8]:
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
lr = 0.0005 # learning rate
optimizer = optim.Adam(model.parameters(), lr=lr)
epoch_num = 1   

loss_list = []

for epoch in range(1,epoch_num+1):
    # training process
    for batch_idx, (data, target) in enumerate(train_loader):
        data, label = data.to(device), torch.eye(len(train_dataset.classes))[target].to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, label)
        loss_list.append(loss.item())
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    # calculate accuracy and other metrics
    torch.save(model.state_dict(), "./model/CNN_parameter_%d.pkl"%epoch)

    # plot the loss figure
    plt.figure(figsize=(12,4))
    plt.plot(loss_list,linewidth=1)
    plt.xlim(0,epoch*len(train_loader))
    plt.xlabel('epoch',fontsize=14)
    plt.ylabel('loss',fontsize=14)
    plt.xticks(np.arange(0,epoch*len(train_loader)+1,len(train_loader)),np.arange(0,epoch+1,1),fontsize=12)
    plt.savefig('./figure/CNN_loss_epoch_%d.png'%epoch)
    plt.close()

    # validation and test process
    result_valid, reference_valid = make_prediction(valid_loader, model)
    result_test, reference_test = make_prediction(test_loader, model)

    # calculate accuracy and other metrics
    print('Validation accuracy: %.4f'%accuracy_score(reference_valid, result_valid))
    print('Test accuracy: %.4f'%accuracy_score(reference_test, result_test))
    print('Validation F1 score: %.4f'%f1_score(reference_valid, result_valid, average='macro'))
    print('Test F1 score: %.4f'%f1_score(reference_test, result_test, average='macro'))
    print('Validation confusion matrix: \n%s'%confusion_matrix(reference_valid, result_valid))
    print('Test confusion matrix: \n%s'%confusion_matrix(reference_test, result_test))
    print('Validation classification report: \n%s'%classification_report(reference_valid, result_valid))
    print('Test classification report: \n%s'%classification_report(reference_test, result_test))

    
    plt.figure(figsize=(6,6))
    plt.title('Validation confusion matrix epoch %d'%epoch,fontsize=16)
    confusion_data = confusion_matrix(reference_valid, result_valid)
    plt.imshow(confusion_data,interpolation='nearest',cmap="YlGnBu",vmax=2500,vmin=0)
    for i in range(confusion_data.shape[0]):
        for j in range(confusion_data.shape[1]):
            plt.text(j,i,confusion_data[i,j],ha="center",va="center",fontsize=12)
    plt.xticks(np.arange(0,confusion_data.shape[1],1),fontsize=12)
    plt.yticks(np.arange(0,confusion_data.shape[0],1),fontsize=12)
    plt.savefig('./figure/CNN_valid_confusion_matrix_epoch_%d.png'%epoch)
    plt.close()


    plt.figure(figsize=(6,6))
    plt.title('Test confusion matrix epoch %d'%epoch,fontsize=16)
    confusion_data = confusion_matrix(reference_test, result_test)
    plt.imshow(confusion_data,interpolation='nearest',cmap="YlGnBu",vmax=2500,vmin=0)
    for i in range(confusion_data.shape[0]):
        for j in range(confusion_data.shape[1]):
            plt.text(j,i,confusion_data[i,j],ha="center",va="center",fontsize=12)
    plt.xticks(np.arange(0,confusion_data.shape[1],1),fontsize=12)
    plt.yticks(np.arange(0,confusion_data.shape[0],1),fontsize=12)
    plt.savefig('./figure/CNN_test_confusion_matrix_epoch_%d.png'%epoch)
    plt.close()



  x = self.softmax(self.fc3(x))




IndexError: index 3 is out of bounds for dimension 0 with size 3

In [11]:
result_valid, reference_valid = make_prediction(valid_loader, model)
result_test, reference_test = make_prediction(test_loader, model)

  x = self.softmax(self.fc3(x))


In [None]:
# new_model = CNN().to(device)
# new_model.load_state_dict(torch.load("./model/CNN_parameter.pkl"))   
# new_model.forward(input)

<All keys matched successfully>

In [13]:
print('Validation accuracy: %.4f'%accuracy_score(reference_valid, result_valid))
print('Test accuracy: %.4f'%accuracy_score(reference_test, result_test))
print('Validation F1 score: %.4f'%f1_score(reference_valid, result_valid, average='macro'))
print('Test F1 score: %.4f'%f1_score(reference_test, result_test, average='macro'))
print('Validation confusion matrix: \n%s'%confusion_matrix(reference_valid, result_valid))
print('Test confusion matrix: \n%s'%confusion_matrix(reference_test, result_test))
print('Validation classification report: \n%s'%classification_report(reference_valid, result_valid))
print('Test classification report: \n%s'%classification_report(reference_test, result_test))

    
plt.figure(figsize=(12,12))
plt.title('Validation confusion matrix epoch %d'%epoch,fontsize=16)
confusion_data = confusion_matrix(reference_valid, result_valid)
plt.imshow(confusion_data,interpolation='nearest',cmap="YlGnBu",vmax=2500,vmin=0)
for i in range(confusion_data.shape[0]):
    for j in range(confusion_data.shape[1]):
        plt.text(j,i,confusion_data[i,j],ha="center",va="center",fontsize=12)
plt.xticks(np.arange(0,confusion_data.shape[1],1),fontsize=12)
plt.yticks(np.arange(0,confusion_data.shape[0],1),fontsize=12)
plt.savefig('./figure/CNN_valid_confusion_matrix_epoch_%d.png'%epoch)
plt.close()

plt.figure(figsize=(12,12))
plt.title('Test confusion matrix epoch %d'%epoch,fontsize=16)
confusion_data = confusion_matrix(reference_test, result_test)
plt.imshow(confusion_data,interpolation='nearest',cmap="YlGnBu",vmax=2500,vmin=0)
for i in range(confusion_data.shape[0]):
    for j in range(confusion_data.shape[1]):
        plt.text(j,i,confusion_data[i,j],ha="center",va="center",fontsize=12)
plt.xticks(np.arange(0,confusion_data.shape[1],1),fontsize=12)
plt.yticks(np.arange(0,confusion_data.shape[0],1),fontsize=12)
plt.savefig('./figure/CNN_test_confusion_matrix_epoch_%d.png'%epoch)
plt.close()

Validation accuracy: 0.4634
Test accuracy: 0.4616
Validation F1 score: 0.3855
Test F1 score: 0.3846
Validation confusion matrix: 
[[ 669  127    0    0   52    0    0   46    1    0   51   68    0  233
   145  456  117    0    0  270  125   34   30   67    9]
 [   9 2152    0    0   13    0    0   18   29    0   41    5    0   11
     2    8    4    0    0   62   20   33   36   24   33]
 [ 390  200    0    0  140    0    0   10    8    0   71   65    0  340
   115  407   15    0    0  123   32   99  185  284   16]
 [   7  220    0    0   72    0    0   15    8    0  379   11    0   62
     2   55    3    0    0  328  212   18  630   16  462]
 [   4   11    0    0 2186    0    0   11    2    0    7   12    0   15
     4   24    3    0    0   72    9   10   50   76    4]
 [ 128  307    0    0  185    0    0   27   21    0  322   37    0  205
   105  159   11    0    0  384  121  142  193   87   66]
 [  34  333    0    0  103    0    0    9    2    0  197   35    0  361
    26   67   26  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
