# COM3029 Coursework 2 - Group 9

## Research of Model Serving Options

Should this be in the report instead?

## Web Service

Show the basic web server script here.

## Endpoint Testing

Run the unit tests here.

In [None]:
import requests

def send_prediction_request(string_to_predict):
    return requests.post('http://127.0.0.1:5000/', json={'comment': string_to_predict}).text


print(send_prediction_request('test comment 1'))
print(send_prediction_request('test comment 2'))

r = requests.get('http://127.0.0.1:5000/')
print(r.text)

### Model Accuracy

This test uses 50 comments from the GoEmotions testing dataset to estimate the accuracy of the model. The transformer model used in our web service should achieve an accuracy of around 60-65%.

In [None]:
import random
import datasets
from tqdm import tqdm

dataset = datasets.load_dataset("go_emotions")

# Maps all unchosen labels to their most similar counterpart.
label_mappings = [
    'admiration', 'amusement', 'anger', 'annoyance', 'admiration', 'optimism', 'curiosity',
    'curiosity', 'optimism', 'sadness', 'disapproval', 'annoyance', 'sadness', 'joy',
    'sadness', 'gratitude', 'sadness', 'joy', 'love', 'sadness', 'optimism',
    'admiration', 'admiration', 'gratitude', 'remorse', 'sadness', 'surprise', 'neutral'
]

count = 50
correct = 0
for i in tqdm(range(count)):
    choice = random.choice(dataset['test'])
    prediction = requests.post('http://127.0.0.1:5000/', json={'comment': choice['text']}).json()['prediction']
    
    if prediction in [label_mappings[x] for x in choice['labels']]:
        correct += 1
        
print("Correct: " + str(correct) + "    Incorrect: " + str(count - correct))

## Service Performance & Stress Testing

The stress testing script is run from the command line once the web service is running, as shown here:

`python stress_test.py -u 50 -r 5 -l 30`

The contents of this file is shown below for reference:

```python
import gevent
from locust import HttpUser, task, events
from locust.env import Environment
from locust.stats import stats_printer, stats_history, StatsCSVFileWriter
from locust.log import setup_logging

from locustfile import StressTest

import matplotlib.pyplot as plt
import pandas as pd

from argparse import ArgumentParser

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument("-u", "--user_count", dest="user_count", help="How many users to be created during a spawn", type=int, default=1)
    parser.add_argument("-r", "--spawn_rate", dest="spawn_rate", help="How many groups of users should be created every second", type=int, default=1)
    parser.add_argument("-l", "--length", dest="duration", help="How long the stress test should last for", type=int, default=30)

    stress_args = parser.parse_args()


    setup_logging("INFO", None)


    # setup Environment and Runner
    env = Environment(user_classes=[StressTest], events=events)
    runner = env.create_local_runner()

    # start a WebUI instance
    web_ui = env.create_web_ui("127.0.0.1", 8089)

    logging = StatsCSVFileWriter(environment=env, base_filepath='./', full_history=True, percentiles_to_report=[90.0])

    # execute init event handlers (only really needed if you have registered any)
    env.events.init.fire(environment=env, runner=runner, web_ui=web_ui)

    # start a greenlet that periodically outputs the current stats
    gevent.spawn(stats_printer(env.stats))

    gevent.spawn(logging)

    # start a greenlet that save current stats to history
    gevent.spawn(stats_history, env.runner)

    # start the test
    runner.start(user_count=stress_args.user_count, spawn_rate=stress_args.spawn_rate)

    # in duration seconds stop the runner
    gevent.spawn_later(stress_args.duration, lambda: runner.quit())

    # wait for the greenlets
    runner.greenlet.join()

    # stop the web server for good measures
    web_ui.stop()

    graph_stats = pd.read_csv('_stats_history.csv')
    graph_stats = graph_stats[graph_stats['Name'] == 'Aggregated']
    x_axis = graph_stats['Timestamp']
    x_axis = x_axis - x_axis.min()

    target_metrics = ['Requests/s','Failures/s','Total Request Count','Total Failure Count','Total Average Response Time']

    for metric in target_metrics:
        y_axis = graph_stats[metric].astype(float)

        plt.plot(list(x_axis.values), list(y_axis.values))
        plt.title(metric)
        save_name = metric.replace('/', '_per_')
        plt.savefig(f'{save_name}.png')
        plt.clf()
```

## Monitoring Capabilities

In order to monitor the user input and predictions of the webservice, the server uses Python's `logging` library to log to a file. For every prediction request made by each user, the time, input text, and predicted label is written as JSON to the `predictions.txt` log file, an example snippet of which is shown here:

```json
{"time": "2023-05-23 00:58:45:748136", "text": "literally I feel like crying", "prediction": "sadness"}
{"time": "2023-05-23 00:58:46:135930", "text": "i love this subreddit", "prediction": "love"}
```

JSON is used so that the log file can be parsed programatically and used by a different service, if if ever becomes necessary. This structured logging is implemented using a custom class to store all the necessary data. There is also a class to log error messages, for when the user sends invalid text. Both of these are shown below:

```python
class LogMsg(object):
    def __init__(self, text, prediction):
        self.text = text
        self.prediction = prediction
        self.time = datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')

    def __str__(self):
        return json.dumps({'time': self.time, 'text': self.text, 'prediction': self.prediction})
    
class LogMsgInputFormat400Error(object):
    def __init__(self):
        self.time = datetime.now().strftime('%Y-%m-%d %H:%M:%S:%f')

    def __str__(self):
        return json.dumps({'time': self.time, 'error': "The message must be JSON in the form json={'comment': string_to_predict}."})
```

Logging a regular request and an invalid request are then done like so:

```python
# Invalid request:
app.logger.info(LogMsgInputFormat400Error())

# Valid request:
app.logger.info(LogMsg(comment_text, prediction))
```

## CI/CD Pipeline