In [None]:
!pip install requests scikit-learn requests

In [1]:
from sdk import Client, QnaAPI
from sklearn.metrics import classification_report, accuracy_score, f1_score
from json import dump

In [3]:
def parse_data(path):
    X = []
    y = []
    docs = {}
    with open(path, 'r') as f:
        for line in f:
            splitted = line.split(',')
            label = splitted[0]
            sent = ", ".join(splitted[2:]).strip()
            X.append(sent)
            y.append(label) 
            if label not in docs:
                docs[label] = {'question': sent, 'answer': label, 'name': label, 'paraphrased_questions': []}
            else:
                docs[label]['paraphrased_questions'].append(sent)
    return docs, X, y

In [4]:
host_url = 'https://chat.autofaq.ai'
api_url = 'https://chat.autofaq.ai/core-api/query'

small_test_path = 'small/hwu_small_test.csv'
small_train_path = 'small/hwu_small_train.csv'
small_results_path = 'small/autofaq_results.json'

large_test_path = 'large/hwu_large_test.csv'
large_train_path = 'large/hwu_large_train.csv'
large_results_path = 'large/autofaq_results.json'

user_id = 1
user_token = '1'

In [5]:
def hwu_metrics_calc(test_path, train_path, user_id, user_token, name, results_json_path):
    docs_test, X_test, y_test = parse_data(test_path)
    docs_train, _, _ = parse_data(train_path)
    print("Parsed data")
    
    client = Client(host_url=host_url, user_id=user_id, user_token=user_token, namespace='core-api/crud/api/v1')
    Client.HTTP_TIMEOUT = 180

    service_response = client.create_service({'preset': 'en', 'name': name})
    for doc in docs_train.values():
        client.create_document(
            service_response['service_id'], 
            question=doc['question'], 
            answer=doc['answer'], 
            name=doc['name'], 
            paraphrases=doc['paraphrased_questions']
        )
    
    client.publish_service(service_response['service_id'], wait_timeout=180)
    print("Service published")

    qna = QnaAPI(api_url, service_response['service_id'], service_response['tokens'][0])

    print("Querying API ...")
    test_results = []
    for row in X_test:
        test_results.append(qna.query(row))
    
    with open(results_json_path, 'w') as f:
        dump(test_results, f)

    y_pred = []
    for r in test_results:
        y_pred.append(r['results'][0]['name'])
    print(classification_report(y_test, y_pred))
    print("Accuracy: ", accuracy_score(y_test, y_pred)) 
    print("F1-Score: ", f1_score(y_test, y_pred, average='macro'))   

In [6]:
hwu_metrics_calc(small_test_path, small_train_path, user_id, user_token, 'HWU-Small', small_results_path)

Parsed data
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
Service published
Querying API ...
                          precision    recall  f1-score   support

             alarm_query       1.00      0.84      0.91        19
            alarm_remove       0.64      0.82      0.72        11
               alarm_set       0.70      0.84      0.76        19
       audio_volume_down       0.70      0.88      0.78         8
       audio_volume_mute       0.56      0.33      0.42        15
         audio_volume_up       0.55      0.85      0.67        13
          calendar_query       0.43      0.63      0.51        19
         calendar_remove       1.00     

In [6]:
hwu_metrics_calc(large_test_path, large_train_path, user_id, user_token, 'HWU-Large', large_results_path)

Parsed data
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
Service published
Querying API ...
                          precision    recall  f1-score   support

             alarm_query       0.93      0.91      0.92        94
            alarm_remove       0.91      0.89      0.90        54
               alarm_set       0.83      0.84      0.84        96
       audio_volume_down       0.70      0.80      0.74        40
       audio_volume_mute       0.90      0.86      0.88        76
         audio_volume_up       0.80      0.78      0.79        73
          calendar_query       0.61      0.71      0.66        95
         calendar_remove       0.89     