# Document Classification Demo

In this notebook we will learn how to use the **Document Classification** service using a Python-based client library. 
In this demo, we train a model for classification and evaluate its performance on a small example dataset.

For the demo, we prepared this Jupyter Notebook which demonstrates the use of this client library to invoke the most important functions of the Document Classification REST API. 

### Demo content

1. Create a dataset
1. Train and deploy a custom model
1. Run classification requests using the custom model

### Initial steps

**Note:** The following steps are required for the execution of this notebook (dataset creation and training/deployment features are not available on trial accounts):

1. See this [tutorial](https://developers.sap.com/tutorials/hcp-create-trial-account.html) to learn how to create an **SAP Cloud Platform Trial account**.
2. See this [tutorial](https://developers.sap.com/tutorials/cp-aibus-dc-service-instance.html) to learn how to create a **Document Classification service instance**.

Please keep the **service key** for the Document Classification service created in the 2nd step above at hand as we will use it below!

### 1. Installation and invocation of the client library

In this section we will:
- Pass the service key we obtained in the **initial steps above** to the client library which will use it to communicate with the [Document Classification REST API](https://aiservices-trial-dc.cfapps.eu10.hana.ondemand.com/document-classification/v1/)
- Import the client library package and create an instance of the client
- Try out and learn about the [API of the client library](https://github.com/SAP/business-document-processing/blob/main/API.md)

An example dataset is provided in the repo, you can explore the structure of the dataset required [here](https://github.com/SAP/business-document-processing/tree/main/examples/document_classification_examples/data/training_data).

In [None]:
# Pass the credentials to the client library

# Please copy the content of the service key you created in the Prerequisites above here after "service_key = "
# An EXAMPLE is given to show how the service key needs to be pasted here
service_key = {
  "url": "*******",
  "uaa": {
    "tenantmode": "*******",
    "sburl": "*******",
    "subaccountid": "*******",
    "clientid": "*******",
    "xsappname": "*******",
    "clientsecret": "*******",
    "url": "*******",
    "uaadomain": "*******",
    "verificationkey": "*******",
    "apiurl": "*******",
    "identityzone": "*******",
    "identityzoneid": "*******",
    "tenantid": "*******",
    "zoneid": "*******"
  },
  "swagger": "/document-classification/v1"
}

# These 4 values are needed for communicating with the Document Classification REST API
api_url = service_key["url"]
uaa_server = service_key["uaa"]["url"]
client_id = service_key["uaa"]["clientid"]
client_secret = service_key["uaa"]["clientsecret"]



<div class="alert alert-block alert-warning">
<b>Warning:</b> Please provide a model name for training.
</div>

In [None]:
# Model specific configuration
model_name = "" # choose an arbitrary model name for the model trained here, will be assigned to the trained model for identification purposes
dataset_folder = "data/training_data/" # should point to (relative or absolute) path containing dataset

## Initialize Demo

In [None]:
# Import client class from document classification client python package
from sap_business_document_processing import DCApiClient
# Create an instance of this class with the credentials defined in the last cell
my_dc_client = DCApiClient(api_url, client_id, client_secret, uaa_server)

## Create Dataset for training of a new model

In [None]:
# Create Training dataset
response = my_dc_client.create_dataset()
training_dataset_id = response["datasetId"]
print("Dataset created with datasetId: {}".format(training_dataset_id))

In [None]:
# Upload training documents to the dataset from training directory
print("Uploading training documents to the dataset")
my_dc_client.upload_documents_directory_to_dataset(training_dataset_id, dataset_folder)
print("Finished uploading training documents to the dataset")

In [None]:
# Pretty print the dataset statistics
from pprint import pprint
print("Dataset statistics")
dataset_stats = my_dc_client.get_dataset_info(training_dataset_id)
pprint(dataset_stats)

In [None]:
# Visualization of label distribution
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

nrCharacteristics = len(dataset_stats["groundTruths"])
fig, (ax) = plt.subplots(nrCharacteristics,1, figsize=(10, 15), dpi=80, facecolor='w', edgecolor='k')
if nrCharacteristics==1:
    ax = np.array((ax,)) 
for i in range(nrCharacteristics):
    keys = [element["value"] for element in  dataset_stats["groundTruths"][i]["classes"]]
    total = [element["total"] for element in  dataset_stats["groundTruths"][i]["classes"]]
    ax[i].set_ylabel("Absolute")
    ax[i].bar(keys, total)

## Training

In [None]:
# Train the model
print("Start training job from model with modelName {}".format(model_name))
response = my_dc_client.train_model(model_name, training_dataset_id)
pprint(response)
print("Model training finished with status: {}".format(response.get("status")))
if response.get("status") == "SUCCEEDED":
    model_version = response.get("modelVersion")
    print("Trained model: {}".format(model_name))
    print("Trained model version: {}".format(model_version))

In [None]:
# Check training statistics
response = my_dc_client.get_trained_model_info(model_name, model_version)
training_details = response.pop("details")
pprint(response)

## Deployment

In [None]:
# Deploy model
response = my_dc_client.deploy_model(model_name, model_version)
pprint(response)

## Classification

This runs inference on all documents in the test set (stratification is done inside DC service and reproduced here).  
We are working on exposing the stratification results so that this cell can be shortened.

In [None]:
# Test usage of the model by classifying a few documents and collecting results and ground truth
import binascii
import time
import json
import numpy as np
from collections import defaultdict

filenames = my_dc_client._find_files(dataset_folder, ".pdf")
test_filenames = []
for filename in filenames:
    # Check whether it is a test document
    with open(filename, 'rb') as pdf_file:
        is_test_document = (int(str(binascii.crc32(pdf_file.read()))) % 100) in range(90,100)
    if is_test_document:
        test_filenames.append(filename)

# Classify all test documents
responses = list(my_dc_client.classify_documents(test_filenames, model_name, model_version))

# Iterate over responses and store results in convenient format
test_prediction = defaultdict(lambda : [])
test_probability = defaultdict(lambda : defaultdict(lambda : []))
test_ground_truth = defaultdict(lambda : [])
for response, filename in zip(responses, test_filenames):
    pprint(response)
    try:
        # Parse response from DC service
        prediction = response["predictions"]
        for element in prediction:
            labels = []
            scores = []
            for subelement in element["results"]:
                labels.append(subelement["label"])
                scores.append(subelement["score"])
                test_probability[element["characteristic"]][subelement["label"]].append(subelement["score"])
            test_prediction[element["characteristic"]].append(labels[np.argmax(np.asarray(scores))])
        # Collect ground truth of all test documents
        with open(filename.replace(".pdf", ".json")) as gt_file:
            gt = json.load(gt_file)
        for element in gt["classification"]:
            test_ground_truth[element["characteristic"]].append(element["value"])
    except KeyError:
        print("Document not used")

In [None]:
# display the ground truth and classification result for a certain document with index idx
idx = 0

for i in range(nrCharacteristics):
    characteristic =dataset_stats["groundTruths"][i]["characteristic"]
    print("Ground truth for characteristic '{}'".format(str(characteristic)) + ": '{}'".format(test_ground_truth[str(characteristic)][idx]))

print("Model predictions:")
pprint(responses[idx])

# Confusion Matrix

The confusion matrix is a way to visualize the results of a classification task.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix

font = {'size'   : 22}
plt.rc('font', **font)

def plot_confusion_matrix(ax, char, y_true, y_pred, classes,
                          normalize=True,
                          title=None,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = "{}: Normalized confusion matrix".format(char)
        else:
            title = "{}: Confusion matrix without normalization".format(char)

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)

        # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel="True label",
           xlabel="Predicted label",
           xlim=(-0.5,len(classes)-0.5),
           ylim=(-0.5,len(classes)-0.5))


    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right',
             rotation_mode='anchor')

    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha='center', va='center', color='white' if cm[i, j] > thresh else 'black')
    fig.show()

fig, ax = plt.subplots(len(test_ground_truth.keys()), 1, figsize=(14,28))
if len(test_ground_truth.keys())==1:
    ax = np.array((ax,)) 
for idx, characteristic in enumerate(test_ground_truth.keys()):
    plot_confusion_matrix(ax[idx], characteristic,
                          test_ground_truth[characteristic], 
                          test_prediction[characteristic], 
                          np.unique(np.asarray(test_ground_truth[characteristic])), 
                          normalize=False)
fig.subplots_adjust(hspace=0.5)

## Precision Recall curves

Note that a Precision recall curve is only well defined for binary classification tasks.  
For the multi-class case each class is treated independently based on the confidence scores provided by the service.

In [None]:
## Visualize PR curve for each characteristic (NOTE this as a bit boring in this example, create a more challenging dataset for algorithm?)
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import precision_recall_curve

def plot_f_score(ax):
    f_scores = np.linspace(0.2, 0.8, num=4)
    for f_score in f_scores:
        x = np.linspace(0.01, 1)
        y = f_score * x / (2 * x - f_score)
        l, = ax.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2)
        ax.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02))

fig, ax = plt.subplots(len(test_ground_truth.keys()), 1, figsize=(12, 24), dpi=80, facecolor='w', edgecolor='k')
if len(test_ground_truth.keys())==1:
    ax = np.array((ax,)) 

for idx, characteristic in enumerate(test_ground_truth.keys()):
    for label in np.unique(np.asarray(test_ground_truth[characteristic])):
        gt = [subelement == label for subelement in test_ground_truth[characteristic]]
        prediction = test_probability[characteristic][label]
        precision, recall, thresholds = precision_recall_curve(gt, prediction)
        ax[idx].plot(recall, precision, label=label)
    ax[idx].set_xlabel('Recall')
    ax[idx].set_ylabel('Precision')
    ax[idx].set_xlim(-0.1,1.1)
    ax[idx].set_ylim(-0.1,1.1)
    ax[idx].set_title('{}: Precision-Recall curves'.format(characteristic))   
    ax[idx].spines["top"].set_visible(False)
    ax[idx].spines["right"].set_visible(False)
    ax[idx].get_xaxis().tick_bottom()
    ax[idx].get_yaxis().tick_left()
    ax[idx].legend()
    ax[idx].grid()
    plot_f_score(ax[idx])

fig.show()