# Introduction

This notebook continues from the _DataPrep_ notebook and tries to use a CNN based on ResNet in the MxNet framework for classification.

## Data Set

[Qingyi](https://www.kaggle.com/qingyi). (February 2018). WM-811K wafer map, Version 1. Retrieved January 2018 from https://www.kaggle.com/qingyi/wm811k-wafer-map/downloads/wm811k-wafer-map.zip/1.

## License

Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from mxnet import gluon, init, nd
from mxnet.gluon import data as gdata, loss as gloss, model_zoo
from mxnet.gluon import utils as gutils
import matplotlib.pyplot as plt
import os
import zipfile
import boto3
import sagemaker
from sagemaker.mxnet import MXNet
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

role = get_execution_role()

## Data loader and transformations

In [None]:
inputs = {'training': 's3://chip-wafer/data/train_rec', 'validation': 's3://chip-wafer/data/valid_rec'}

In [None]:
batch_size = 64
epochs = 3
learning_rate = 0.001
wd = 0.001

## Training

In [None]:
m = MXNet("classify_mxnet.py",
          role=role,
          train_instance_count=1,
          train_instance_type="ml.p3.2xlarge",
          framework_version="1.2.1",
          py_version="py3",
          hyperparameters={'batch_size': batch_size,
                         'epochs': epochs,
                         'learning_rate': learning_rate,
                         'momentum': 0.9, 
                         'wd': wd,
                         'log_interval': 200})

In [None]:
m.fit(inputs)

## Test Set

In [None]:
predictor = m.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')

In [None]:
test_imgs = gdata.vision.ImageFolderDataset('data')

In [None]:
sagemaker_r = boto3.client('sagemaker-runtime')

In [None]:
from sklearn.metrics import *
import io, json
import itertools
def evaluate_metrics(dataset):
    """Evaluate accuracy of a model on the given data set."""
    preds = []
    trues = []
    cnt = 0
    cnt_step = 500
    for item in dataset.items:
        img_path = item[0]
        label = item[1]    
        with io.FileIO(img_path, 'r') as imageBuffer:
            response = sagemaker_r.invoke_endpoint(
                              EndpointName = 'sagemaker-mxnet-2019-04-22-23-47-12-458',
                              Body=imageBuffer.read(),
                              ContentType='image/png',
                              Accept='application/json'
                          )
            res_json = json.loads(response['Body'].read().decode("utf-8"))
    
        
        trues.append(label)
        preds.append(int(res_json))
        
        if cnt % cnt_step == 0:
            print("Working on " + str(cnt))
        cnt = cnt + 1
    return trues, preds
    

In [None]:
trues, preds = evaluate_metrics(test_imgs)

In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax

In [None]:
import numpy as np
print("Accuracy: {0}".format(accuracy_score(trues, preds)))
print("Weighted F1 Score: {0}".format(f1_score(trues, preds, average='weighted')))
print("Weighted F-beta: {0}".format(fbeta_score(trues, preds, average='weighted', beta=1.0)))
print("Macro F1 Score: {0}".format(f1_score(trues, preds, average='macro')))
print("Macro F-beta: {0}".format(fbeta_score(trues, preds, average='macro', beta=1.0)))
print("Micro F1 Score: {0}".format(f1_score(trues, preds, average='micro')))
print("Micro F-beta: {0}".format(fbeta_score(trues, preds, average='micro', beta=1.0)))
print(classification_report(trues, preds, target_names=test_imgs.synsets))
cm = confusion_matrix(trues, preds)
plot_confusion_matrix(cm, test_imgs.synsets, normalize=False)
plot_confusion_matrix(cm, test_imgs.synsets, normalize=True)