In [None]:
!pip install requests scikit-learn requests

In [1]:
from sdk import Client, QnaAPI
from sklearn.metrics import classification_report, accuracy_score, f1_score
from json import dump
import time
import numpy as np

In [2]:
def parse_data(path):
    X = []
    y = []
    docs = {}
    with open(path, 'r') as f:
        for line in f:
            splitted = line.split(',')
            label = splitted[0]
            sent = ", ".join(splitted[2:]).strip()
            X.append(sent)
            y.append(label) 
            if label not in docs:
                docs[label] = {'question': sent, 'answer': label, 'name': label, 'paraphrased_questions': []}
            else:
                docs[label]['paraphrased_questions'].append(sent)
    return docs, X, y

In [19]:
host_url = 'https://chat.autofaq.ai'
namespace = 'core-api/crud/api/v1'
api_url = 'https://chat.autofaq.ai/core-api/query'


small_test_path = 'data/hwu_small_test.csv'
small_train_path = 'data/hwu_small_train.csv'

large_test_path = 'data/hwu_large_test.csv'
large_train_path = 'data/hwu_large_train.csv'

In [20]:
# Contact info@autofaq.ai to get user_id and user_token
user_id = 1
user_token = 'xx'

In [21]:
def hwu_metrics_calc(test_path, train_path, user_id, user_token, name, namespace):
    docs_test, X_test, y_test = parse_data(test_path)
    docs_train, _, _ = parse_data(train_path)
    print("Parsed data")
    
    client = Client(host_url=host_url, user_id=user_id, user_token=user_token, namespace=namespace)
    Client.HTTP_TIMEOUT = 180

    service_response = client.create_service({'preset': 'en', 'name': name})
    for doc in docs_train.values():
        client.create_document(
            service_response['service_id'], 
            question=doc['question'], 
            answer=doc['answer'], 
            name=doc['name'], 
            paraphrases=doc['paraphrased_questions']
        )
    
    client.publish_service(service_response['service_id'], wait_timeout=180)
    print("Service published")

    qna = QnaAPI(api_url, service_response['service_id'], service_response['tokens'][0])

    print("Querying API ...")
    test_results = []
    times = []
    for row in X_test:
        start_time = time.time()
        res = qna.query(row)
        times.append(time.time() - start_time)
        test_results.append(res)

    y_pred = []
    for r in test_results:
        y_pred.append(r['results'][0]['name'])
    print(classification_report(y_test, y_pred))
    print("Accuracy: ", accuracy_score(y_test, y_pred)) 
    print("F1-Score: ", f1_score(y_test, y_pred, average='macro'))   
    print(f"Mean response time: {np.mean(times)} +- {np.std(times)} sec.")

In [15]:
hwu_metrics_calc(small_test_path, small_train_path, user_id, user_token, 'HWU-Small', namespace)

Parsed data
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
Service published
Querying API ...
                          precision    recall  f1-score   support

             alarm_query       0.94      0.89      0.92        19
            alarm_remove       0.60      0.82      0.69        11
               alarm_set       0.71      0.79      0.75        19
       audio_volume_down       0.64      0.88      0.74         8
       audio_volume_mute       0.55      0.40      0.46        15
         audio_volume_up       0.62      0.77      0.69        13
          calendar_query       0.37      0.58      0.45        19
         calenda

In [22]:
hwu_metrics_calc(large_test_path, large_train_path, user_id, user_token, 'HWU-Large', namespace)

Parsed data
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
publish_service wait ..
Service published
Querying API ...
                          precision    recall  f1-score   support

             alarm_query       0.95      0.89      0.92        94
            alarm_remove       0.84      0.91      0.88        54
               alarm_set       0.85      0.85      0.85        96
       audio_volume_down       0.70      0.80      0.74        40
       audio_volume_mu